In [49]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import cv2
import pandas as pd

class VideoDataset(Dataset):
    def __init__(self, df, num_frames=16, transform=None):
        self.df = df
        self.num_frames = num_frames
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        video_path = self.df.iloc[idx]['clip_path']
        label = self.df.iloc[idx]['encoded_label']

        # Simple frame sampling logic
        cap = cv2.VideoCapture(video_path)
        frames = []
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Sample equidistant frames
        indices = torch.linspace(0, total_frames - 1, self.num_frames).long()

        for i in range(total_frames):
            ret, frame = cap.read()
            if not ret: break
            if i in indices:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                if self.transform:
                    frame = self.transform(frame)
                frame_t = torch.from_numpy(frame).permute(2, 0, 1).contiguous().float() / 255.0
                frames.append(frame_t)
                # frames.append(frame)
        cap.release()

        # Stack frames into shape (T, C, H, W)
        return torch.stack(frames), torch.tensor(label, dtype=torch.long)

In [50]:
class VideoClassifier(nn.Module):
    def __init__(self, num_classes=3, pooling='avg'):
        super(VideoClassifier, self).__init__()
        # Load pre-trained ResNet
        resnet = models.resnet18(weights='IMAGENET1K_V1')

        # Feature extractor: all layers except the last FC
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        # Freeze early layers, fine-tune only the last block and FC
        for param in self.backbone[:-2].parameters():
            param.requires_grad = False

        self.pooling = pooling
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        # x shape: (Batch, Frames, C, H, W)
        batch_size, T, C, H, W = x.shape

        # Reshape for 2D CNN: (Batch * Frames, C, H, W)
        x = x.view(batch_size * T, C, H, W)

        # Extract features: (Batch * T, 512, 1, 1)
        features = self.backbone(x)
        features = features.view(batch_size, T, 512)

        # Temporal Pooling
        if self.pooling == 'avg':
            combined = torch.mean(features, dim=1)
        else: # Max pooling
            combined, _ = torch.max(features, dim=1)

        return self.fc(combined)

In [51]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


In [52]:
def train_model(model, train_loader, val_loader, epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    best_acc = 0.0

    for epoch in range(epochs):
        model.train()
        for videos, labels in train_loader:
            videos, labels = videos.to(DEVICE), labels.to(DEVICE)

            outputs = model(videos)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Validation logic...
        # If val_acc > best_acc: torch.save(model.state_dict(), 'best_model.pth')

In [53]:
train_df = pd.read_csv("./dataset/splits/train.csv", index_col='index')
test_df = pd.read_csv("./dataset/splits/test.csv", index_col='index')
val_df = pd.read_csv("./dataset/splits/validation.csv", index_col='index')
train_df.head()

Unnamed: 0_level_0,clip_name,clip_path,label,encoded_label
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,v_Diving_g03_c01.avi,./dataset/Diving/v_Diving_g03_c01.avi,Diving,0
1,v_Diving_g19_c03.avi,./dataset/Diving/v_Diving_g19_c03.avi,Diving,0
2,v_Diving_g03_c04.avi,./dataset/Diving/v_Diving_g03_c04.avi,Diving,0
3,v_Diving_g05_c04.avi,./dataset/Diving/v_Diving_g05_c04.avi,Diving,0
4,v_Diving_g15_c03.avi,./dataset/Diving/v_Diving_g15_c03.avi,Diving,0


In [54]:
dataset = VideoDataset(train_df, 10)
val_dataset = VideoDataset(val_df, 10)
model = VideoClassifier().to(DEVICE)
sample_video, sample_label = dataset[0]
print(sample_video.shape, sample_video.dtype, sample_video.min().item(), sample_video.max().item())
print(sample_label, sample_label.dtype)
train_model(model, DataLoader(dataset, batch_size=1), DataLoader(val_dataset, batch_size=1))

torch.Size([9, 3, 240, 320]) torch.float32 0.0 1.0
tensor(0) torch.int64
