In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix
)

import pandas as pd
import numpy as np
from pathlib import Path
import os
from PIL import Image
from tqdm import tqdm
import json
from collections import defaultdict

In [None]:
class ConfidenceWeightedVoting():
    def __init__(self, n_classes):
        super().__init__()
        self.n_classes = n_classes

    def forward(self, instance_logits):
        # get probabilities
        instance_probs = torch.softmax(instance_logits, dim=-1)

        # get predictions and confidences
        instance_predictions = torch.argmax(instance_probs, dim=-1)
        instance_confidences = torch.max(instance_probs, dim=-1)[0]

        # use majority voting for each prediction weighted by confidence
        bag_logits = torch.zeros(self.n_classes, device=instance_logits.device)
        for pred, conf in zip(instance_predictions, instance_confidences):
            bag_logits[pred] += conf

        instance_info = {
            "predictions": instance_predictions,
            "confidences": instance_confidences,
            "probabilities": instance_probs,
        }

        return bag_logits, instance_info

In [None]:
class MIL_FabricClassifier(nn.Module):
    def __init__(self, n_classes, pretrained_path=None, agg_type='confidence_voting'):
        super().__init__()

        # load pretrained resnet18 model
        self.instance_model = models.resnet18(pretrained=False)
        checkpoint = torch.load(pretrained_path, map_location='cpu')
        state_dict = checkpoint.get('model_state_dict', checkpoint)
        state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k}
        self.instance_model.load_state_dict(state_dict, strict=False)

        # replace final classifier to outut 7 classes
        self.instance_model.fc = nn.Linear(512, n_classes)

        self.n_classes = n_classes

        self.agg_type = agg_type
        if agg_type == 'confidence_voting':
            self.agg = ConfidenceWeightedVoting(n_classes)

    def forward(self, bag_dict):
        bag_logits_list = []
        bag_pred_list = []
        bag_info_dict = {}
        bag_ids = []

        for bag_id, instances in bag_dict.items():
            instance_logits = self.instance_model(instances.float())

            # MIL aggregation
            bag_logits, instance_info = self.agg(instance_logits)

            bag_prediction = torch.argmax(bag_logits)

            bag_output_list.append(votes)
            bag_ids.append(bag_id)

            bag_logits_list.append(bag_logits)
            bag_pred_list.append(bag_prediction)
            bag_ids.append(bag_id)

            bag_info_dict[bag_id] = {
                'instance_predictions': instance_info['predictions'].cpu(),
                'instance_confidences': instance_info['confidences'].cpu(),
                'instance_probabilities': instance_info['probabilities'].cpu()
            }

        bag_logits_batch = torch.stack(bag_logits_list)
        bag_pred_batch = torch.stack(bag_pred_list)

        return bag_logits_batch, bag_pred_batch, bag_info_dict, bag_ids



In [None]:
class FabricMILDataset(Dataset):
    def __init__(self, data_dict, transform=None):
        self.data_dict = data_dict
        self.item_ids = list(data_dict.keys())
        self.transform = transform

    def __len__(self):
        return len(self.item_ids)

    def __getitem__(self, idx):
        item_id = self.item_ids[idx]
        item_data = self.data_dict[item_id]

        image_paths = item_data['images']
        label = item_data['label']

        # load all images for the bag
        instances = []
        for img_path in image_paths:
            try:
                img = Image.open(img_path).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                instances.append(img)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")

        # handle empty bags
        if len(instances) == 0:
            instances = [torch.zeros(3, 224, 224)]

        instances = torch.stack(instances)
        return item_id, instances, torch.tensor(label, dtype=torch.long)

# custom collate function for dataset
def collate_mil_bags(batch):
    bag_ids, instances, labels = zip(*batch)
    bag_dict = {bid: inst for bid, inst in zip(bag_ids, instances)}
    labels = torch.stack(labels)
    return bag_dict, labels


In [None]:
def pad_to_square(img, fill=(255, 255, 255)):
    w, h = img.size
    if w == h:
        return img
    diff = abs(h - w)
    if w < h:
        padding = (diff // 2, 0, diff - diff // 2, 0)
    else:
        padding = (0, diff // 2, 0, diff - diff // 2)
    return transforms.functional.pad(img, padding, fill=fill)


In [None]:
def create_data_dict_from_csv(csv_path, images_folder):
    df = pd.read_csv(csv_path)

    # create label mapping
    unique_labels = sorted(df['label'].unique())
    label_mapping = {label: idx for idx, label in enumerate(unique_labels)}
    reverse_mapping = {idx: label for label, idx in label_mapping.items()}

    print(f"Found {len(unique_labels)} fabric classes:")
    for label, idx in label_mapping.items():
        print(f"  {idx}: {label}")

    images_path = Path(images_folder)

    data_dict = {}
    missing_images = []

    for item_id, group in df.groupby('item_id'):
        image_paths = []
        for image_id in group['image_id']:
            img_path = images_path / f"{image_id}.jpg"
            if img_path.exists():
                image_paths.append(str(img_path))
            else:
                missing_images.append(f"{image_id}.jpg")

        # skip items with no valid images
        if len(image_paths) == 0:
            print(f"Warning: Item {item_id} has no valid images, skipping...")
            continue

        # get label
        label_name = group['label'].iloc[0]
        label_idx = label_mapping[label_name]

        data_dict[f'item_{item_id}'] = {
            'images': image_paths,
            'label': label_idx
        }

    if missing_images:
        print(f"\ {len(missing_images)} images not found in {images_folder}")
        print(f"Missing: {missing_images[:5]}")

    print(f"\nCreated data dictionary with {len(data_dict)} items")
    if len(data_dict) > 0:
        example_key = list(data_dict.keys())[0]
        print(f"Example item: {example_key}")
        print(f"  - Images: {len(data_dict[example_key]['images'])}")
        print(f"  - Label: {data_dict[example_key]['label']} ({reverse_mapping[data_dict[example_key]['label']]})")

    return data_dict, label_mapping, reverse_mapping

In [None]:
csv_path = "D:\\csci_461_textiles_project\\data\\fiber\\fiber_data.csv"
images_folder = "D:\\csci_461_textiles_project\\data\\fiber\\fiber_images"
data_dict, label_mapping, reverse_mapping = create_data_dict_from_csv(csv_path, images_folder)

print(label_mapping)

Found 7 fabric classes:
  0: Acrylic
  1: Cotton
  2: Linen
  3: Nylon
  4: Polyester
  5: Suede
  6: Viscose

Created data dictionary with 2145 items
Example item: item_1
  - Images: 6
  - Label: 0 (Acrylic)
{'images': ['D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\1.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\2.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\3.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\4.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\5.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\6.jpg'], 'label': 0}
{'Acrylic': 0, 'Cotton': 1, 'Linen': 2, 'Nylon': 3, 'Polyester': 4, 'Suede': 5, 'Viscose': 6}


In [None]:
print(data_dict.get('item_1'))


{'images': ['D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\1.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\2.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\3.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\4.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\5.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\6.jpg'], 'label': 0}


In [None]:
import random
random.seed(42)

# target maximum per class
max_per_class = {
    1: 200,  # downsample class 1
    2: 200,  # downsample class 2
}

# group items by label
by_class = {}
for k, v in data_dict.items():
    lbl = v["label"]
    by_class.setdefault(lbl, []).append((k, v))

# downsample large classes
balanced_items = []
for lbl, items in by_class.items():
    if lbl in max_per_class and len(items) > max_per_class[lbl]:
        items = random.sample(items, max_per_class[lbl])
    balanced_items.extend(items)

# reconstruct new data_dict
balanced_data_dict = {k: v for k, v in balanced_items}

# check new distribution
from collections import Counter
print(Counter(v["label"] for v in balanced_data_dict.values()))
data_dict = balanced_data_dict
len(data_dict)


Counter({1: 200, 2: 200, 5: 154, 6: 101, 3: 88, 4: 61, 0: 22})


826

In [None]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models import resnet18
from PIL import Image
from pathlib import Path
import pandas as pd
from tqdm import tqdm


def load_image(path, transform):
    img = Image.open(path).convert("RGB")
    return transform(img)


def run_resnet_baseline_bagging(
    data_dict,
    pretrained_checkpoint_path,
    textilenet_num_classes=33,
    save_dir="textilenet_image_baseline",
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Running TextileNet image-level baseline on:", device)

    save_path = Path(save_dir)
    save_path.mkdir(exist_ok=True, parents=True)

    # load model
    model = resnet18(num_classes=textilenet_num_classes)

    ckpt = torch.load(pretrained_checkpoint_path, map_location=device)

    if "model" in ckpt:
        print("✓ Found 'model' inside checkpoint — using ckpt['model']")
        state_dict = ckpt["model"]
    else:
        state_dict = ckpt

    clean_state = {}
    for k, v in state_dict.items():
        clean_k = k.replace("module.", "")
        clean_state[clean_k] = v

    print("✓ Stripped 'module.' prefix from state_dict keys")

    model.load_state_dict(clean_state)
    model.to(device)
    model.eval()

    print("✓ Loaded TextileNet ResNet18 checkpoint")

    transform = transforms.Compose([
        pad_to_square,
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    rows = []

    for item_id, item in tqdm(data_dict.items(), desc="Processing bags"):

        image_paths = item["images"]
        true_lbl = item["label"]

        image_probs = []

        for img_path in image_paths:
            img_tensor = load_image(img_path, transform).unsqueeze(0).to(device)

            with torch.no_grad():
                logits = model(img_tensor)
                probs = F.softmax(logits, dim=1).cpu().numpy()[0]

            image_probs.append(probs)

        image_probs = np.array(image_probs)
        bag_probs = image_probs.mean(axis=0)

        bag_pred = int(bag_probs.argmax())

        row = {
            "bag_id": item_id,
            "true_label": true_lbl,
            "predicted_label": bag_pred,
        }

        for c in range(textilenet_num_classes):
            row[f"prob_class_{c}"] = float(bag_probs[c])

        rows.append(row)

    df = pd.DataFrame(rows)
    out_file = save_path / "textilenet_bag_predictions.csv"
    df.to_csv(out_file, index=False)

    print("✓ Saved:", out_file)
    return df


In [None]:
results = run_resnet_baseline_bagging(
    data_dict=data_dict,
    pretrained_checkpoint_path="D:/csci_461_textiles_project/res18_ckpt.pth",
    textilenet_num_classes=33,
    save_dir="D:/csci_461_textiles_project/fiber_resnet_test_three/image_baseline"
)


Running TextileNet image-level baseline on: cpu
✓ Found 'model' inside checkpoint — using ckpt['model']
✓ Stripped 'module.' prefix from state_dict keys
✓ Loaded TextileNet ResNet18 checkpoint


Processing bags: 100%|██████████| 826/826 [06:12<00:00,  2.22it/s]

✓ Saved: D:\csci_461_textiles_project\fiber_resnet_test_three\image_baseline\textilenet_bag_predictions.csv



