# Import Dependencies

In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import torch.nn as nn
import torch.optim as optim

# Read CSV and Prepare a Custom Dataset

In [None]:
class CrashDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        self.df = pd.read_csv(csv_path)
        self.img_dir = img_dir
        self.transform = transform

        # Build a list of (img_path, label)
        self.samples = []
        for idx, row in self.df.iterrows():
            vidname = f"{int(row['vidname']):06d}"  # zero-pad to 6 digits
            for fnum in range(1, 51):  # frame_1 to frame_50
                frame_str = f"{fnum:02d}"
                label = row[f'frame_{fnum}']  # 0 or 1
                img_name = f"C_{vidname}_{frame_str}.jpg"
                full_path = os.path.join(self.img_dir, img_name)
                self.samples.append((full_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

In [None]:
csv_path = "path to your Crash_Table.csv"
img_dir = "path to your CrashBest"

dataset = CrashDataset(csv_path, img_dir, transform=transform)

# 70/20 train/val split
train_size = int(0.7 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False)

# Redefine Model Class

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Flatten(),
            nn.Linear(32 * 56 * 56, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        return self.net(x)

model = SimpleCNN(num_classes=2)

# Fine Tune the Model

In [None]:
# Instantiate and load your pre-trained model
model = SimpleCNN(num_classes=2).cuda()
model.load_state_dict(torch.load("model/model_epoch_5.pth"))

# Define the loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)  # using a smaller LR for fine-tuning

# Fine-tune the model
finetune_epochs = 3  # number of epochs can be adjusted

for epoch in range(finetune_epochs):
    model.train()
    epoch_train_loss, correct_train, total_train = 0.0, 0, 0
    
    for images, labels in train_loader:
        images, labels = images.cuda(), labels.cuda()
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        epoch_train_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct_train += (predicted == labels).sum().item()
        total_train += labels.size(0)

    avg_train_loss = epoch_train_loss / total_train
    train_acc = correct_train / total_train

    # Evaluate on validation set
    model.eval()
    val_loss, correct_val, total_val = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            correct_val += (predicted == labels).sum().item()
            total_val += labels.size(0)
    
    avg_val_loss = val_loss / total_val
    val_acc = correct_val / total_val

    print(f"Fine-Tune Epoch {epoch+1}: "
          f"Train Loss {avg_train_loss:.4f}, Train Acc {train_acc:.4f}, "
          f"Val Loss {avg_val_loss:.4f}, Val Acc {val_acc:.4f}")

torch.save(model.state_dict(), f"fine_tuned_model/model_fine_tuned_{epoch+1}.pth")