### Contributor - Jainil

In [None]:
import os
import torch
import json
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the desired frame count
DESIRED_FRAME_COUNT = 16

# Load the EVENT_DICTIONARY for mapping annotation labels
EVENT_DICTIONARY = {
    'action_class': {"Tackling": 0, "Standing tackling": 1, "High leg": 2, "Holding": 3, "Pushing": 4,
                     "Elbowing": 5, "Challenge": 6, "Dive": 7, "Dont know": 8},
    'offence_class': {"Offence": 0, "Between": 1, "No Offence": 2, "No offence": 2},
    'severity_class': {"1.0": 0, "2.0": 1, "3.0": 2, "4.0": 3, "5.0": 4},
    'bodypart_class': {"Upper body": 0, "Under body": 1},
    'offence_severity_class': {"No offence": 0, "Offence + No card": 1, "Offence + Yellow card": 2, "Offence + Red card": 3}
}

# Transformation for RGB preprocessing
rgb_transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Transformation for flow preprocessing
flow_transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor()
])

def load_filtered_clips_and_labels(DATA_PATH, split, max_samples_o, max_samples_no):
    rgb_clips, flow_clips = [], []
    labels_action, labels_offence, labels_severity, labels_bodypart, labels_offence_severity = [], [], [], [], []

    annotations_path = os.path.join(DATA_PATH, split, "annotations.json")
    print(f"Loading annotations from: {annotations_path}")

    with open(annotations_path, 'r') as f:
        annotations = json.load(f)
    print(f"Total actions found in annotations: {len(annotations['Actions'])}")

    offence_count, no_offence_count, skipped_actions = 0, 0, 0

    for action_index, (action_key, action_data) in enumerate(annotations['Actions'].items()):
        offence_class = action_data['Offence']
        if (offence_class == "Offence" and offence_count >= max_samples_o) or \
           (offence_class in ["No offence", "No Offence"] and no_offence_count >= max_samples_no):
            continue

        # Map labels to indices using the dictionary
        action_label = EVENT_DICTIONARY['action_class'].get(action_data['Action class'])
        offence_label = EVENT_DICTIONARY['offence_class'].get(offence_class)
        severity_label = EVENT_DICTIONARY['severity_class'].get(action_data.get('Severity', '1.0'))
        bodypart_label = EVENT_DICTIONARY['bodypart_class'].get(action_data.get('Bodypart', 'Upper body'))
        offence_severity = f"{offence_class} + {EVENT_DICTIONARY['severity_class'].get(severity_label, 'No card')}"
        offence_severity_label = EVENT_DICTIONARY['offence_severity_class'].get(offence_severity, 0)

        # Skip if any label is missing
        if None in [action_label, offence_label, severity_label, bodypart_label, offence_severity_label]:
            skipped_actions += 1
            continue

        action_folder = os.path.join(DATA_PATH, split, f"action_{action_key}")
        if not os.path.exists(action_folder):
            skipped_actions += 1
            continue

        rgb_action_clips, flow_action_clips = [], []
        for clip_idx in range(2):
            clip_path = os.path.join(action_folder, f"clip_{clip_idx}.mp4")
            if not os.path.exists(clip_path):
                continue

            cap = cv2.VideoCapture(clip_path)
            ret, prev_frame = cap.read()
            if not ret:
                continue

            prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
            rgb_frames, flow_frames = [], []

            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break

                # Process RGB frame
                rgb_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                rgb_frame = rgb_transform(rgb_frame).to(device)
                rgb_frames.append(rgb_frame)

                # Process Optical Flow
                curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                flow = cv2.calcOpticalFlowFarneback(prev_gray, curr_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
                flow = np.clip(flow, -20, 20)  # Clipping to limit extreme values
                flow = ((flow + 20) * (255.0 / 40)).astype(np.uint8)  # Normalizing to 0-255 range
                flow_frame = Image.fromarray(flow[..., 0])  # Taking the horizontal component for simplicity
                flow_frame = flow_transform(flow_frame).to(device)
                flow_frames.append(flow_frame)
                prev_gray = curr_gray

            cap.release()

            # Adjust frame count
            if len(rgb_frames) > DESIRED_FRAME_COUNT:
                indices = np.linspace(0, len(rgb_frames) - 1, DESIRED_FRAME_COUNT).astype(int)
                rgb_frames = [rgb_frames[i] for i in indices]
                flow_frames = [flow_frames[i] for i in indices]
            elif len(rgb_frames) < DESIRED_FRAME_COUNT:
                rgb_frames += [rgb_frames[-1]] * (DESIRED_FRAME_COUNT - len(rgb_frames))
                flow_frames += [flow_frames[-1]] * (DESIRED_FRAME_COUNT - len(flow_frames))

            rgb_action_clips.append(torch.stack(rgb_frames, dim=0))
            flow_action_clips.append(torch.stack(flow_frames, dim=0))

        if rgb_action_clips and flow_action_clips:
            rgb_clips.append(rgb_action_clips)
            flow_clips.append(flow_action_clips)
            labels_action.append(action_label)
            labels_offence.append(offence_label)
            labels_severity.append(severity_label)
            labels_bodypart.append(bodypart_label)
            labels_offence_severity.append(offence_severity_label)

            if offence_class == "Offence":
                offence_count += 1
            else:
                no_offence_count += 1

        if offence_count >= max_samples_o and no_offence_count >= max_samples_no:
            break

    print("\nSummary:")
    print(f"Total actions loaded: {len(rgb_clips)}")
    print(f"Total actions skipped: {skipped_actions}")
    return rgb_clips, flow_clips, labels_action, labels_offence, labels_severity, labels_bodypart, labels_offence_severity


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import math

class MultiScaleAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, dropout=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        return x

class MViTBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, 
                 dropout=0., attention_dropout=0., drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiScaleAttention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, dropout=attention_dropout)
        
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout)
        )
        
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output

class PatchEmbed(nn.Module):
    def __init__(self, img_size=112, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x

class MViTForFoulDetection(nn.Module):
    def __init__(
        self,
        img_size=112,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.,
        qkv_bias=True,
        dropout=0.1,
        attention_dropout=0.1,
        drop_path=0.1,
        num_frames=8,
        num_classes_action=9,
        num_classes_offence=3,
        num_classes_severity=5,
        num_classes_bodypart=2,
        num_classes_offence_severity=4
    ):
        super().__init__()
        
        self.num_frames = num_frames
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.temporal_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
        
        self.pos_drop = nn.Dropout(p=dropout)
        
        # Stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
        
        self.blocks = nn.ModuleList([
            MViTBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                dropout=dropout,
                attention_dropout=attention_dropout,
                drop_path=dpr[i]
            )
            for i in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # Task-specific heads
        self.fc_action = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.LayerNorm(embed_dim // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embed_dim // 2, num_classes_action)
        )
        
        self.fc_offence = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.LayerNorm(embed_dim // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embed_dim // 2, num_classes_offence)
        )
        
        self.fc_severity = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.LayerNorm(embed_dim // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embed_dim // 2, num_classes_severity)
        )
        
        self.fc_bodypart = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.LayerNorm(embed_dim // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embed_dim // 2, num_classes_bodypart)
        )
        
        self.fc_offence_severity = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.LayerNorm(embed_dim // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embed_dim // 2, num_classes_offence_severity)
        )
        
        # Initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.temporal_embed, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_features(self, x):
        # Reshape input: B, F, C, H, W -> (B * F), C, H, W
        B = x.shape[0]
        x = rearrange(x, 'b f c h w -> (b f) c h w')
        
        # Patch embedding
        x = self.patch_embed(x)
        
        # Reshape back: (B * F), N, D -> B, F, N, D
        x = rearrange(x, '(b f) n d -> b f n d', b=B)
        
        # Add temporal embeddings
        temporal_embed = repeat(self.temporal_embed, '1 f d -> b f d', b=B)
        x = x + temporal_embed.unsqueeze(2)
        
        # Reshape for transformer blocks: B, F, N, D -> (B * F), N, D
        x = rearrange(x, 'b f n d -> (b f) n d')
        
        # Add position embeddings
        x = x + self.pos_embed[:, 1:]
        
        # Add classification token
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=x.shape[0])
        x = torch.cat((cls_tokens, x), dim=1)
        
        x = self.pos_drop(x)
        
        # Apply transformer blocks
        for blk in self.blocks:
            x = blk(x)
            
        x = self.norm(x)
        
        # Take only the classification token output
        x = x[:, 0]
        
        # Reshape back to batch dimension: (B * F), D -> B, F, D
        x = rearrange(x, '(b f) d -> b f d', b=B)
        
        # Global temporal pooling
        x = x.mean(dim=1)
        
        return x

    def forward(self, x):
        x = self.forward_features(x)
        
        return {
            'action': self.fc_action(x),
            'offence': self.fc_offence(x),
            'severity': self.fc_severity(x),
            'bodypart': self.fc_bodypart(x),
            'offence_severity': self.fc_offence_severity(x)
        }

# Training utilities
class FoulDetectionLoss(nn.Module):
    def __init__(self, class_weights=None):
        super().__init__()
        self.class_weights = class_weights
        
    def forward(self, predictions, targets):
        losses = {}
        
        for task, pred in predictions.items():
            if self.class_weights is not None and task in self.class_weights:
                weight = self.class_weights[task].to(pred.device)
                losses[task] = F.cross_entropy(pred, targets[task], weight=weight)
            else:
                losses[task] = F.cross_entropy(pred, targets[task])
        
        # Total loss is the sum of all task losses
        total_loss = sum(losses.values())
        return total_loss, losses

def create_mvit_model(num_frames=16):
    model = MViTForFoulDetection(
        img_size=112,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.,
        num_frames=num_frames,
        dropout=0.1,
        attention_dropout=0.1,
        drop_path=0.1
    )
    return model

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score


#from mvit_model import MViTForFoulDetection, FoulDetectionLoss  # Import from our previous implementation

class ActionDataset(Dataset):
    def __init__(self, rgb_clips, flow_clips, labels, transform=None, num_frames=8):
        self.rgb_clips = rgb_clips
        self.flow_clips = flow_clips
        self.labels = labels
        self.transform = transform
        self.num_frames = num_frames

    def __len__(self):
        return len(self.rgb_clips)

    def sample_frames(self, clip):
        # Uniformly sample frames if clip length > num_frames
        if len(clip) > self.num_frames:
            indices = np.linspace(0, len(clip)-1, self.num_frames, dtype=int)
            return [clip[i] for i in indices]
        # Repeat last frame if clip length < num_frames
        while len(clip) < self.num_frames:
            clip.append(clip[-1])
        return clip

    def __getitem__(self, idx):
        # Sample frames
        rgb_frames = self.sample_frames(self.rgb_clips[idx])
        flow_frames = self.sample_frames(self.flow_clips[idx])

        # Apply transformations
        if self.transform:
            rgb_frames = [self.transform(frame) if not isinstance(frame, torch.Tensor) else frame 
                         for frame in rgb_frames]
            flow_frames = [self.transform(frame) if not isinstance(frame, torch.Tensor) else frame 
                         for frame in flow_frames]

        # Stack frames
        rgb_frames = torch.stack(rgb_frames, dim=0)  # [num_frames, C, H, W]
        flow_frames = torch.stack(flow_frames, dim=0)  # [num_frames, C, H, W]

        # Create label dictionary
        label_dict = {key: torch.tensor(self.labels[key][idx], dtype=torch.long) 
                     for key in self.labels.keys()}

        return rgb_frames, flow_frames, label_dict

class MViTTrainer:
    def __init__(self, model, train_loader, valid_loader, criterion, optimizer, scheduler, device, num_epochs, grad_clip_val=1.0, mixed_precision=True):
        self.model = model
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.num_epochs = num_epochs
        self.grad_clip_val = grad_clip_val
        self.mixed_precision = mixed_precision
        self.scaler = GradScaler() if mixed_precision else None

    def train_one_epoch(self):
        self.model.train()
        running_loss = 0.0
        all_preds = {key: [] for key in ['action', 'offence', 'severity', 'bodypart', 'offence_severity']}
        all_labels = {key: [] for key in all_preds.keys()}

        for batch_idx, (rgb_input, flow_input, labels) in enumerate(tqdm(self.train_loader)):
            rgb_input = rgb_input.to(self.device)
            flow_input = flow_input.to(self.device)
            labels = {k: v.to(self.device) for k, v in labels.items()}

            self.optimizer.zero_grad()

            with autocast(enabled=self.mixed_precision):
                outputs = self.model(rgb_input)
                loss, task_losses = self.criterion(outputs, labels)

            if self.mixed_precision:
                self.scaler.scale(loss).backward()
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_val)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_val)
                self.optimizer.step()

            running_loss += loss.item()

            for task in all_preds.keys():
                all_preds[task].extend(outputs[task].argmax(dim=1).cpu().numpy())
                all_labels[task].extend(labels[task].cpu().numpy())

            # Print batch metrics every 10 batches
            if batch_idx % 10 == 0:
                print(f"Batch {batch_idx}: Loss = {loss.item():.4f}")
                for task, l in task_losses.items():
                    print(f"{task}_loss = {l.item():.4f}")

        avg_loss = running_loss / len(self.train_loader)
        accuracies = {task: accuracy_score(all_labels[task], all_preds[task]) for task in all_preds.keys()}
        return avg_loss, accuracies

    @torch.no_grad()
    def validate(self):
        # validate method remains unchanged
        self.model.eval()
        running_loss = 0.0
        all_preds = {key: [] for key in ['action', 'offence', 'severity', 'bodypart', 'offence_severity']}
        all_labels = {key: [] for key in all_preds.keys()}

        for rgb_input, flow_input, labels in tqdm(self.valid_loader):
            rgb_input = rgb_input.to(self.device)
            flow_input = flow_input.to(self.device)
            labels = {k: v.to(self.device) for k, v in labels.items()}

            outputs = self.model(rgb_input)
            loss, task_losses = self.criterion(outputs, labels)

            running_loss += loss.item()

            for task in all_preds.keys():
                all_preds[task].extend(outputs[task].argmax(dim=1).cpu().numpy())
                all_labels[task].extend(labels[task].cpu().numpy())

        avg_loss = running_loss / len(self.valid_loader)
        accuracies = {task: accuracy_score(all_labels[task], all_preds[task]) 
                     for task in all_preds.keys()}

        return avg_loss, accuracies
        pass

    def train(self):
        best_val_loss = float('inf')
        for epoch in range(self.num_epochs):
            print(f"\nEpoch {epoch + 1}/{self.num_epochs}")
            
            # Training phase
            train_loss, train_accuracies = self.train_one_epoch()
            
            # Validation phase
            val_loss, val_accuracies = self.validate()
            
            # Update learning rate
            self.scheduler.step()
            
            # Print metrics
            print(f"Train Loss: {train_loss:.4f} | Train Accuracies: {train_accuracies}")
            print(f"Val Loss: {val_loss:.4f} | Val Accuracies: {val_accuracies}")
            print(f"Learning Rate: {self.optimizer.param_groups[0]['lr']:.6f}")

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'best_val_loss': best_val_loss,
                }, "best_mvit_model.pth")
                print("Saved best model.")

def main(data_path, num_epochs=5, batch_size=2, learning_rate=5e-5, num_frames=8, max_samples_o=2, max_samples_no=2):
    # Set device and enable mixed precision training
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load and preprocess data
    train_rgb_clips, train_flow_clips, train_labels_action, train_labels_offence, \
    train_labels_severity, train_labels_bodypart, train_labels_offence_severity = \
    load_filtered_clips_and_labels(data_path, "train", max_samples_o, max_samples_no)

    valid_rgb_clips, valid_flow_clips, valid_labels_action, valid_labels_offence, \
    valid_labels_severity, valid_labels_bodypart, valid_labels_offence_severity = \
    load_filtered_clips_and_labels(data_path, "valid", max_samples_o, max_samples_no)

    # Rest of the main function remains unchanged, just remove wandb.init() and wandb.finish()
    # ...
    # Set device and enable mixed precision training
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load and preprocess data
    train_rgb_clips, train_flow_clips, train_labels_action, train_labels_offence, \
    train_labels_severity, train_labels_bodypart, train_labels_offence_severity = \
        load_filtered_clips_and_labels(data_path, "train", max_samples_o, max_samples_no)
    
    valid_rgb_clips, valid_flow_clips, valid_labels_action, valid_labels_offence, \
    valid_labels_severity, valid_labels_bodypart, valid_labels_offence_severity = \
        load_filtered_clips_and_labels(data_path, "valid", max_samples_o, max_samples_no)

    # Organize labels
    train_labels = {
        "action": train_labels_action,
        "offence": train_labels_offence,
        "severity": train_labels_severity,
        "bodypart": train_labels_bodypart,
        "offence_severity": train_labels_offence_severity
    }
    valid_labels = {
        "action": valid_labels_action,
        "offence": valid_labels_offence,
        "severity": valid_labels_severity,
        "bodypart": valid_labels_bodypart,
        "offence_severity": valid_labels_offence_severity
    }

    # Define transforms with additional augmentation for training
    train_transform = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    valid_transform = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create datasets and dataloaders
    train_dataset = ActionDataset(train_rgb_clips, train_flow_clips, train_labels, 
                                transform=train_transform, num_frames=num_frames)
    valid_dataset = ActionDataset(valid_rgb_clips, valid_flow_clips, valid_labels, 
                                transform=valid_transform, num_frames=num_frames)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                            num_workers=4, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, 
                            num_workers=4, pin_memory=True)

    # Initialize model and move to device
    model = MViTForFoulDetection(num_frames=num_frames).to(device)
    
    # Initialize criterion with class weights
    criterion = FoulDetectionLoss()
    
    # Initialize optimizer with weight decay
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, 
                           weight_decay=0.05, betas=(0.9, 0.999))
    
    # Initialize learning rate scheduler
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

    # Create trainer instance
    trainer = MViTTrainer(
        model=model,
        train_loader=train_loader,
        valid_loader=valid_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        num_epochs=num_epochs,
        grad_clip_val=1.0,
        mixed_precision=True
    )

    # Start training
    trainer.train()


if __name__ == "__main__":
    DATA_PATH = 'mvfouls'
    main(data_path=DATA_PATH)