In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import timm
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class VideoRecognitionModel(nn.Module):
    def __init__(self, num_classes, num_frames, embed_dim, num_heads, num_layers, hidden_dim):
        super(VideoRecognitionModel, self).__init__()
        
        # Load a pre-trained MobileNet model from timm
        self.mobilenet = timm.create_model('mobilenetv3_large_100', pretrained=True, features_only=True)
        
        # Remove the last layer to get feature maps
        self.mobilenet.global_pool = nn.Identity()
        self.mobilenet.classifier = nn.Identity()
        
        # Feature dimension from MobileNet
        self.feature_dim = 960  # For mobilenetv3_large_100, adjust if using a different model
        
        # Transformer Encoder
        self.embed_dim = embed_dim
        self.positional_encoding = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
        encoder_layers = TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers=num_layers)
        
        # Fully connected layer for classification
        self.fc = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        # x shape: (batch_size, num_frames, channels, height, width)
        batch_size, num_frames, channels, height, width = x.shape
        
        # Process each frame through MobileNet
        features = []
        for t in range(num_frames):
            frame = x[:, t, :, :, :]  # Extract frame at time t
            frame_features = self.mobilenet(frame)  # Extract features using MobileNet
            frame_features = frame_features[-1]  # Use the last feature map
            frame_features = frame_features.mean([2, 3])  # Global average pooling
            features.append(frame_features)
        
        # Stack features along the time dimension
        features = torch.stack(features, dim=1)  # Shape: (batch_size, num_frames, feature_dim)
        
        # Project features to the embedding dimension
        features = nn.Linear(self.feature_dim, self.embed_dim)(features)
        
        # Add positional encoding
        features = features + self.positional_encoding
        
        # Pass through Transformer Encoder
        transformer_output = self.transformer_encoder(features)  # Shape: (batch_size, num_frames, embed_dim)
        
        # Aggregate over time (e.g., mean pooling)
        aggregated_output = transformer_output.mean(dim=1)
        
        # Final classification
        output = self.fc(aggregated_output)
        return output



In [2]:
# Example usage

    # Hyperparameters
num_classes = 10  # Number of classes for classification
num_frames = 16   # Number of frames in the video
embed_dim = 512   # Embedding dimension for Transformer
num_heads = 8     # Number of attention heads
num_layers = 2    # Number of Transformer layers
hidden_dim = 1024 # Hidden dimension in Transformer feed-forward network
    
    # Initialize model
model = VideoRecognitionModel(num_classes, num_frames, embed_dim, num_heads, num_layers, hidden_dim)
    
    # Dummy input (batch_size, num_frames, channels, height, width)
dummy_input = torch.randn(2, num_frames, 3, 224, 224)
    
    # Forward pass
output = model(dummy_input)
print("Output shape:", output.shape)  # Should be (batch_size, num_classes)

model.safetensors:   0%|          | 0.00/22.1M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Unexpected keys (classifier.bias, classifier.weight, conv_head.bias, conv_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.


Output shape: torch.Size([2, 10])


In [3]:
print(output)

tensor([[ 0.0126, -0.0574, -0.0610, -0.1870,  0.2715, -0.2460, -0.0294, -0.2435,
          0.0934, -0.2143],
        [ 0.2715, -0.0578, -0.1424, -0.3074,  0.2688, -0.2412,  0.1157, -0.2850,
          0.1176, -0.1314]], grad_fn=<AddmmBackward0>)
