In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
from typing import Tuple, List
import cv2
from collections import deque

In [2]:


class TemporalAnomalyDetector(nn.Module):
    """
    GRU-based temporal anomaly detection module for identifying
    sudden motion pattern changes in video sequences
    """
    def __init__(self, input_dim: int = 512, hidden_dim: int = 256, num_layers: int = 2):
        super(TemporalAnomalyDetector, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # GRU for temporal sequence modeling
        self.gru = nn.GRU(input_dim, hidden_dim, num_layers, 
                         batch_first=True, dropout=0.3)
        
        # Anomaly score prediction head
        self.anomaly_head = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
        # Motion intensity head for detecting abrupt changes
        self.motion_head = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        # x shape: (batch_size, sequence_length, input_dim)
        gru_out, hidden = self.gru(x)
        
        # Use last hidden state for prediction
        last_output = gru_out[:, -1, :]
        
        anomaly_score = self.anomaly_head(last_output)
        motion_intensity = self.motion_head(last_output)
        
        return anomaly_score, motion_intensity, hidden

In [3]:
class AccidentDetectionCNN(nn.Module):
    """
    YOLOv5-inspired CNN backbone with temporal anomaly detection
    for motorcycle accident detection in roadside camera footage
    """
    def __init__(self, num_classes: int = 1, sequence_length: int = 16):
        super(AccidentDetectionCNN, self).__init__()
        
        self.sequence_length = sequence_length
        
        # CNN Backbone (Modified ResNet-50 for efficiency)
        self.backbone = models.resnet50(pretrained=True)
        
        # Replace final layers to extract features instead of classification
        self.backbone.fc = nn.Identity()
        
        # Feature extraction layer
        self.feature_extractor = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.BatchNorm1d(512)
        )
        
        # Temporal anomaly detection module
        self.temporal_detector = TemporalAnomalyDetector(
            input_dim=512, 
            hidden_dim=256, 
            num_layers=2
        )
        
        # Object detection head (simplified YOLO-style)
        self.detection_head = nn.Sequential(
            nn.Conv2d(2048, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 5 + num_classes, 1)  # 5 = x, y, w, h, confidence
        )
        
        # Multi-frame confidence verification
        self.confidence_buffer = deque(maxlen=5)  # Store last 5 predictions
        
    def extract_features(self, x):
        """Extract features from single frame"""
        # x shape: (batch_size, 3, H, W)
        features = self.backbone(x)
        return self.feature_extractor(features)
    
    def detect_objects(self, x):
        """Object detection on single frame"""
        # Get feature maps from backbone (before global pooling)
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        
        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)  # Shape: (batch, 2048, H/32, W/32)
        
        detection = self.detection_head(x)
        return detection
    
    def forward(self, frame_sequence):
        """
        Process sequence of frames for accident detection
        frame_sequence shape: (batch_size, sequence_length, 3, H, W)
        """
        batch_size, seq_len, C, H, W = frame_sequence.shape
        
        # Extract features from each frame in sequence
        features_sequence = []
        object_detections = []
        
        for i in range(seq_len):
            frame = frame_sequence[:, i]  # (batch_size, 3, H, W)
            
            # Extract temporal features
            frame_features = self.extract_features(frame)
            features_sequence.append(frame_features)
            
            # Perform object detection
            detection = self.detect_objects(frame)
            object_detections.append(detection)
        
        # Stack features for temporal analysis
        features_tensor = torch.stack(features_sequence, dim=1)  # (batch, seq_len, 512)
        
        # Temporal anomaly detection
        anomaly_score, motion_intensity, _ = self.temporal_detector(features_tensor)
        
        # Get latest object detection
        latest_detection = object_detections[-1]
        
        return {
            'anomaly_score': anomaly_score,
            'motion_intensity': motion_intensity,
            'object_detection': latest_detection,
            'features': features_tensor
        }
    
    def verify_confidence(self, prediction_scores):
        """
        Multi-frame confidence verification to reduce false positives
        """
        self.confidence_buffer.append(prediction_scores.cpu().numpy())
        
        if len(self.confidence_buffer) < 3:
            return False
        
        # Check if last 3 predictions consistently indicate accident
        recent_scores = list(self.confidence_buffer)[-3:]
        avg_score = np.mean(recent_scores)
        consistency = np.std(recent_scores) < 0.1  # Low variation indicates consistency
        
        return avg_score > 0.7 and consistency



In [4]:
class AccidentDetectionLoss(nn.Module):
    """
    Combined loss function for accident detection system
    """
    def __init__(self, anomaly_weight=1.0, motion_weight=0.5, detection_weight=1.0):
        super(AccidentDetectionLoss, self).__init__()
        self.anomaly_weight = anomaly_weight
        self.motion_weight = motion_weight
        self.detection_weight = detection_weight
        
        self.bce_loss = nn.BCELoss()
        self.mse_loss = nn.MSELoss()
        
    def forward(self, predictions, targets):
        """
        Calculate combined loss
        predictions: dict with 'anomaly_score', 'motion_intensity', 'object_detection'
        targets: dict with corresponding ground truth
        """
        total_loss = 0
        
        # Anomaly detection loss
        if 'anomaly_labels' in targets:
            anomaly_loss = self.bce_loss(
                predictions['anomaly_score'], 
                targets['anomaly_labels']
            )
            total_loss += self.anomaly_weight * anomaly_loss
        
        # Motion intensity loss
        if 'motion_labels' in targets:
            motion_loss = self.mse_loss(
                predictions['motion_intensity'],
                targets['motion_labels']
            )
            total_loss += self.motion_weight * motion_loss
        
        # Object detection loss (simplified)
        if 'detection_labels' in targets:
            detection_loss = self.mse_loss(
                predictions['object_detection'],
                targets['detection_labels']
            )
            total_loss += self.detection_weight * detection_loss
        
        return total_loss



In [5]:
# Utility functions for preprocessing
class VideoPreprocessor:
    """
    Preprocessor for video frames from roadside cameras
    """
    def __init__(self, target_size=(416, 416), sequence_length=16):
        self.target_size = target_size
        self.sequence_length = sequence_length
        
    def preprocess_frame(self, frame):
        """Preprocess single frame"""
        # Resize frame
        frame = cv2.resize(frame, self.target_size)
        
        # Normalize to [0, 1]
        frame = frame.astype(np.float32) / 255.0
        
        # Convert BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Transpose to CHW format
        frame = np.transpose(frame, (2, 0, 1))
        
        return frame
    
    def create_sequence(self, frames):
        """Create sequence of frames for temporal analysis"""
        processed_frames = []
        
        for frame in frames:
            processed_frame = self.preprocess_frame(frame)
            processed_frames.append(processed_frame)
        
        # Ensure we have exactly sequence_length frames
        if len(processed_frames) < self.sequence_length:
            # Pad with last frame
            last_frame = processed_frames[-1]
            while len(processed_frames) < self.sequence_length:
                processed_frames.append(last_frame.copy())
        elif len(processed_frames) > self.sequence_length:
            # Take last sequence_length frames
            processed_frames = processed_frames[-self.sequence_length:]
        
        return np.stack(processed_frames)



In [6]:
# Example usage and testing
if __name__ == "__main__":
    # Initialize model
    model = AccidentDetectionCNN(num_classes=1, sequence_length=16)
    
    # Create dummy input (batch_size=2, sequence_length=16, channels=3, height=416, width=416)
    dummy_input = torch.randn(2, 16, 3, 416, 416)
    
    # Forward pass
    with torch.no_grad():
        outputs = model(dummy_input)
        
        print("Model outputs:")
        print(f"Anomaly scores shape: {outputs['anomaly_score'].shape}")
        print(f"Motion intensity shape: {outputs['motion_intensity'].shape}")
        print(f"Object detection shape: {outputs['object_detection'].shape}")
        print(f"Features shape: {outputs['features'].shape}")
    
    # Test loss function
    loss_fn = AccidentDetectionLoss()
    
    # Create dummy targets
    targets = {
        'anomaly_labels': torch.randint(0, 2, (2, 1)).float(),
        'motion_labels': torch.rand(2, 1),
        'detection_labels': torch.randn_like(outputs['object_detection'])
    }
    
    loss = loss_fn(outputs, targets)
    print(f"\nTotal loss: {loss.item():.4f}")
    
    # Test preprocessing
    preprocessor = VideoPreprocessor()
    
    # Create dummy video frames
    dummy_frames = [np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) for _ in range(20)]
    
    # Process sequence
    processed_sequence = preprocessor.create_sequence(dummy_frames)
    print(f"\nProcessed sequence shape: {processed_sequence.shape}")
    
    print("\nModel initialization complete!")
    print("Ready for training on accident detection dataset.")



Model outputs:
Anomaly scores shape: torch.Size([2, 1])
Motion intensity shape: torch.Size([2, 1])
Object detection shape: torch.Size([2, 6, 13, 13])
Features shape: torch.Size([2, 16, 512])

Total loss: 1.9551

Processed sequence shape: (16, 3, 416, 416)

Model initialization complete!
Ready for training on accident detection dataset.


In [7]:
# Save model classes for easy import in other notebooks
def save_model_classes():
    """Save the model classes to a pickle file for easy import"""
    import pickle
    
    model_classes = {
        'AccidentDetectionCNN': AccidentDetectionCNN,
        'AccidentDetectionLoss': AccidentDetectionLoss,
        'VideoPreprocessor': VideoPreprocessor,
        'TemporalAnomalyDetector': TemporalAnomalyDetector
    }
    
    with open('accident_model_classes.pkl', 'wb') as f:
        pickle.dump(model_classes, f)
    
    return model_classes



In [8]:
def save_trained_model(model, optimizer, epoch, accuracy, save_path='best_accident_model.pth'):
    """Save a trained model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict() if optimizer else None,
        'accuracy': accuracy,
        'model_config': {
            'num_classes': getattr(model, 'num_classes', 1),
            'sequence_length': getattr(model, 'sequence_length', 16)
        }
    }
    
    torch.save(checkpoint, save_path)
    return checkpoint



In [9]:
if __name__ == "__main__":
    save_model_classes()