# Imports

In [None]:
!pip install grad-cam torchcam numpy==1.26.4 pandas==2.2.2 seaborn scikit-learn

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import torchvision.models as models
from torchvision.models import (
    resnet18, resnet50, resnet152, ResNet18_Weights,
    efficientnet_b0, EfficientNet_B0_Weights,
    vgg16, VGG16_Weights,
    inception_v3, Inception_V3_Weights
)
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

In [None]:
torch.manual_seed(42)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Dataset

In [None]:
class MILImageFolder(Dataset):
  def __init__(self, root, transform, bag_size=10):
    self.dataset = datasets.ImageFolder(root=root, transform=transform)
    self.bag_size = bag_size

    self.label_to_indices = {}
    for idx, (_, label) in enumerate(self.dataset.samples):
        self.label_to_indices.setdefault(label, []).append(idx)

    self.bags = []
    for label, indices in self.label_to_indices.items():
        np.random.shuffle(indices)
        for i in range(0, len(indices) - bag_size + 1, bag_size):
            bag_indices = indices[i:i + bag_size]
            self.bags.append((bag_indices, label))
    np.random.shuffle(self.bags)

    # Save bag labels for stratified splits
    self.targets = np.array([label for (_, label) in self.bags])

  def __len__(self):
    return len(self.bags)

  def __getitem__(self, index):
    bag_indices, bag_label = self.bags[index]
    bag_images = [self.dataset[i][0] for i in bag_indices]
    bag_tensor = torch.stack(bag_images)
    return bag_tensor, bag_label

In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomRotation(degrees=15),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])


test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Dataset directories
train_augmented_dir = '/content/drive/MyDrive/GBCUD/dataset/train'
output_test_folder = '/content/drive/MyDrive/GBCUD/dataset/test'

#  MIL datasets.
# bag will contain 10 images from the same class
train_dataset = MILImageFolder(root=train_augmented_dir, transform=train_transforms, bag_size=10)
test_dataset = MILImageFolder(root=output_test_folder, transform=test_transforms, bag_size=10)

# Create stratified train/validation split on bag labels.
targets = train_dataset.targets
train_idx, val_idx = train_test_split(
    np.arange(len(targets)),
    test_size=0.2,
    stratify=targets,
    random_state=42
)

train_subset = Subset(train_dataset, train_idx)
val_subset = Subset(train_dataset, val_idx)

# Create a weighted sampler for the training subset.
train_targets = targets[train_idx]
class_counts = np.bincount(train_targets)
sample_weights = np.array([1.0 / class_counts[label] for label in train_targets])
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

# Model

In [None]:
class AttentionMIL(nn.Module):
  def __init__(self, backbone_name="resnet18", num_classes=3, pretrained=True):
    super(AttentionMIL, self).__init__()

    self.L = 512
    self.D = 256
    self.K = 1

    #  load pretrained backbone
    self.feature_extractor, backbone_output_dim = self.get_backbone(backbone_name, pretrained)

    self.pool = nn.AdaptiveAvgPool2d((30, 30))

    self.extra_conv_layers = nn.Sequential(
        nn.Conv2d(backbone_output_dim, 256, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU()
    )

    self.feature_extractor_part2 = nn.Sequential(
        nn.Linear(64 * 30 * 30, self.L),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(self.L, self.L),
        nn.ReLU(),
        nn.Dropout(0.5),
    )

    self.attention = nn.Sequential(
        nn.Linear(self.L, self.D),
        nn.Tanh(),
        nn.Linear(self.D, self.K)
    )

    self.classifier = nn.Linear(self.L * self.K, num_classes)

  def get_backbone(self, name, pretrained):
    if name == "resnet18":
      model = resnet18(weights=ResNet18_Weights.DEFAULT if pretrained else None)
      return nn.Sequential(*list(model.children())[:-2]), 512
    elif name == "resnet50":
      model = resnet50(pretrained=pretrained)
      return nn.Sequential(*list(model.children())[:-2]), 2048
    elif name == "resnet152":
      model = resnet152(pretrained=pretrained)
      return nn.Sequential(*list(model.children())[:-2]), 2048
    elif name == "efficientnet_b0":
      model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT if pretrained else None)
      return model.features, 1280
    elif name == "vgg16":
      model = vgg16(weights=VGG16_Weights.DEFAULT if pretrained else None)
      return nn.Sequential(*list(model.features.children())[:-1]), 512
    elif name == "inception_v3":
      model = inception_v3(weights=Inception_V3_Weights.DEFAULT if pretrained else None, aux_logits=False)
      return nn.Sequential(*list(model.children())[:-2]), 2048
    else:
      raise ValueError(f"Unsupported backbone: {name}")

  def forward(self, x):
    B, bag_size, C, H, W = x.shape
    x = x.view(B * bag_size, C, H, W)
    features = self.feature_extractor(x)
    features = self.pool(features)
    features = self.extra_conv_layers(features)
    features = features.view(B * bag_size, -1)
    H_features = self.feature_extractor_part2(features)
    H_features = H_features.view(B, bag_size, -1)

    A = self.attention(H_features.view(B * bag_size, -1))
    A = A.view(B, bag_size, self.K).transpose(1, 2)

    temperature = 0.5
    A = F.softmax(A / temperature, dim=2)

    M = torch.bmm(A, H_features)
    M = M.view(B, -1)
    logits = self.classifier(M)
    probs = F.log_softmax(logits, dim=1).exp()
    return logits, probs, A

    def calculate_classification_error(self, X, Y):
      logits, _, _ = self.forward(X)
      preds = torch.argmax(logits, dim=1)
      error = 1.0 - preds.eq(Y).cpu().float().mean().item()
      return error, preds

  def calculate_objective(self, X, Y):
    logits, _, A = self.forward(X)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(logits, Y)
    return loss, A

  def get_attention_map(self, x):
    self.eval()
    with torch.no_grad():
      if x.ndimension() == 5:
        x = x.squeeze(0)
      elif x.ndimension() == 4:
        pass
      else:
        raise ValueError(f"wrong input shape in get_attention_map: {x.shape}")

      bag_size, C, H, W = x.shape

      # Extract features
      features = self.feature_extractor(x)
      features = self.pool(features)
      features = self.extra_conv_layers(features)

      features = features.view(bag_size, -1)
      H_features = self.feature_extractor_part2(features)

      A = self.attention(H_features)
      A = F.softmax(A, dim=0)

      att_map = A[:, 0].cpu().numpy()

      return att_map

# Train and Evaluate Model

In [None]:
def train(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        all_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    return total_loss / len(dataloader), acc

In [None]:
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            all_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average="macro")
    return total_loss / len(dataloader), acc, f1

In [None]:
def run_experiment(model_name, train_loader, val_loader, lr=1e-4, epochs=5):
    model = build_model(model_name)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    print(f"Training {model_name} for {epochs} epochs...")
    for epoch in range(epochs):
        train_loss, train_acc = train(model, train_loader, optimizer, criterion)
        val_loss, val_acc, val_f1 = evaluate(model, val_loader, criterion)
        print(f"Epoch {epoch+1}/{epochs} | Train Acc: {train_acc:.3f} | Val Acc: {val_acc:.3f} | Val F1: {val_f1:.3f}")
    return model

# Experiment

In [None]:
model = run_experiment("resnet50", train_loader, val_loader, lr=1e-4, epochs=10)