In [1]:
import random
from pyexpat import features

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset, Subset, DataLoader
from torchvision import datasets, transforms, models
from sklearn.metrics import accuracy_score
from torchvision.models import ResNet18_Weights

# Constants
RANDOM_SEED = 42
TRAIN_VAL_SPLIT_RATIO = 0.7
SUBSET_RATIO = 0.1
BAG_SIZE = 8
BATCH_SIZE = 32
NUM_WORKERS = 4
PIN_MEMORY = True
INPUT_DIM = 512
OUTPUT_DIM = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# Set random seed for reproducibility
random.seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Ensure deterministic behavior for CUDA operations
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
# Data transformation
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Dataset preparation
mnist_train = datasets.MNIST(root='data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='data', train=False, download=True, transform=transform)

In [4]:
# Subset and splitting data
subset_size = int(SUBSET_RATIO * len(mnist_train))
subset_indices = torch.randperm(len(mnist_train))[:subset_size]
mnist_train_subset = Subset(mnist_train, subset_indices)

train_size = int(TRAIN_VAL_SPLIT_RATIO * len(mnist_train_subset))
val_size = len(mnist_train_subset) - train_size
train_dataset, val_dataset = random_split(mnist_train_subset, [train_size, val_size])

In [5]:
class BagDataset(Dataset):
    def __init__(self, dataset, bag_size=BAG_SIZE):
        self.dataset = dataset
        self.bag_size = bag_size

    def __len__(self):
        return len(self.dataset) // self.bag_size

    def __getitem__(self, idx):
        bag_images = []
        bag_labels = []

        for i in range(self.bag_size):
            img, label = self.dataset[idx * self.bag_size + i]
            bag_images.append(img)
            bag_labels.append(label)
        
        bag_images = torch.stack(bag_images)
        bag_label = 1 if 9 in bag_labels else 0

        return bag_images, bag_label

In [6]:
class BagDataLoader:
    def __init__(self, dataset, bag_size, batch_size, shuffle=True, num_workers=0, pin_memory=False, drop_last=False):
        self.bag_dataset = BagDataset(dataset, bag_size)
        self.loader = DataLoader(
            self.bag_dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=drop_last
        )

    def get_loader(self):
        return self.loader

In [7]:
def create_bag_loader(dataset, bag_size, batch_size, shuffle, num_workers, pin_memory, drop_last):
    return BagDataLoader(
        dataset=dataset,
        bag_size=bag_size,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=drop_last
    ).get_loader()

train_loader = create_bag_loader(
    train_dataset, BAG_SIZE, BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True
)

val_loader = create_bag_loader(
    val_dataset, BAG_SIZE, BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True
)

test_loader = create_bag_loader(
    mnist_test, BAG_SIZE, BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True
)

# Model configurations

## Dual-Stream Multiple Instance Learning (DSMIL)

In [8]:
class FCLayer(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super(FCLayer, self).__init__()
        self.fc = nn.Sequential(nn.Linear(input_dim, output_dim))
        
    def forward(self, x):
        return self.fc(x)
    
class InstanceClassifier(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super(InstanceClassifier, self).__init__()
        self.features_extractor = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        self.features_extractor.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.features_extractor.fc = nn.Identity()
        
        self.fc = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        batch_size, num_instances, C, H, W = x.shape
        x = x.view(batch_size * num_instances, C, H, W)
        
        instance_features = self.features_extractor(x).view(batch_size, num_instances, -1)
        classes = self.fc(instance_features)
        
        return instance_features, classes
    
class BagClassifier(nn.Module):
    def __init__(self, input_dim, output_dim=1, hidden_dim=128, dropout_v=0.2, non_linear=True, passing_v=False):
        super(BagClassifier, self).__init__()
        self.hidden_dim = hidden_dim
        
        if non_linear:
            self.q = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.Tanh()
            )
        else:
            self.q = nn.Linear(input_dim, hidden_dim)
        
        if passing_v:
            self.v = nn.Sequential(
                nn.Dropout(dropout_v),
                nn.Linear(input_dim, input_dim),
                nn.ReLU()
            )
        else:
            self.v = nn.Identity()
            
        self.fc = FCLayer(input_dim, output_dim)
        
    def forward(self, features, classes):
        batch_size = features.size(0)
        num_instances = features.size(1)
        features_dim = features.size(2)
        
        combine_features = features.view(features.shape[0] * features.shape[1], -1)
        V = self.v(combine_features)
        Q = self.q(combine_features)
        assert V.shape[0] == Q.shape[0] == batch_size * num_instances, f'V: {V.shape}, Q: {Q.shape}'
        assert V.shape[1] == features_dim, f'V: {V.shape} should be [{batch_size * num_instances}, {features_dim}]'
        assert Q.shape[1] == self.hidden_dim, f'Q: {Q.shape} should be [{batch_size * num_instances}, {self.hidden_dim}]'
        
        # Get critical instance indices by squeezing classes
        critical_indices = torch.squeeze(classes).argmax(dim=1)  # Shape [32]
        assert critical_indices.shape[0] == batch_size, f'Critical indices: {critical_indices.shape}'

        # Gather features for each batch using critical instance indices
        m_features = features[torch.arange(batch_size).unsqueeze(1), critical_indices.unsqueeze(1)].squeeze()
        assert m_features.shape[0] == batch_size, f'M features: {m_features.shape} should be [{batch_size}, {features_dim}]'
        q_max = self.q(m_features)
        assert q_max.shape[0] == batch_size and q_max.shape[1] == self.hidden_dim, f'Q max: {q_max.shape} should be [{batch_size}, {self.hidden_dim}]'
        
        A = torch.mm(Q, q_max.mT)
        A = F.softmax(A / torch.sqrt(torch.tensor(Q.shape[-1]).float()), dim=0)
        assert A.shape[0] == batch_size * num_instances and A.shape[1] == batch_size, f'A: {A.shape} should be [{batch_size * num_instances}, {batch_size}]'
        
        B = torch.mm(A.T, V)
        assert B.shape[0] == batch_size and B.shape[1] == features_dim, f'B: {B.shape} should be [{batch_size}, {features_dim}]'
        
        B = B.view(1, B.shape[0], B.shape[1])
        C = self.fc(B)
        C = C.view(1, -1)
        
        return C, A, B
    
class MILNetwork(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super(MILNetwork, self).__init__()
        self.instance_classifier = InstanceClassifier(input_dim)
        self.bag_classifier = BagClassifier(input_dim)
        
    def forward(self, x):
        instance_features, classes = self.instance_classifier(x)
        predicted_bags, A, B = self.bag_classifier(instance_features, classes)
        
        return classes, predicted_bags, A, B

## Attention Layer

In [9]:
class SecondStreamDSMIL(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SecondStreamDSMIL, self).__init__()

        self.W_q = nn.Linear(input_dim, input_dim)
        self.W_v = nn.Linear(input_dim, input_dim)
        self.W_b = nn.Linear(input_dim, output_dim)

        # You can choose any activation function here
        self.activation = nn.ReLU()  # Example: using ReLU

    def forward(self, B):
        # Linear transformations followed by non-linear activation
        q = self.activation(self.W_q(B))
        v = self.activation(self.W_v(B))

        # Max pooling over the value vectors
        h_m = torch.max(v, dim=0)[0]

        # Compute cosine similarity and attention weights
        U = F.softmax(F.cosine_similarity(q.unsqueeze(1), h_m.unsqueeze(0), dim=-1), dim=1).unsqueeze(-1)

        # Weighted sum of values
        return self.W_b(torch.sum(U * v, dim=1))


# class SecondStreamDSMIL(nn.Module):
#     def __init__(self, input_dim, output_dim):
#         super(SecondStreamDSMIL, self).__init__()
# 
#         self.W_q = nn.Linear(input_dim, input_dim)
#         self.W_v = nn.Linear(input_dim, input_dim)
#         self.W_b = nn.Linear(input_dim, output_dim)
# 
#     def forward(self, B):
#         h_m = torch.max(B @ self.W_v.weight.T, dim=0)[0]
#         q, v = self.W_q(B), self.W_v(B)
#         U = F.softmax(F.cosine_similarity(q.unsqueeze(1), h_m.unsqueeze(0), dim=-1), dim=1).unsqueeze(-1)
#         return self.W_b(torch.sum(U * v, dim=1))

# class SecondStreamDSMIL(nn.Module):
#     def __init__(self, input_dim, output_class, non_linear=True, passing_v=False, dropout_v=0.2):
#         super(SecondStreamDSMIL, self).__init__()
#         if non_linear:
#             self.q = nn.Sequential(
#                 nn.Linear(input_dim, 128),
#                 nn.ReLU(),
#                 nn.Linear(128, 128),
#                 nn.Tanh())
#         else:
#             self.q = nn.Linear(input_dim, 128)
# 
#         if passing_v:
#             self.v = nn.Sequential(
#                 nn.Dropout(dropout_v),
#                 nn.Linear(input_dim, input_dim),
#                 nn.ReLU())
#         else:
#             self.v = nn.Identity()
# 
#         # 1D Convolutional layer that can handle multiple classes (including binary)
#         self.fcc = nn.Conv1d(output_class, output_class, kernel_size=input_dim)
# 
#     def forward(self, features, classes): # N * K, N * C
#         V = self.v(features) # N * V, unsorted
#         Q = self.q(features).view(features.shape[0], -1) # N * Q, unsorted 
#         print(f'Shape of V: {V.shape}, Q: {Q.shape}')
# 
#         # handle multiple classes
#         _, m_indices = torch.sort(classes, 0,descending=True) # sort class scores along the instance dimension, m_indices in shape N x C
#         print(f'Shape of m_indices: {m_indices.shape}')
#         m_features = torch.index_select(features, dim=0, index=m_indices[0, :]) # Select critical instances based on class scores, m_features in shape C x K
#         print(f'Shape of m_features: {m_features.shape}')
#         q_max = self.q(m_features) # Extract features from critical instances, q_max in shape C x Q
#         print(f'Shape of q_max: {q_max.shape}')
#         A = torch.mm(Q, q_max.T) # Compute attention weights, A in shape N x C
#         print(f'Shape of A after mm: {A.shape}') 
#         A = F.softmax(A / torch.sqrt(torch.tensor(Q.shape[-1]).float()), dim=0) # normalize attention scores, A in shape N x C,
#         print(f'Shape of A after softmax: {A.shape}')
#         B = torch.mm(A.T, V) # compute bag representation, B in shape C x V
#         print(f'Shape of B after mm: {B.shape}')
# 
#         B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V
#         print(f'Shape of B after view: {B.shape}')
#         C = self.fcc(B) # 1 x C x 1
#         print(f'Shape of C after fcc: {C.shape}')
#         print(f'C after fcc: {C}')
#         C = C.view(1, -1)
#         print(f'Shape of C after view: {C.shape}')
#         return C, A, B 

class AttentionLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
          nn.Linear(input_dim, hidden_dim), # in = out
          nn.PReLU(),
          nn.Linear(hidden_dim, 1) # 1 = num of classes
        )

    def forward(self, x):
        # x shape: (batch_size, num_instances, feature_dim)
        attention_weights = self.attention(x)
        weights = F.softmax(attention_weights, dim=1)
        return (x * weights).sum(dim=1), weights.squeeze(-1)

class DualStreamMIL(nn.Module):
    def __init__(self, input_dim, output_dim, mode='all'):
        super(DualStreamMIL, self).__init__()
        self.input_dim = input_dim
        self.mode = mode  # New parameter to handle different modes

        # Initialize ResNet18 with custom input channels
        self.resnet18 = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        self.resnet18.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet18.fc = nn.Identity()

        self.attention_layer = AttentionLayer(input_dim, input_dim)
        # self.second_stream = SecondStreamDSMIL(input_dim, output_class=1)
        self.second_stream = SecondStreamDSMIL(input_dim, input_dim)

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.25),
            nn.Linear(in_features=512, out_features=128),
            nn.ReLU(),          
            nn.Dropout(p=0.25),
            nn.Linear(in_features=128, out_features=128),
            nn.Tanh(),
            nn.Dropout(p=0.25),
            nn.Linear(in_features=128, out_features=1)  # 1 = num of classes
        )

        self.fc = nn.Linear(input_dim, 1)

    def forward(self, x):
        N, bag_size, C, H, W = x.shape
        x = x.view(N * bag_size, C, H, W)

        features = nn.Dropout(0.2)(self.resnet18(x).view(N, bag_size, -1))

        # Feature extraction
        max_aggregation = torch.max(features, dim=1)[0]
        classes = self.classifier(max_aggregation)

        if self.mode == 'max-pooling':
            return classes

        elif self.mode == 'dual-stream':
            # Use max aggregation and second stream features
            second_stream_features = torch.max(self.second_stream(features), dim=1)[0]
            # predicted_bags, A, B = self.second_stream(features, classes)
            return classes, self.classifier(second_stream_features)

        elif self.mode == 'all':
            # Use a combination of max, second stream and attention features
            attention_features, _ = self.attention_layer(features)
            second_stream_features = torch.max(self.second_stream(features), dim=1)[0]

            return classes, self.classifier(attention_features), self.classifier(second_stream_features)

        else:
            raise ValueError("Invalid mode selected. Choose from 'max-pooling', 'dual-stream', or 'all'.")

        return None

## Model, criterion, and optimizer setup

In [10]:
# Model, criterion, and optimizer setup
# model = DualStreamMIL(INPUT_DIM, OUTPUT_DIM, 'dual-stream').to(DEVICE)
model = MILNetwork(INPUT_DIM, OUTPUT_DIM).to(DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Training loop

In [11]:
# Training loop
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    train_labels = []
    train_preds = []
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        # max_agg, second_stream = model(images)
        classes, predicted_bags, A, B = model(images)
        max_agg = torch.max(classes, dim=1)[0]
        loss_max = criterion(max_agg.squeeze(), labels.float())
        loss_bag = criterion(predicted_bags.squeeze(), labels.float())
        
        loss = 0.5 * loss_max + 0.5 * loss_bag
        loss = loss.mean()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        preds = torch.sigmoid(predicted_bags).squeeze().round()
        
        train_labels.extend(labels.cpu().numpy())
        train_preds.extend(preds.detach().cpu().numpy())

    train_accuracy = accuracy_score(train_labels, train_preds)
    
    # Validation loop
    model.eval()
    val_loss = 0
    val_labels = []
    val_preds = []
    with torch.inference_mode():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            # max_agg, second_stream = model(images)
            classes, predicted_bags, A, B = model(images)
            max_agg = torch.max(classes, dim=1)[0]
            loss_max = criterion(max_agg.squeeze(), labels.float())
            loss_bag = criterion(predicted_bags.squeeze(), labels.float())
            
            loss = 0.5 * loss_max + 0.5 * loss_bag
            loss = loss.mean()
            val_loss += loss.item()
            preds = torch.sigmoid(predicted_bags).squeeze().round()
            
            val_labels.extend(labels.cpu().numpy())
            val_preds.extend(preds.detach().cpu().numpy())

    val_accuracy = accuracy_score(val_labels, val_preds)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Train Accuracy: {train_accuracy:.4f}, Validation Loss: {val_loss/len(val_loader):.4f}, Validation Accuracy: {val_accuracy:.4f}")

Epoch 1/5, Train Loss: 0.5912, Train Accuracy: 0.5391, Validation Loss: 0.6654, Validation Accuracy: 0.5714
Epoch 2/5, Train Loss: 0.3144, Train Accuracy: 0.6738, Validation Loss: 0.4164, Validation Accuracy: 0.8125
Epoch 3/5, Train Loss: 0.0883, Train Accuracy: 0.9746, Validation Loss: 0.1929, Validation Accuracy: 0.9375
Epoch 4/5, Train Loss: 0.0295, Train Accuracy: 0.9961, Validation Loss: 0.0700, Validation Accuracy: 0.9821
Epoch 5/5, Train Loss: 0.0317, Train Accuracy: 0.9961, Validation Loss: 0.0672, Validation Accuracy: 0.9821


In [12]:
# Test loop
model.eval()
test_loss = 0
test_labels = []
test_preds = []
with torch.inference_mode():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        # max_agg, second_stream = model(images)
        classes, predicted_bags, A, B = model(images)
        max_agg = torch.max(classes, dim=1)[0]
        loss_max = criterion(max_agg.squeeze(), labels.float())
        loss_bag = criterion(predicted_bags.squeeze(), labels.float())
        loss = 0.5 * loss_max + 0.5 * loss_bag
        loss = loss.mean()
        test_loss += loss.item()
        preds = torch.sigmoid(predicted_bags).squeeze().round()
        test_labels.extend(labels.cpu().numpy())
        test_preds.extend(preds.detach().cpu().numpy())
        
test_accuracy = accuracy_score(test_labels, test_preds)
print(f"Test Loss: {test_loss/len(test_loader):.4f}, Test Accuracy: {test_accuracy:.4f}")

Test Loss: 0.1158, Test Accuracy: 0.9655
