<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/Longitudinal_Image_MHAResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

MHAResNet

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class TemporalResNet(nn.Module):
    def __init__(self, resnet_type="resnet18", num_classes=1000, embed_dim=512, num_heads=8):
        super(TemporalResNet, self).__init__()

        # Load ResNet backbone (excluding fully connected layers)
        resnet = getattr(models, resnet_type)(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-2])  # Remove avgpool and fc layers

        # Feature dimensionality
        self.feature_dim = resnet.fc.in_features

        # Projection layer to match MultiheadAttention embedding size
        self.projection = nn.Linear(self.feature_dim, embed_dim)

        # Multihead Self-Attention
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

        # Fully Connected Head
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, frame1, frame2):
        # Extract features for both frames
        f1 = self.feature_extractor(frame1)  # Shape: (B, C, H, W)
        f2 = self.feature_extractor(frame2)  # Shape: (B, C, H, W)

        # Global Average Pooling
        f1 = F.adaptive_avg_pool2d(f1, (1, 1)).squeeze(-1).squeeze(-1)  # Shape: (B, C)
        f2 = F.adaptive_avg_pool2d(f2, (1, 1)).squeeze(-1).squeeze(-1)  # Shape: (B, C)

        # Project to embedding space
        f1 = self.projection(f1)  # Shape: (B, embed_dim)
        f2 = self.projection(f2)  # Shape: (B, embed_dim)

        # Create sequence (T=2, B, embed_dim)
        features = torch.stack([f1, f2], dim=1)  # Shape: (B, 2, embed_dim)

        # Self-Attention across temporal frames
        attended_features, _ = self.attention(features, features, features)

        # Aggregate attended features (using the last time step or mean)
        fused_features = attended_features.mean(dim=1)  # Shape: (B, embed_dim)

        # Classification
        output = self.fc(fused_features)  # Shape: (B, num_classes)

        return output

# Example usage
if __name__ == "__main__":
    model = TemporalResNet(resnet_type="resnet18", num_classes=3, embed_dim=512, num_heads=8)
    frame1 = torch.randn(4, 3, 224, 224)  # Batch of 4 images
    frame2 = torch.randn(4, 3, 224, 224)
    output = model(frame1, frame2)
    print(output.shape)  # Expected: (4, 10)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 129MB/s]


torch.Size([4, 3])
