# Import Libraries

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import v2 as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

# Data Preparation

In [9]:
# Load and preprocess the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_data, eval_data = train_test_split(mnist_data, train_size=0.1, random_state=42, stratify=mnist_data.targets)

# Bag-level dataset
class BagDataset(Dataset):
    def __init__(self, data, bag_size=8):
        self.data = data
        self.bag_size = bag_size

    def __len__(self):
        return len(self.data) // self.bag_size

    def __getitem__(self, idx):
        images, labels = [], []
        for i in range(self.bag_size):
            img, label = self.data[idx * self.bag_size + i]
            images.append(img)
            labels.append(label)
        bag_label = 1 if 9 in labels else 0
        return torch.stack(images), torch.tensor(bag_label)

train_dataset = BagDataset(train_data)
eval_dataset = BagDataset(eval_data)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=4, drop_last=True)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=4, drop_last=True)

# Model Definition

## NTXentLoss

In [10]:
# Define NTXentLoss (provided by you)
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0)
        z = torch.cat([z_i, z_j], dim=0)
        z = F.normalize(z, dim=1)
        similarity_matrix = torch.mm(z, z.T) / self.temperature
        mask = torch.eye(2 * batch_size, device=z.device).bool()
        similarity_matrix.masked_fill_(mask, -float('inf'))
        exp_sim = torch.exp(similarity_matrix)
        denominator = exp_sim.sum(dim=1)
        positive_samples = torch.cat(
            [torch.arange(batch_size, 2 * batch_size), torch.arange(batch_size)], dim=0
        ).to(z.device)
        positives = similarity_matrix[torch.arange(2 * batch_size), positive_samples]
        loss = -torch.log(torch.exp(positives) / denominator)
        return loss.mean()

## Dual-Stream MIL Model (DSMIL)

In [11]:
class FCLayer(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super(FCLayer, self).__init__()
        self.fc = nn.Sequential(nn.Linear(input_dim, output_dim))
        
    def forward(self, x):
        return self.fc(x)
    
class InstanceClassifier(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super(InstanceClassifier, self).__init__()
        self.features_extractor = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.features_extractor.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.features_extractor.fc = nn.Identity()
        
        self.fc = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        batch_size, num_instances, C, H, W = x.shape
        x = x.view(batch_size * num_instances, C, H, W)
        
        instance_features = nn.Dropout(0.25)(self.features_extractor(x)).view(batch_size, num_instances, -1)
        classes = self.fc(instance_features)
        
        return instance_features, classes
    
class BagClassifier(nn.Module):
    def __init__(self, input_dim, output_dim=1, hidden_dim=128, dropout_v=0.2, non_linear=True, passing_v=False):
        super(BagClassifier, self).__init__()
        self.hidden_dim = hidden_dim
        
        if non_linear:
            self.q = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.Tanh()
            )
        else:
            self.q = nn.Linear(input_dim, hidden_dim)
        
        if passing_v:
            self.v = nn.Sequential(
                nn.Dropout(dropout_v),
                nn.Linear(input_dim, input_dim),
                nn.ReLU()
            )
        else:
            self.v = nn.Identity()
            
        self.fc = FCLayer(input_dim, output_dim)
        
    def forward(self, features, classes):
        batch_size = features.size(0)
        num_instances = features.size(1)
        features_dim = features.size(2)
        
        combine_features = features.view(features.shape[0] * features.shape[1], -1)
        V = self.v(combine_features)
        Q = self.q(combine_features)
        assert V.shape[0] == Q.shape[0] == batch_size * num_instances, f'V: {V.shape}, Q: {Q.shape}'
        assert V.shape[1] == features_dim, f'V: {V.shape} should be [{batch_size * num_instances}, {features_dim}]'
        assert Q.shape[1] == self.hidden_dim, f'Q: {Q.shape} should be [{batch_size * num_instances}, {self.hidden_dim}]'
        
        # Get critical instance indices by squeezing classes
        critical_indices = torch.squeeze(classes).argmax(dim=1)  # Shape [32]
        assert critical_indices.shape[0] == batch_size, f'Critical indices: {critical_indices.shape}'

        # Gather features for each batch using critical instance indices
        m_features = features[torch.arange(batch_size).unsqueeze(1), critical_indices.unsqueeze(1)].squeeze()
        assert m_features.shape[0] == batch_size, f'M features: {m_features.shape} should be [{batch_size}, {features_dim}]'
        q_max = self.q(m_features)
        assert q_max.shape[0] == batch_size and q_max.shape[1] == self.hidden_dim, f'Q max: {q_max.shape} should be [{batch_size}, {self.hidden_dim}]'
        
        A = torch.mm(Q, q_max.mT)
        A = F.softmax(A / torch.sqrt(torch.tensor(Q.shape[-1]).float()), dim=0)
        assert A.shape[0] == batch_size * num_instances and A.shape[1] == batch_size, f'A: {A.shape} should be [{batch_size * num_instances}, {batch_size}]'
        
        B = torch.mm(A.T, V)
        assert B.shape[0] == batch_size and B.shape[1] == features_dim, f'B: {B.shape} should be [{batch_size}, {features_dim}]'
        
        B = B.view(1, B.shape[0], B.shape[1])
        C = self.fc(B)
        C = C.view(1, -1)
        
        return C, A, B

## Encoder Model

In [12]:
class Encoder(nn.Module):
    def __init__(self, base_model, projection_dim=128):
        super(Encoder, self).__init__()
        self.encoder = base_model
        self.projection = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )
        
        self.instance_classifier = InstanceClassifier(512)
        self.bag_classifier = BagClassifier(512)
        
    def forward(self, x):
        batch_size, num_instances, channels, height, width = x.size()
        
        instances_features, classes = self.instance_classifier(x)
        
        features = instances_features.view(batch_size * num_instances, -1)  # Flatten to (batch_size * num_instances, feature_dim)
        
        projection_features = self.projection(features)
        
        predicted_bags, A, B = self.bag_classifier(instances_features, classes)
        
        return projection_features, classes, predicted_bags, A, B

# Augmentation function
def augment_batch(batch_images):
    batch_size, num_instances, channels, height, width = batch_images.shape
    aug_transform = transforms.Compose([
        # transforms.RandomResizedCrop(36, scale=(0.8, 1)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3)], p=0.6),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        # transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # Apply transformation to each image instance in the batch
    augmented_batch = []
    for i in range(batch_size):
        augmented_instances = [aug_transform(transforms.ToPILImage()(img.cpu())) for img in batch_images[i]]
        augmented_batch.append(torch.stack(augmented_instances))
    
    return torch.stack(augmented_batch).cuda()  # Move the augmented batch to GPU


# Model Initialization

In [13]:
# Initialize models, loss function, and optimizer
base_model = models.resnet18(weights=None)
base_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
base_model.fc = nn.Identity()
encoder = Encoder(base_model).cuda()
projection_dim = 256
ntxent_loss = NTXentLoss().cuda()

# Model Training & Evaluation

In [14]:
# Training parameters
epochs = 150
learning_rate = 1e-3
optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
bceLoss = nn.BCEWithLogitsLoss()


correct_predictions = 0
total_samples = 0
# Training loop for contrastive learning
for epoch in range(epochs):
    encoder.train()
    total_loss = 0
    for images, labels in train_loader:
        images = images.cuda()  # Move the batch to GPU
        labels = labels.cuda()
        # aug1 = augment_batch(images).cuda()
        aug1 = images 
        aug2 = augment_batch(images).cuda()
        
        z_i, outputs_1, predicted_bags_1, _, _ = encoder(aug1)
        z_j, outputs_2, predicted_bags_2, _, _ = encoder(aug2)
        
        NTXLoss = ntxent_loss(z_i, z_j)
        max_agg_1 = torch.max(outputs_1, dim=1).values.squeeze()
        max_agg_2 = torch.max(outputs_2, dim=1).values.squeeze()
        
        loss_max_1 = bceLoss(max_agg_1, labels.float())
        loss_max_2 = bceLoss(max_agg_2, labels.float())
        loss_bag_1 = bceLoss(predicted_bags_1.squeeze(), labels.float())
        loss_bag_2 = bceLoss(predicted_bags_2.squeeze(), labels.float())
        
        loss = 0.2 * NTXLoss + 0.2 * loss_max_1 + 0.2 * loss_max_2 + 0.2 * loss_bag_1 + 0.2 * loss_bag_2
        loss = loss.mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        # Calculate predictions and update correct predictions count
        predicted = (torch.sigmoid(predicted_bags_1.squeeze()) > 0.5).float()  # Binary classification threshold
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    # Calculate average loss and accuracy for this epoch
    avg_loss = total_loss / len(train_loader)
    accuracy = (correct_predictions / total_samples) * 100  # Convert to percentage
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.5f}, Accuracy: {accuracy:.2f}%")
    
# Evaluation loop
with torch.no_grad():
    encoder.eval()
    correct, total = 0, 0
    for images, labels in eval_loader:
        images = images.cuda()
        labels = labels.cuda()
        z_i, output, predicted_bags, _, _ = encoder(images)
        predicted = (torch.sigmoid(predicted_bags.squeeze()) > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f"Test Accuracy: {correct/total}")



Epoch [1/150], Loss: 1.78608, Accuracy: 55.71%
Epoch [2/150], Loss: 1.64287, Accuracy: 60.39%
Epoch [3/150], Loss: 1.42234, Accuracy: 69.20%
Epoch [4/150], Loss: 1.30298, Accuracy: 75.17%
Epoch [5/150], Loss: 1.26689, Accuracy: 78.97%
Epoch [6/150], Loss: 1.25020, Accuracy: 81.25%
Epoch [7/150], Loss: 1.19878, Accuracy: 83.39%
Epoch [8/150], Loss: 1.13303, Accuracy: 85.26%
Epoch [9/150], Loss: 1.11858, Accuracy: 86.68%
Epoch [10/150], Loss: 1.11586, Accuracy: 87.84%
Epoch [11/150], Loss: 1.09246, Accuracy: 88.76%
Epoch [12/150], Loss: 1.09648, Accuracy: 89.47%
Epoch [13/150], Loss: 1.09496, Accuracy: 90.09%
Epoch [14/150], Loss: 1.07725, Accuracy: 90.69%
Epoch [15/150], Loss: 1.06048, Accuracy: 91.23%
Epoch [16/150], Loss: 1.07648, Accuracy: 91.63%
Epoch [17/150], Loss: 1.05111, Accuracy: 92.06%
Epoch [18/150], Loss: 1.03062, Accuracy: 92.47%
Epoch [19/150], Loss: 0.99908, Accuracy: 92.84%
Epoch [20/150], Loss: 1.01378, Accuracy: 93.18%
Epoch [21/150], Loss: 0.99616, Accuracy: 93.48%
E