# Notebook 2: Train CNN+GRU on CADP

In [ ]:
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np, matplotlib.pyplot as plt
from accident_ai_uganda.models.cnn_gru import CNN_GRU_Classifier
from accident_ai_uganda.data.dataset import RoadsideClips
from sklearn.metrics import precision_recall_fscore_support

train_ds = RoadsideClips('./data/roadside', split='train')
val_ds   = RoadsideClips('./data/roadside', split='val')
train_loader = DataLoader(train_ds,batch_size=2,shuffle=True)
val_loader   = DataLoader(val_ds,batch_size=2,shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CNN_GRU_Classifier().to(device)
crit = nn.BCEWithLogitsLoss()
opt = optim.Adam(model.parameters(), lr=1e-3)

train_losses, val_losses, f1s = [],[],[]
for epoch in range(3):
    model.train(); tl=0
    for x,y in train_loader:
        x,y=x.to(device), y.to(device)
        opt.zero_grad(); logits=model(x)
        loss=crit(logits,y)
        loss.backward(); opt.step()
        tl+=loss.item()
    train_losses.append(tl/len(train_loader))

    model.eval(); vl=0; ytrue=[]; ypred=[]
    with torch.no_grad():
        for x,y in val_loader:
            x,y=x.to(device), y.to(device)
            logits=model(x)
            loss=crit(logits,y)
            vl+=loss.item()
            probs=torch.sigmoid(logits).cpu().numpy()
            ytrue+=y.cpu().numpy().tolist()
            ypred+=(probs>0.5).astype(int).tolist()
    val_losses.append(vl/len(val_loader))
    p,r,f1,_=precision_recall_fscore_support(ytrue,ypred,average='binary',zero_division=0)
    f1s.append(f1)
    print(f"Epoch {epoch}: TrainLoss {train_losses[-1]:.3f} ValLoss {val_losses[-1]:.3f} F1 {f1:.3f}")

torch.save(model.state_dict(),'./artifacts/cnn_gru.pt')
