In [1]:
import torch
import torch.nn as nn

In [2]:
class SpatioTemporalBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )
    
    def forward(self, x):
        return self.conv(x)

In [5]:
class AnomalySTCN(nn.Module):
    def __init__(self, num_views=3, input_channels=3, base_channels=32):
        super().__init__()
        self.num_views = num_views
        
        # Feature extraction
        self.feature_extraction = nn.Sequential(
            SpatioTemporalBlock(input_channels, base_channels),
            SpatioTemporalBlock(base_channels, base_channels * 2),
            SpatioTemporalBlock(base_channels * 2, base_channels * 4)
        )
        
        # Global average pooling
        self.gap = nn.AdaptiveAvgPool3d(1)
        
        # Multi-view fusion
        fusion_dim = base_channels * 4
        self.fusion = nn.Sequential(
            nn.Linear(fusion_dim * num_views, fusion_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
        
        # Anomaly detection head
        self.anomaly_head = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(fusion_dim // 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # x shape: (batch_size, num_views, channels, time, H, W)
        batch_size = x.size(0)
        view_features = []
        
        # Process each view independently
        for view_idx in range(self.num_views):
            # Extract current view
            current_view = x[:, view_idx]  # (batch_size, channels, time, H, W)
            
            # Extract spatio-temporal features
            features = self.feature_extraction(current_view)
            
            # Global average pooling
            features = self.gap(features)
            features = features.view(batch_size, -1)
            
            view_features.append(features)
        
        # Concatenate features from all views
        combined_features = torch.cat(view_features, dim=1)
        
        # Multi-view fusion
        fused_features = self.fusion(combined_features)
        
        # Anomaly detection
        anomaly_score = self.anomaly_head(fused_features)  # Will output shape (batch_size, 1)
        
        return anomaly_score

In [6]:
def get_model(num_views=3, input_channels=3):
    return AnomalySTCN(num_views=num_views, input_channels=input_channels)