In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix
)

import pandas as pd
import numpy as np
from pathlib import Path
import os
from PIL import Image
from tqdm import tqdm
import json
from collections import defaultdict

In [None]:
class ConfidenceWeightedVoting():
    def __init__(self, n_classes):
        super().__init__()
        self.n_classes = n_classes

    def forward(self, instance_logits):
        # get probabilities
        instance_probs = torch.softmax(instance_logits, dim=-1)

        # get predictions and confidences
        instance_predictions = torch.argmax(instance_probs, dim=-1)
        instance_confidences = torch.max(instance_probs, dim=-1)[0]

        # use majority voting for each prediction weighted by confidence
        bag_logits = torch.zeros(self.n_classes, device=instance_logits.device)
        for pred, conf in zip(instance_predictions, instance_confidences):
            bag_logits[pred] += conf

        instance_info = {
            "predictions": instance_predictions,
            "confidences": instance_confidences,
            "probabilities": instance_probs,
        }

        return bag_logits, instance_info

In [None]:
class MIL_FabricClassifier(nn.Module):
    def __init__(self, n_classes, pretrained_path=None, agg_type='confidence_voting'):
        super().__init__()

        # load pretrained resnet18 model
        self.instance_model = models.resnet18(pretrained=False)
        checkpoint = torch.load(pretrained_path, map_location='cpu')
        state_dict = checkpoint.get('model_state_dict', checkpoint)
        state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k}
        self.instance_model.load_state_dict(state_dict, strict=False)

        # replace final classifier to outut 7 classes
        self.instance_model.fc = nn.Linear(512, n_classes)

        self.n_classes = n_classes

        self.agg_type = agg_type
        if agg_type == 'confidence_voting':
            self.agg = ConfidenceWeightedVoting(n_classes)

    def forward(self, bag_dict):
        bag_logits_list = []
        bag_pred_list = []
        bag_info_dict = {}
        bag_ids = []

        for bag_id, instances in bag_dict.items():
            instance_logits = self.instance_model(instances.float())

            # MIL aggregation
            bag_logits, instance_info = self.agg(instance_logits)

            bag_prediction = torch.argmax(bag_logits)

            bag_output_list.append(votes)
            bag_ids.append(bag_id)

            bag_logits_list.append(bag_logits)
            bag_pred_list.append(bag_prediction)
            bag_ids.append(bag_id)

            bag_info_dict[bag_id] = {
                'instance_predictions': instance_info['predictions'].cpu(),
                'instance_confidences': instance_info['confidences'].cpu(),
                'instance_probabilities': instance_info['probabilities'].cpu()
            }

        bag_logits_batch = torch.stack(bag_logits_list)
        bag_pred_batch = torch.stack(bag_pred_list)

        return bag_logits_batch, bag_pred_batch, bag_info_dict, bag_ids



In [None]:
class FabricMILDataset(Dataset):
    def __init__(self, data_dict, transform=None):
        self.data_dict = data_dict
        self.item_ids = list(data_dict.keys())
        self.transform = transform
    
    def __len__(self):
        return len(self.item_ids)
    
    def __getitem__(self, idx):
        item_id = self.item_ids[idx]
        item_data = self.data_dict[item_id]
        
        image_paths = item_data['images']
        label = item_data['label']
        
        # load all images for the bag
        instances = []
        for img_path in image_paths:
            try:
                img = Image.open(img_path).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                instances.append(img)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
        
        # handle empty bags
        if len(instances) == 0:
            instances = [torch.zeros(3, 224, 224)]
        
        instances = torch.stack(instances)
        return item_id, instances, torch.tensor(label, dtype=torch.long)

# custom collate function for dataset
def collate_mil_bags(batch):
    bag_ids, instances, labels = zip(*batch)
    bag_dict = {bid: inst for bid, inst in zip(bag_ids, instances)}
    labels = torch.stack(labels)
    return bag_dict, labels

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device, reg_factor=0.0):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    all_preds = []
    all_labels = []
    all_probs = []
    class_predictions = {}

    pbar = tqdm(dataloader, desc="training")

    for batch_idx, (bag_dict, labels) in enumerate(pbar):
        bag_dict = {k: v.to(device) for k, v in bag_dict.items()}
        labels = labels.to(device)

        optimizer.zero_grad()
        logits, _, _, _ = model(bag_dict)
        loss = criterion(logits, labels)
        if reg_factor > 0:
            reg = 0
            for p in model.parameters():
                reg += torch.norm(p, 2)
            loss += reg_factor * reg

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        probs = torch.softmax(logits, dim=-1)
        preds = probs.argmax(dim=1)

        total += labels.size(0)
        correct += (preds == labels).sum().item()

        for p in preds.cpu().numpy():
            class_predictions[p] = class_predictions.get(p, 0) + 1

        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())
        all_probs.extend(probs.cpu().tolist())

        pbar.set_postfix({
            "loss": f"{running_loss/(batch_idx+1):.4f}",
            "acc": f"{100 * correct / total:.2f}%"
        })

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total

    print(f"  class predictions: {dict(sorted(class_predictions.items()))}")

    return epoch_loss, epoch_acc, all_preds, all_labels, np.array(all_probs)

def test(model, dataloader, criterion, device, return_details=False):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    all_preds = []
    all_labels = []
    all_probs = []
    all_bag_ids = []

    with torch.no_grad():
        for bag_dict, labels in tqdm(dataloader, desc="Validation"):
            bag_dict = {k: v.to(device) for k, v in bag_dict.items()}
            labels = labels.to(device)

            logits, _, _, bag_ids = model(bag_dict)

            loss = criterion(logits, labels)

            probs = torch.softmax(logits, dim=-1)
            preds = probs.argmax(dim=1)

            running_loss += loss.item()
            total += labels.size(0)
            correct += (preds == labels).sum().item()

            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())
            all_probs.extend(probs.cpu().tolist())
            all_bag_ids.extend(bag_ids)

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total

    if return_details:
        return (
            epoch_loss,
            epoch_acc,
            all_preds,
            all_labels,
            np.array(all_probs),
            all_bag_ids
        )
    else:
        return epoch_loss, epoch_acc, all_preds, all_labels

In [None]:
def calculate_metrics(true_labels, pred_labels, probas, num_classes):
    true_labels = np.array(true_labels)
    pred_labels = np.array(pred_labels)

    accuracy = accuracy_score(true_labels, pred_labels)
    precision = precision_score(true_labels, pred_labels, average='macro', zero_division=0)
    recall = recall_score(true_labels, pred_labels, average='macro', zero_division=0)
    f1 = f1_score(true_labels, pred_labels, average='macro', zero_division=0)

    try:
        auc_macro = roc_auc_score(true_labels, probas, multi_class='ovo', average='macro')
    except Exception:
        auc_macro = np.nan

    # confusion matrix
    cm = confusion_matrix(true_labels, pred_labels, labels=list(range(num_classes)))

    # per-class metrics
    TP = np.diag(cm).astype(float)
    FN = cm.sum(axis=1) - TP
    FP = cm.sum(axis=0) - TP
    TN = cm.sum() - (TP + FP + FN)

    FPR = np.where((FP + TN) > 0, FP / (FP + TN), 0.0)
    FNR = np.where((FN + TP) > 0, FN / (FN + TP), 0.0)
    TNR = np.where((TN + FP) > 0, TN / (TN + FP), 0.0)
    TPR = np.where((TP + FN) > 0, TP / (TP + FN), 0.0)

    return {
        'accuracy': float(accuracy),
        'precision_macro': float(precision),
        'recall_macro': float(recall),
        'f1_macro': float(f1),
        'auc_macro': float(auc_macro),

        'FPR_per_class': FPR.tolist(),
        'FNR_per_class': FNR.tolist(),
        'TPR_per_class': TPR.tolist(),
        'TNR_per_class': TNR.tolist(),

        'FPR_macro': float(np.mean(FPR)),
        'FNR_macro': float(np.mean(FNR)),
        'TPR_macro': float(np.mean(TPR)),
        'TNR_macro': float(np.mean(TNR)),

        'confusion_matrix': cm.tolist()
    }

def print_metrics(metrics, phase="Metrics"):
    print(f"\n===== {phase} =====")
    print(f"Accuracy:       {metrics['accuracy']:.4f}")
    print(f"Precision (mac):{metrics['precision_macro']:.4f}")
    print(f"Recall (mac):   {metrics['recall_macro']:.4f}")
    print(f"F1 (mac):       {metrics['f1_macro']:.4f}")
    print(f"AUC (mac):      {metrics['auc_macro']:.4f}")
    print(f"FPR (mac):      {metrics['FPR_macro']:.4f}")
    print(f"FNR (mac):      {metrics['FNR_macro']:.4f}")

In [None]:
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split

def pad_to_square(img, fill=(255, 255, 255)):
    w, h = img.size
    if w == h:
        return img
    diff = abs(h - w)
    if w < h:
        padding = (diff // 2, 0, diff - diff // 2, 0)
    else:
        padding = (0, diff // 2, 0, diff - diff // 2)
    return transforms.functional.pad(img, padding, fill=fill)

def train_mil_fabric_kfold(
    data_dict,
    pretrained_checkpoint_path,
    num_classes,
    save_dir='mil_fabric_experiments',
    n_folds=5,
    n_epochs=15,
    batch_size=4,
    lr=0.0001,
    reg_factor=0.0,
    agg_type='confidence_voting'
):
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    save_path = Path(save_dir)
    save_path.mkdir(exist_ok=True, parents=True)

    # calculate class distribution
    all_labels = [data_dict[item_id]['label'] for item_id in data_dict.keys()]
    unique_labels, label_counts = np.unique(all_labels, return_counts=True)
    
    print(f"\n Class Distribution:")
    for label, count in zip(unique_labels, label_counts):
        print(f"  Class {label}: {count} items ({count/len(all_labels)*100:.1f}%)")
    
    class_weights = torch.FloatTensor([(len(all_labels) / (num_classes * count))**2
                                          for count in label_counts])

    # beta = 0.5
    # class_weights = 1.0 / (label_counts ** beta)
    # class_weights /= class_weights.sum() / len(label_counts)
    # class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)

    class_weights = class_weights.to(device)

    
    train_transform = transforms.Compose([
        pad_to_square,  
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        pad_to_square,  
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    item_ids = list(data_dict.keys())
    stratify_labels = [data_dict[item_id]['label'] for item_id in item_ids]
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    fold_results = []
    all_predictions = {}

    for fold, (trainval_idx, test_idx) in enumerate(skf.split(data_dict, stratify_labels)):
        trainval_labels = [stratify_labels[i] for i in trainval_idx]
        train_idx, val_idx = train_test_split(trainval_idx, test_size=0.10, stratify=trainval_labels, random_state=42)

        train_items = {item_ids[i]: data_dict[item_ids[i]] for i in train_idx}
        val_items = {item_ids[i]: data_dict[item_ids[i]] for i in val_idx}
        test_items = {item_ids[i]: data_dict[item_ids[i]] for i in test_idx}

        print(f"Train bags: {len(train_items)}, Val bags: {len(val_items)}, Test bags: {len(test_items)}")
        
        # Create datasets
        train_dataset = FabricMILDataset(train_items, train_transform)
        val_dataset = FabricMILDataset(val_items, val_transform)
        test_dataset = FabricMILDataset(test_items, val_transform)
        
        train_loader = DataLoader(
            train_dataset, 
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,
            collate_fn=collate_mil_bags,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            collate_fn=collate_mil_bags,
            pin_memory=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            collate_fn=collate_mil_bags,
            pin_memory=True
        )
        
        # initialize model
        model = MIL_FabricClassifier(
            n_classes=num_classes,
            pretrained_path=pretrained_checkpoint_path,
            agg_type=agg_type
        ).to(device)
        
        # optimizer and loss
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=3, verbose=True
        )
        
        # training loop
        best_val_acc = 0.0
        fold_history = {
            'train_loss': [], 'train_acc': [],
            'val_loss': [], 'val_acc': []
        }
        
        for epoch in range(n_epochs):
            print(f"\nEpoch {epoch+1}/{n_epochs}")
            
            # TRAIN
            train_loss, train_acc, train_preds, train_labels, train_probs = train_epoch(
                model, train_loader, criterion, optimizer, device, reg_factor
            )

            # VAL
            val_loss, val_acc, val_preds, val_labels, val_probs, val_bag_ids = validate(
                model, val_loader, criterion, device, return_details=True
            )
            
            # compute train + val metrics (per epoch)
            train_metrics = calculate_metrics(
                true_labels=train_labels,
                pred_labels=train_preds,
                probas=train_probs,
                num_classes=num_classes
            )

            val_metrics = calculate_metrics(
                true_labels=val_labels,
                pred_labels=val_preds,
                probas=val_probs,
                num_classes=num_classes
            )
            
            fold_history['train_loss'].append(train_loss)
            fold_history['train_acc'].append(train_acc)
            fold_history['val_loss'].append(val_loss)
            fold_history['val_acc'].append(val_acc)
            
            print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc*100:.2f}%")
            print(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc*100:.2f}%")
            
            scheduler.step(val_acc)
            
            # save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({
                    'fold': fold,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_acc': val_acc,
                    'val_loss': val_loss,
                }, save_path / f'best_model_fold{fold}.pth')
                print(f"  ✓ Saved best model (acc: {val_acc*100:.2f}%)")
        
        fold_history.setdefault('train_metrics', []).append(train_metrics)
        fold_history.setdefault('val_metrics', []).append(val_metrics)

        fold_results.append({
            'fold': fold,
            'best_val_acc': best_val_acc,
            'history': fold_history
        })
        
        print(f"\nFold {fold+1} Best Val Acc: {best_val_acc*100:.2f}%")
    
        print(f"\n{'='*70}")
        print(f"Testing best model from fold {fold+1}")
        print(f"{'='*70}")
        
        # load best model
        checkpoint = torch.load(save_path / f'best_model_fold{fold}.pth')
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # test
        test_loss, test_acc, test_preds, test_labels, test_probs, test_bag_ids = validate(
            model, test_loader, criterion, device, return_details=True
        )
        
        test_metrics = calculate_metrics(test_labels, test_preds, test_probs, num_classes)
        print_metrics(test_metrics, phase="Test")

        fold_results.append({
            'test_metrics': test_metrics,
        })

        fold_predictions = {
            'bag_id': test_bag_ids,
            'true_label': test_labels,
            'predicted_label': test_preds,
            'probabilities': test_probs.tolist()
        }
        
        fold_pred_df = pd.DataFrame({
            'bag_id': test_bag_ids,
            'true_label': test_labels,
            'predicted_label': test_preds,
            **{f'prob_class_{i}': test_probs[:, i] for i in range(num_classes)}
        })
        fold_pred_df.to_csv(save_path / f'predictions_fold{fold}.csv', index=False)
        
        for bag_id, true_label, pred_label, prob in zip(test_bag_ids, test_labels, test_preds, test_probs):
            all_predictions[bag_id] = {
                'fold': int(fold),
                'true_label': int(true_label),       
                'predicted_label': int(pred_label),
                'probabilities': [float(p) for p in prob]
            }
        
        fold_results.append({
            'fold': fold,
            'best_val_acc': best_val_acc,
            'test_acc': test_acc,
            'test_metrics': test_metrics,
            'history': fold_history,
            'predictions': fold_predictions
        })
        
        print(f"\nFold {fold+1} - Val Acc: {best_val_acc*100:.2f}%, Test Acc: {test_acc*100:.2f}%")

    print(f"\n{'='*70}")
    print("CROSS-VALIDATION SUMMARY")
    print(f"{'='*70}")
    
    avg_acc = np.mean([r['best_val_acc'] for r in fold_results])
    std_acc = np.std([r['best_val_acc'] for r in fold_results])
    
    print(f"\nAverage Validation Accuracy: {avg_acc*100:.2f}% ± {std_acc*100:.2f}%")
    print(f"\nPer-fold results:")
    for i, result in enumerate(fold_results):
        print(f"  Fold {i+1}: {result['best_val_acc']*100:.2f}%")
    
    results_summary = {
        'fold_results': [{
            'fold': r['fold'],
            'best_val_acc': float(r['best_val_acc']),
            'history': {k: [float(x) for x in v] for k, v in r['history'].items()}
        } for r in fold_results],
        'avg_acc': float(avg_acc),
        'std_acc': float(std_acc),
        'config': {
            'n_folds': n_folds,
            'n_epochs': n_epochs,
            'batch_size': batch_size,
            'lr': lr,
            'reg_factor': reg_factor,
            'agg_type': agg_type,
            'num_classes': num_classes
        }
    }
    
    with open(save_path / 'cv_results.json', 'w') as f:
        json.dump(results_summary, f, indent=2)
    
    print(f"\n✓ Results saved to {save_path}")
    
    return fold_results

In [None]:
def create_data_dict_from_csv(csv_path, images_folder):
    df = pd.read_csv(csv_path)
    
    # create label mapping
    unique_labels = sorted(df['label'].unique())
    label_mapping = {label: idx for idx, label in enumerate(unique_labels)}
    reverse_mapping = {idx: label for label, idx in label_mapping.items()}
    
    print(f"Found {len(unique_labels)} fabric classes:")
    for label, idx in label_mapping.items():
        print(f"  {idx}: {label}")
    
    images_path = Path(images_folder)
    
    data_dict = {}
    missing_images = []
    
    for item_id, group in df.groupby('item_id'):
        image_paths = []
        for image_id in group['image_id']:
            img_path = images_path / f"{image_id}.jpg"
            if img_path.exists():
                image_paths.append(str(img_path))
            else:
                missing_images.append(f"{image_id}.jpg")
        
        # skip items with no valid images
        if len(image_paths) == 0:
            print(f"Warning: Item {item_id} has no valid images, skipping...")
            continue
        
        # get label
        label_name = group['label'].iloc[0]
        label_idx = label_mapping[label_name]
        
        data_dict[f'item_{item_id}'] = {
            'images': image_paths,
            'label': label_idx
        }
    
    if missing_images:
        print(f"\ {len(missing_images)} images not found in {images_folder}")
        print(f"Missing: {missing_images[:5]}")
    
    print(f"\nCreated data dictionary with {len(data_dict)} items")
    if len(data_dict) > 0:
        example_key = list(data_dict.keys())[0]
        print(f"Example item: {example_key}")
        print(f"  - Images: {len(data_dict[example_key]['images'])}")
        print(f"  - Label: {data_dict[example_key]['label']} ({reverse_mapping[data_dict[example_key]['label']]})")
    
    return data_dict, label_mapping, reverse_mapping

In [55]:
csv_path = "D:\\csci_461_textiles_project\\data\\fiber\\fiber_data.csv"
images_folder = "D:\\csci_461_textiles_project\\data\\fiber\\fiber_images"
data_dict, label_mapping, reverse_mapping = create_data_dict_from_csv(csv_path, images_folder)

print(data_dict.get('item_1'))
print(label_mapping)

Found 7 fabric classes:
  0: Acrylic
  1: Cotton
  2: Linen
  3: Nylon
  4: Polyester
  5: Suede
  6: Viscose

Created data dictionary with 2145 items
Example item: item_1
  - Images: 6
  - Label: 0 (Acrylic)
{'images': ['D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\1.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\2.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\3.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\4.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\5.jpg', 'D:\\csci_461_textiles_project\\data\\fiber\\fiber_images\\6.jpg'], 'label': 0}
{'Acrylic': 0, 'Cotton': 1, 'Linen': 2, 'Nylon': 3, 'Polyester': 4, 'Suede': 5, 'Viscose': 6}


In [None]:
import random
random.seed(42)

# target maximum per class
max_per_class = {
    1: 200,  # downsample class 1
    2: 200,  # downsample class 2
}

# group items by label
by_class = {}
for k, v in data_dict.items():
    lbl = v["label"]
    by_class.setdefault(lbl, []).append((k, v))

# downsample large classes
balanced_items = []
for lbl, items in by_class.items():
    if lbl in max_per_class and len(items) > max_per_class[lbl]:
        items = random.sample(items, max_per_class[lbl])
    balanced_items.extend(items)

# reconstruct new data_dict
balanced_data_dict = {k: v for k, v in balanced_items}

# check new distribution
from collections import Counter
print(Counter(v["label"] for v in balanced_data_dict.values()))
data_dict = balanced_data_dict
len(data_dict)


Counter({1: 200, 2: 200, 5: 154, 6: 101, 3: 88, 4: 61, 0: 22})


826

In [None]:
if __name__ == '__main__':    
    # Train with 5-fold CV
    results = train_mil_fabric_kfold(
        data_dict=data_dict,
        pretrained_checkpoint_path='D:\\csci_461_textiles_project\\res18_ckpt.pth',
        num_classes=7,  # Adjust based on your fabric classes
        save_dir='D:\\csci_461_textiles_project\\fiber_resnet_test_three',
        n_folds=5,
        n_epochs=30,
        batch_size=8,
        lr=0.00001,
        agg_type='confidence_voting'  # or 'softmax_mean', 'noisy_and'
    )