# Attention Layer

## Import Lib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.xpu import device
from torchvision import models
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score

import gpytorch

import matplotlib.pyplot as plt
import numpy as np
import random

In [None]:
# Set a fixed seed value
seed_value = 40
# Set the random seed for Python's built-in random module
random.seed(seed_value)
# Set the random seed for NumPy
np.random.seed(seed_value)
# Set the random seed for PyTorch
torch.manual_seed(seed_value)

# If using CUDA, set the seed for GPU as well (if applicable)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed_value)

In [None]:
import warnings
from sklearn.exceptions import UndefinedMetricWarning

warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

## Configurations

In [None]:
# # Hyperparameters
batch_size = 32
learning_rate = 3e-4
num_epochs = 10
img_size = 28  # MNIST images are 28x28 pixels
num_classes = 1  # Digits from 0 to 9
patch_size = 7  # Size of each patch (7x7)
embedding_dim = 7*7  # Dimensionality of the embeddings
num_heads = 7  # Number of attention heads
num_layers = 6  # Number of transformer layers
dropout_rate = 0.1  # Dropout rate for regularization

## Data Preparation

In [None]:
class DatasetGenerator:
    def __init__(self, mnist_data, n_bags=1000, min_instances=3, max_instances=5):
        self.mnist_data = mnist_data
        self.n_bags = n_bags
        self.min_instances = min_instances
        self.max_instances = max_instances
        self.empty_image = torch.zeros(1, 28, 28)  # Create an empty image tensor (1x28x28)

    def create_bags(self):
        bags = []
        labels = []
        
        for _ in range(self.n_bags):
            # Randomly choose a number of instances for the bag
            n_instances = np.random.randint(self.min_instances, self.max_instances + 1)
            
            # Randomly select instances from the dataset
            bag_indices = np.random.choice(len(self.mnist_data), n_instances, replace=False)
            bag_images = [self.mnist_data[i][0] for i in bag_indices]
            
            # Determine the label: 1 if any instance is '9', else 0
            label = 1 if any(self.mnist_data[i][1] == 9 for i in bag_indices) else 0
            
            # Convert images to tensors and pad to ensure exactly 7 instances
            bag_images_tensors = [ToTensor()(img) for img in bag_images]
            while len(bag_images_tensors) < 7:
                bag_images_tensors.append(self.empty_image)  # Pad with empty image
            
            bags.append(torch.stack(bag_images_tensors))
            labels.append(label)

        return bags, labels

class TrainDatasetGenerator(DatasetGenerator):
    def __init__(self, mnist_data, n_bags=1000):
        super().__init__(mnist_data, n_bags)

class TestDatasetGenerator(DatasetGenerator):
    def __init__(self, mnist_data, n_bags=500):  # Example: fewer bags for testing
        super().__init__(mnist_data, n_bags)

In [None]:
# Set seed for random number generators
np.random.seed(42)
torch.manual_seed(42)

# Load MNIST dataset
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True)

# Create training dataset generator and generate bags
train_generator = TrainDatasetGenerator(mnist_dataset)
train_bags, train_labels = train_generator.create_bags()
train_loader = DataLoader(list(zip(train_bags, train_labels)), batch_size=32, shuffle=True, drop_last=True)

# Create test dataset generator and generate bags
test_generator = TestDatasetGenerator(mnist_dataset)
test_bags, test_labels = test_generator.create_bags()
test_loader = DataLoader(list(zip(test_bags, test_labels)), batch_size=32, shuffle=False, drop_last=True)

## Patch Embeddings

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, in_channels, patch_size, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=embed_dim,
                              kernel_size=patch_size, stride=patch_size, padding=0)

    def forward(self, x):
        """
        Forward pass to create patch embeddings.

        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).

        Returns:
            torch.Tensor: Output tensor of shape (N, num_patches, embed_dim).
        """
        # Handle case of 1-channel image (grayscale)
        if len(x.shape) == 3:
            x = x.unsqueeze(1)
            
        x = self.conv(x)  # Apply convolution to create patches 
        # After convolution: shape (N, embed_dim, H/patch_size, W/patch_size)
        
        x = x.flatten(2)  # Flatten patches into a sequence 
        # After flattening: shape (N, embed_dim, num_patches)
        # print(f'Shape of x after patch embedding: {x.shape}')

        return x.transpose(1, 2)  # Rearrange dimensions for transformer input 
        # Final output shape: (N, num_patches, embed_dim)

## Multi-Head Self-Attention Layer

In [None]:
# 1. Create a class that inherits from nn.Module
class MultiheadSelfAttentionBlock(nn.Module):
    """Creates a multi-head self-attention block ("MSA block" for short).
    """
    # 2. Initialize the class with hyperparameters from Table 1
    def __init__(self,
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 attn_dropout:float=0): # doesn't look like the paper uses any dropout in MSABlocks
        super().__init__()

        # 3. Create the Norm layer (LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        # 4. Create the Multi-Head Attention (MSA) layer
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                    num_heads=num_heads,
                                                    dropout=attn_dropout,
                                                    batch_first=True) # does our batch dimension come first?

    # 5. Create a forward() method to pass the data through the layers
    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, attn_weights = self.multihead_attn(query=x, # query embeddings
                                             key=x, # key embeddings
                                             value=x, # value embeddings
                                             need_weights=True) # do we need the weights or just the layer outputs?
        assert type(attn_output) == torch.Tensor, "The MSA block output should be a PyTorch tensor."
        return attn_output

## MLP Block Layer

In [None]:
# 1. Create a class that inherits from nn.Module
class MLPBlock(nn.Module):
    """Creates a layer normalized multilayer perceptron block ("MLP block" for short)."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 embedding_dim:int=768, # Hidden Size D from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 dropout:float=0.1): # Dropout from Table 3 for ViT-Base
        super().__init__()

        # 3. Create the Norm layer (LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        # 4. Create the Multilayer perceptron (MLP) layer(s)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim,
                      out_features=mlp_size),
            nn.GELU(), # "The MLP contains two layers with a GELU non-linearity (section 3.1)."
            nn.Dropout(p=dropout),
            nn.Linear(in_features=mlp_size, # needs to take same in_features as out_features of layer above
                      out_features=embedding_dim), # take back to embedding_dim
            nn.Dropout(p=dropout) # "Dropout, when used, is applied after every dense layer.."
        )

    # 5. Create a forward() method to pass the data through the layers
    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x

## Transformer Encoder Layer

In [None]:
# 1. Create a class that inherits from nn.Module
class TransformerEncoderBlock(nn.Module):
    """Creates a Transformer Encoder block."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 mlp_dropout:float=0.1, # Amount of dropout for dense layers from Table 3 for ViT-Base
                 attn_dropout:float=0): # Amount of dropout for attention layers
        super().__init__()

        # 3. Create MSA block (equation 2)
        self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
                                                     num_heads=num_heads,
                                                     attn_dropout=attn_dropout)
        # 4. Create MLP block (equation 3)
        self.mlp_block =  MLPBlock(embedding_dim=embedding_dim,
                                   mlp_size=mlp_size,
                                   dropout=mlp_dropout)

    # 5. Create a forward() method
    def forward(self, x):
        # 6. Create residual connection for MSA block (add the input to the output)
        # x =  self.msa_block(x) + x
        attn_output = self.msa_block(x)
        assert type(attn_output) == torch.Tensor, "The MSA block output should be a PyTorch tensor."
        x = attn_output + x

        # 7. Create residual connection for MLP block (add the input to the output)
        x = self.mlp_block(x) + x

        return x

## Attention Layer

In [None]:
class AttentionLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim=512):
        super(AttentionLayer, self).__init__()
        # Sequential model for attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),  # Shape: (input_dim, hidden_dim)
            nn.Tanh(),                          # Activation function
            nn.Linear(hidden_dim, 1)           # Shape: (hidden_dim, 1)
        )

    def forward(self, x):
        # x shape: (batch_size, num_instances, feature_dim)
        attention_weights = self.attention(x)  
        # Shape of attention_weights: (batch_size, num_instances, 1)

        weights = F.softmax(attention_weights, dim=1)  
        # Shape of weights: (batch_size, num_instances, 1)

        # Element-wise multiplication followed by summation over num_instances
        weighted_sum = (x * weights).sum(dim=1)  
        # Shape of weighted_sum: (batch_size, feature_dim)

        return weighted_sum, weights.squeeze(-1)  
        # Returns weighted sum and attention weights with shape (batch_size, num_instances)

## Vision Transformer Model

In [None]:
# 1. Create a ViT class that inherits from nn.Module
class ViT(nn.Module):
    """Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 img_size:int=224, # Training resolution from Table 3 in ViT paper
                 in_channels:int=3, # Number of channels in input image
                 patch_size:int=16, # Patch size
                 num_transformer_layers:int=12, # Layers from Table 1 for ViT-Base
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 attn_dropout:float=0, # Dropout for attention projection
                 mlp_dropout:float=0.1, # Dropout for dense/MLP layers
                 embedding_dropout:float=0.1, # Dropout for patch and position embeddings
                 num_classes:int=1000): # Default for ImageNet but can customize this
        super().__init__() # don't forget the super().__init__()!

        # 3. Make the image size is divisible by the patch size
        assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."

        # 4. Calculate number of patches (height * width/patch^2)
        self.num_patches = (img_size * img_size) // patch_size**2

        # 5. Create learnable class embedding (needs to go at front of sequence of patch embeddings)
        self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
                                            requires_grad=True)

        # 6. Create learnable position embedding
        self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
                                               requires_grad=True)

        # 7. Create embedding dropout value
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)

        # 8. Create patch embedding layer
        self.patch_embedding = PatchEmbedding(img_size=img_size,
                                              in_channels=in_channels,
                                              patch_size=patch_size,
                                              embed_dim=embedding_dim)

        # 9. Create Transformer Encoder blocks (we can stack Transformer Encoder blocks using nn.Sequential())
        # Note: The "*" means "all"
        self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                                            num_heads=num_heads,
                                                                            mlp_size=mlp_size,
                                                                            mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])

        # 10. Create classifier head
        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dim),
            nn.Linear(in_features=embedding_dim,
                      out_features=num_classes)
        )
        self.attention = AttentionLayer(input_dim=833)

    # 11. Create a forward() method
    def forward(self, x):
        # 12. Get batch size
        batch_size, num_instance = x.shape[0], x.shape[1]
        x = x.view(-1, 1, 28, 28)
        # 13. Create class token embedding and expand it to match the batch size (equation 1)
        class_token = self.class_embedding.expand(batch_size * num_instance, -1, -1) # "-1" means to infer the dimension (try this line on its own)
        # 14. Create patch embedding (equation 1)
        x = self.patch_embedding(x)
        # print(f'Shape of x after patch embedding: {x.shape}')
        # 15. Concat class embedding and patch embedding (equation 1)
        x = torch.cat((class_token, x), dim=1)
        # 16. Add position embedding to patch embedding (equation 1)
        x = self.position_embedding + x
        # 17. Run embedding dropout (Appendix B.1)
        x = self.embedding_dropout(x)
        # 18. Pass patch, position and class embedding through transformer encoder layers (equations 2 & 3)
        x = self.transformer_encoder(x)
        # print(f'Shape of x after transformer encoder: {x.shape}')
        
        x = x.view(batch_size, num_instance, -1, embedding_dim)
        # print(f'Shape of x after transformer encoder and reshape: {x.shape}')
        x = x.max(dim=1)[0]
        # print(f'Shape of x after mean pooling: {x.shape}')
        
        # x = x.view(batch_size, num_instance, -1)
        # # print(f'Shape of x after transformer encoder and reshape: {x.shape}')
        # attention_features, attention_weights = self.attention(x)
        # # print(f'Shape of attention_features: {attention_features.shape}')
        # x = attention_features.view(-1, 833)
        # # print(f'Shape of x after attention layer: {x.shape}')
        # x = x.view(batch_size, 17, 49)
        # 19. Put 0 index logit through classifier (equation 4)
        x = self.classifier(x[:, 0]) # run on each sample in a batch at 0 index
        
        # print(f'Shape of x after classifier: {x.shape}')
        return torch.sigmoid(x)

## Model & Optimizer Definition

In [None]:
import torch.optim as optim

model = ViT(
    img_size=img_size,
    in_channels=1,
    patch_size=patch_size,
    embedding_dim=embedding_dim,
    num_heads=num_heads,
    num_transformer_layers=num_layers,
    num_classes=num_classes
)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

## Training Step

In [None]:
def train(model, dataloader, epochs=5):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
    criterion = nn.BCELoss()

    model.train()

    for epoch in range(epochs):
        all_labels = []
        all_outputs = []
        total_loss = 0

        for batch_images, batch_labels in dataloader:
            batch_images = batch_images.to(device)
            batch_labels = batch_labels.to(device)
            optimizer.zero_grad()

            # Forward pass
            outputs = model(batch_images.float())
            # batch_labels = batch_labels.view(-1, 1)
            loss = criterion(outputs.squeeze(), batch_labels.float())
            
            total_loss += loss.item()

            # Backward pass
            loss.backward()
            optimizer.step()

            # Collect outputs and labels for metrics calculation
            all_labels.extend(batch_labels.cpu().numpy())
            all_outputs.extend((outputs.squeeze().cpu().detach().numpy() > 0.5).astype(int))  # Binarize outputs
            # all_outputs.extend(outputs.argmax(dim=1).cpu().numpy())
            # print(f'Acc {outputs.argmax(dim=1).cpu().numpy()}')

        # Calculate metrics
        avg_loss = total_loss / len(dataloader)
        accuracy = accuracy_score(all_labels, all_outputs)
        recall = recall_score(all_labels, all_outputs)
        precision = precision_score(all_labels, all_outputs)
        f1 = f1_score(all_labels, all_outputs)

        print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, '
              f'Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f},F1 Score: {f1:.4f}')


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViT(
    img_size=img_size,
    in_channels=1,
    patch_size=patch_size,
    embedding_dim=embedding_dim,
    num_heads=num_heads,
    num_transformer_layers=num_layers,
    num_classes=num_classes
).to(device)
train(model, train_loader)

In [None]:
def test(model, dataloader):
    model.eval()
    all_labels = []
    all_outputs = []
    all_attended_weights = []  # Store attended weights for visualization
    images_to_plot = []  # Store images with label = 1
    weights_to_plot = []  # Store attended weights for images with label = 1

    with torch.no_grad():
        for batch_images, batch_labels in dataloader:
            batch_images = batch_images.to(device)
            batch_labels = batch_labels.to(device)

            # Forward pass
            outputs = model(batch_images.float())

            # Collect outputs and labels for metrics calculation
            all_labels.extend(batch_labels.cpu().numpy())
            all_outputs.extend((outputs.squeeze().cpu().detach().numpy() > 0.5).astype(int))  # Binarize outputs

            # # Check for images with label = 1
            # for i in range(len(batch_labels)):
            #     if batch_labels[i] == 1:
            #         images_to_plot.append(batch_images[i].cpu().numpy())
            #         weights_to_plot.append(attended_weights[i].squeeze().cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_outputs)
    recall = recall_score(all_labels, all_outputs)
    precision = precision_score(all_labels, all_outputs)
    f1 = f1_score(all_labels, all_outputs)

    print(f'Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f},F1 Score: {f1:.4f}')

    # # Plotting attended weights for images with label = 1
    # if images_to_plot:  # Check if there are any images to plot
    #     # plot_attended_weights(np.array(images_to_plot), np.array(weights_to_plot))
    #     plot_self_attention(np.array(images_to_plot[0]), np.array(weights_to_plot))

# Call the test function with your model and test loader
test(model, test_loader)

## References:
[1] https://medium.com/@wangdk93/implement-self-attention-and-cross-attention-in-pytorch-1f1a366c9d4b