In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCELoss
# import torchvision.transforms as transforms
from torchvision.transforms import v2 as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

In [2]:
# Define NTXentLoss (provided by you)
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0)
        z = torch.cat([z_i, z_j], dim=0)
        z = F.normalize(z, dim=1)
        similarity_matrix = torch.mm(z, z.T) / self.temperature
        mask = torch.eye(2 * batch_size, device=z.device).bool()
        similarity_matrix.masked_fill_(mask, -float('inf'))
        exp_sim = torch.exp(similarity_matrix)
        denominator = exp_sim.sum(dim=1)
        positive_samples = torch.cat(
            [torch.arange(batch_size, 2 * batch_size), torch.arange(batch_size)], dim=0
        ).to(z.device)
        positives = similarity_matrix[torch.arange(2 * batch_size), positive_samples]
        loss = -torch.log(torch.exp(positives) / denominator)
        return loss.mean()

In [3]:
# Load and preprocess the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_data, eval_data = train_test_split(mnist_data, train_size=0.3, random_state=42, stratify=mnist_data.targets)

# Bag-level dataset
class BagDataset(Dataset):
    def __init__(self, data, bag_size=8):
        self.data = data
        self.bag_size = bag_size

    def __len__(self):
        return len(self.data) // self.bag_size

    def __getitem__(self, idx):
        images, labels = [], []
        for i in range(self.bag_size):
            img, label = self.data[idx * self.bag_size + i]
            images.append(img)
            labels.append(label)
        bag_label = 1 if 9 in labels else 0
        return torch.stack(images), torch.tensor(bag_label)

train_dataset = BagDataset(train_data)
eval_dataset = BagDataset(eval_data)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=4, drop_last=True)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=4, drop_last=True)



In [4]:
class Attention(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Attention, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU(),
            nn.Linear(256, out_dim)
        )

    def forward(self, x):
        attention_weights = self.attention(x)
        weights = F.softmax(attention_weights, dim=1)
        return (x * weights).sum(dim=1), weights.squeeze(-1)

In [5]:
class Encoder(nn.Module):
    def __init__(self, base_model, projection_dim=128):
        super(Encoder, self).__init__()
        self.encoder = base_model
        self.projection = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )
        
        self.attention = Attention(512, 1)
        # self.fc = nn.Linear(512, 1)
        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
    def forward(self, x):
        batch_size, num_instances, channels, height, width = x.size()
        x = x.view(-1, channels, height, width)  # Reshape to (batch_size * num_instances, channels, height, width)
        features = self.encoder(x)
        features = nn.Dropout(0.25)(features)
        features = features.view(batch_size * num_instances, -1)  # Flatten to (batch_size * num_instances, feature_dim)
        
        projection_features = self.projection(features)
        attention_features, _ = self.attention(features.view(batch_size, num_instances, -1))
        output = self.fc(attention_features)
        
        return projection_features, output

# Augmentation function
def augment_batch(batch_images):
    batch_size, num_instances, channels, height, width = batch_images.shape
    aug_transform = transforms.Compose([
        # transforms.RandomResizedCrop(224, scale=(0.8, 1.1)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3)], p=0.6),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        # transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # Apply transformation to each image instance in the batch
    augmented_batch = []
    for i in range(batch_size):
        augmented_instances = [aug_transform(transforms.ToPILImage()(img.cpu())) for img in batch_images[i]]
        augmented_batch.append(torch.stack(augmented_instances))
    
    return torch.stack(augmented_batch).cuda()  # Move the augmented batch to GPU

# Initialize models, loss function, and optimizer
base_model = models.resnet18(weights=None)
base_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
base_model.fc = nn.Identity()
encoder = Encoder(base_model).cuda()
projection_dim = 256
ntxent_loss = NTXentLoss().cuda()

In [6]:
# Training parameters
epochs = 50
learning_rate = 1e-3
optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
bceLoss = nn.BCEWithLogitsLoss()


correct_predictions = 0
total_samples = 0
# Training loop for contrastive learning
for epoch in range(epochs):
    encoder.train()
    total_loss = 0
    for images, labels in train_loader:
        images = images.cuda()  # Move the batch to GPU
        labels = labels.cuda()
        # aug1 = augment_batch(images).cuda()
        aug1 = images 
        aug2 = augment_batch(images).cuda()
        # aug2 = images
        
        z_i, outputs_1 = encoder(aug1)
        z_j, outputs_2 = encoder(aug2)
        
        NTXLoss = ntxent_loss(z_i, z_j)
        BCELoss_1 = bceLoss(outputs_1.squeeze(), labels.float())
        BCELoss_2 = bceLoss(outputs_2.squeeze(), labels.float())
        loss = 0.4 * NTXLoss + 0.3 * BCELoss_1 + 0.3 * BCELoss_2
        loss = loss.mean()
        # loss = ntxent_loss(z_i, z_j)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        # Calculate predictions and update correct predictions count
        predicted = (torch.sigmoid(outputs_1.squeeze()) > 0.5).float()  # Binary classification threshold
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    # Calculate average loss and accuracy for this epoch
    avg_loss = total_loss / len(train_loader)
    accuracy = (correct_predictions / total_samples) * 100  # Convert to percentage
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.5f}, Accuracy: {accuracy:.2f}%")
    
# Evaluation loop
with torch.no_grad():
    encoder.eval()
    correct, total = 0, 0
    for images, labels in eval_loader:
        images = images.cuda()
        labels = labels.cuda()
        z_i, output = encoder(images)
        predicted = (torch.sigmoid(output.squeeze()) > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f"Test Accuracy: {correct/total}")



Epoch [1/50], Loss: 3.03937, Accuracy: 74.31%
Epoch [2/50], Loss: 2.70825, Accuracy: 83.55%
Epoch [3/50], Loss: 2.62387, Accuracy: 87.65%
Epoch [4/50], Loss: 2.57440, Accuracy: 89.99%
Epoch [5/50], Loss: 2.53745, Accuracy: 91.61%
Epoch [6/50], Loss: 2.54068, Accuracy: 92.52%
Epoch [7/50], Loss: 2.50965, Accuracy: 93.38%
Epoch [8/50], Loss: 2.47856, Accuracy: 94.07%
Epoch [9/50], Loss: 2.46844, Accuracy: 94.65%
Epoch [10/50], Loss: 2.45623, Accuracy: 95.10%
Epoch [11/50], Loss: 2.43674, Accuracy: 95.49%
Epoch [12/50], Loss: 2.42786, Accuracy: 95.83%
Epoch [13/50], Loss: 2.42450, Accuracy: 96.08%
Epoch [14/50], Loss: 2.42110, Accuracy: 96.32%
Epoch [15/50], Loss: 2.41485, Accuracy: 96.52%
Epoch [16/50], Loss: 2.41130, Accuracy: 96.71%
Epoch [17/50], Loss: 2.40327, Accuracy: 96.88%
Epoch [18/50], Loss: 2.39383, Accuracy: 97.03%
Epoch [19/50], Loss: 2.40393, Accuracy: 97.16%
Epoch [20/50], Loss: 2.39562, Accuracy: 97.27%
Epoch [21/50], Loss: 2.39074, Accuracy: 97.38%
Epoch [22/50], Loss: 2

## Freeze the encoder and train a linear classifier

In [7]:
# # Freeze the encoder
# # Function to freeze all layers of a model
# def freeze_model(model):
#     for param in model.parameters():
#         param.requires_grad = False
# 
# # Freeze the contrastive learning model
# freeze_model(encoder)

## Define the linear classifier

In [8]:
# class LinearClassifier(nn.Module):
#     def __init__(self, in_dim, out_dim):
#         super(LinearClassifier, self).__init__()
#         self.fc = nn.Sequential(
#             nn.Linear(in_dim, 128),
#             nn.ReLU(),
#             nn.Linear(128, out_dim)
#         )
#         self.dropout = nn.Dropout(0.25)
#         self.attention = Attention(in_dim, 1)
# 
#     def forward(self, x):
#         x = x.view(32, 8, 128)
#         # # Max pooling over the bag instances
#         # x, _ = torch.max(x, dim=1)
#         x = self.dropout(x)
#         x, _ = self.attention(x)
#         return self.fc(x)

In [9]:
# # Initialize the linear classifier
# classifier = LinearClassifier(in_dim=projection_dim, out_dim=1).cuda()
# 
# # Training parameters
# epochs = 50
# learning_rate = 5e-4
# optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)
# criterion = nn.BCEWithLogitsLoss()

In [10]:
# # Training loop for linear classification with accuracy calculation
# for epoch in range(epochs):
#     classifier.train()
#     total_loss = 0
#     correct_predictions = 0
#     total_samples = 0
#     
#     for images, labels in train_loader:
#         images, labels = images.cuda(), labels.float().cuda()
#         features = encoder(images)
#         outputs = classifier(features).squeeze()
#         
#         # Calculate the loss
#         loss = criterion(outputs, labels)
#         
#         # Backpropagation and optimization
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         
#         total_loss += loss.item()
# 
#         # Calculate predictions and update correct predictions count
#         predicted = (torch.sigmoid(outputs) > 0.5).float()  # Binary classification threshold
#         correct_predictions += (predicted == labels).sum().item()
#         total_samples += labels.size(0)
# 
#     # Calculate average loss and accuracy for this epoch
#     avg_loss = total_loss / len(train_loader)
#     accuracy = (correct_predictions / total_samples) * 100  # Convert to percentage
# 
#     print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

In [11]:
# # Evaluation loop
# classifier.eval()
# encoder.eval()
# correct, total = 0, 0
# with torch.no_grad():
#     for images, labels in eval_loader:
#         images, labels = images.cuda(), labels.float().cuda()
#         # images = augment_batch(images)
#         features = encoder(images)
#         outputs = classifier(features).squeeze()
#         predicted = (torch.sigmoid(outputs) > 0.5).float()
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()
# print(f"Accuracy: {correct/total}")
# 
# # Save the model
# torch.save(encoder.state_dict(), "encoder.pth")
# torch.save(classifier.state_dict(), "classifier.pth")