# Attention Layer

## Import Lib

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.xpu import device
from torchvision import models
from torchvision import datasets
from torch.utils.data import DataLoader
import numpy as np

## Data Preparation

In [2]:
import numpy as np
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor

class DatasetGenerator:
    def __init__(self, mnist_data, n_bags=1000, min_instances=3, max_instances=5):
        self.mnist_data = mnist_data
        self.n_bags = n_bags
        self.min_instances = min_instances
        self.max_instances = max_instances
        self.empty_image = torch.empty(1, 28, 28)  # Create an empty image tensor (1x28x28)

    def create_bags(self):
        bags = []
        labels = []
        
        for _ in range(self.n_bags):
            # Randomly choose a number of instances for the bag
            n_instances = np.random.randint(self.min_instances, self.max_instances + 1)
            
            # Randomly select instances from the dataset
            bag_indices = np.random.choice(len(self.mnist_data), n_instances, replace=False)
            bag_images = [self.mnist_data[i][0] for i in bag_indices]
            
            # Determine the label: 1 if any instance is '9', else 0
            label = 1 if any(self.mnist_data[i][1] == 9 for i in bag_indices) else 0
            
            # Convert images to tensors and pad to ensure exactly 7 instances
            bag_images_tensors = [ToTensor()(img) for img in bag_images]
            while len(bag_images_tensors) < 7:
                bag_images_tensors.append(self.empty_image)  # Pad with empty image
            
            bags.append(torch.stack(bag_images_tensors))
            labels.append(label)

        return bags, labels

class TrainDatasetGenerator(DatasetGenerator):
    def __init__(self, mnist_data, n_bags=1000):
        super().__init__(mnist_data, n_bags)

class TestDatasetGenerator(DatasetGenerator):
    def __init__(self, mnist_data, n_bags=500):  # Example: fewer bags for testing
        super().__init__(mnist_data, n_bags)

In [3]:
# Load MNIST dataset
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True)

# Create training dataset generator and generate bags
train_generator = TrainDatasetGenerator(mnist_dataset)
train_bags, train_labels = train_generator.create_bags()

# Create DataLoader for training
from torch.utils.data import DataLoader

train_loader = DataLoader(list(zip(train_bags, train_labels)), batch_size=32, shuffle=True)

# Create test dataset generator and generate bags
test_generator = TestDatasetGenerator(mnist_dataset)
test_bags, test_labels = test_generator.create_bags()

# Create DataLoader for testing
test_loader = DataLoader(list(zip(test_bags, test_labels)), batch_size=32, shuffle=False)

## Attention Layer

In [4]:
class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):  # x.shape (batch_size, seq_length, input_dim)
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)

        score = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
        attention = self.softmax(score)
        weighted = torch.bmm(attention, values)
        return weighted


## MIL-CNN Model

In [5]:
class MILResNet18(nn.Module):
    def __init__(self):
        super(MILResNet18, self).__init__()
        self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        # Modify the first convolutional layer to accept grayscale images
        # Change in_channels from 3 to 1
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.resnet.fc = nn.Identity()  # Remove the final classification layer
        self.attention = SelfAttention(input_dim=512)  # Assuming output dim from ResNet is 512
        self.classifier = nn.Linear(512, 1)  # Binary classification

    def forward(self, bags):
        # bags.shape: (batch_size, num_instances, channels, height, width)
        batch_size, num_instances = bags.size(0), bags.size(1)
        
        # Flatten to (batch_size * num_instances, channels, height, width)
        bags_flattened = bags.view(-1, *bags.shape[2:])
        
        # Get features from ResNet
        features = self.resnet(bags_flattened)  # Shape: (batch_size * num_instances, 512)

        # Reshape back to (batch_size, num_instances, 512)
        features = features.view(batch_size, num_instances, -1)

        # Apply attention mechanism
        attended_features = self.attention(features)

        # Aggregate features (e.g., mean pooling)
        aggregated_features = attended_features.mean(dim=1)  # Shape: (batch_size, 512)
        # Classify bag
        outputs = torch.sigmoid(self.classifier(aggregated_features))
        
        return outputs


## Training Process

In [6]:
import torch
from sklearn.metrics import accuracy_score, precision_score, f1_score

def train(model, dataloader, epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCELoss()
    
    model.train()
    
    for epoch in range(epochs):
        all_labels = []
        all_outputs = []
        total_loss = 0
        
        for batch_images, batch_labels in dataloader:
            batch_images = batch_images.to(device)
            batch_labels = batch_labels.to(device)
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(batch_images.float())
            loss = criterion(outputs.squeeze(), batch_labels.float())
            total_loss += loss.item()
            
            # Backward pass
            loss.backward()
            optimizer.step()

            # Collect outputs and labels for metrics calculation
            all_labels.extend(batch_labels.cpu().numpy())
            all_outputs.extend((outputs.squeeze().cpu().detach().numpy() > 0.5).astype(int))  # Binarize outputs
            
        # Calculate metrics
        avg_loss = total_loss / len(dataloader)
        accuracy = accuracy_score(all_labels, all_outputs)
        precision = precision_score(all_labels, all_outputs)
        f1 = f1_score(all_labels, all_outputs)

        print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, '
              f'Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, F1 Score: {f1:.4f}')

### Testing Function
def test(model, dataloader):
    model.eval()
    all_labels = []
    all_outputs = []
    
    with torch.no_grad():
        for batch_images, batch_labels in dataloader:
            batch_images = batch_images.to(device)
            batch_labels = batch_labels.to(device)

            # Forward pass
            outputs = model(batch_images.float())
            
            # Collect outputs and labels for metrics calculation
            all_labels.extend(batch_labels.cpu().numpy())
            all_outputs.extend((outputs.squeeze().cpu().detach().numpy() > 0.5).astype(int))  # Binarize outputs

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_outputs)
    precision = precision_score(all_labels, all_outputs)
    f1 = f1_score(all_labels, all_outputs)

    print(f'Test Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, F1 Score: {f1:.4f}')


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MILResNet18()
model.to(device)
# Train the model
train(model, train_loader)
# Test the model
test(model, test_loader)

Epoch [1/10], Loss: 0.7951, Accuracy: 0.6580, Precision: 0.5652, F1 Score: 0.4062
Epoch [2/10], Loss: 0.3888, Accuracy: 0.8520, Precision: 0.8486, F1 Score: 0.7843
Epoch [3/10], Loss: 0.2651, Accuracy: 0.9080, Precision: 0.9342, F1 Score: 0.8663
Epoch [4/10], Loss: 0.1544, Accuracy: 0.9500, Precision: 0.9443, F1 Score: 0.9313
Epoch [5/10], Loss: 0.1215, Accuracy: 0.9580, Precision: 0.9632, F1 Score: 0.9418
Epoch [6/10], Loss: 0.1082, Accuracy: 0.9770, Precision: 0.9860, F1 Score: 0.9683
Epoch [7/10], Loss: 0.1572, Accuracy: 0.9510, Precision: 0.9444, F1 Score: 0.9328
Epoch [8/10], Loss: 0.0893, Accuracy: 0.9730, Precision: 0.9698, F1 Score: 0.9632
Epoch [9/10], Loss: 0.0614, Accuracy: 0.9810, Precision: 0.9808, F1 Score: 0.9741
Epoch [10/10], Loss: 0.0654, Accuracy: 0.9790, Precision: 0.9677, F1 Score: 0.9717
Test Accuracy: 0.9360, Precision: 0.9854, F1 Score: 0.8940


## References:
[1] https://medium.com/@wangdk93/implement-self-attention-and-cross-attention-in-pytorch-1f1a366c9d4b