# Notebook 3: Inference and Fine-tuning

In [ ]:
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader
from accident_ai_uganda.models.cnn_gru import CNN_GRU_Classifier
from accident_ai_uganda.data.dataset import RoadsideClips
from sklearn.metrics import classification_report

device='cuda' if torch.cuda.is_available() else 'cpu'
model=CNN_GRU_Classifier().to(device)
model.load_state_dict(torch.load('./artifacts/cnn_gru.pt', map_location=device))
model.eval()
val_ds=RoadsideClips('./data/roadside',split='val')
val_loader=DataLoader(val_ds,batch_size=2,shuffle=False)
ytrue=[]; ypred=[]
with torch.no_grad():
    for x,y in val_loader:
        x,y=x.to(device), y.to(device)
        probs=torch.sigmoid(model(x))
        preds=(probs>0.5).int()
        ytrue+=y.cpu().numpy().tolist()
        ypred+=preds.cpu().numpy().tolist()
print(classification_report(ytrue,ypred))

# Fine-tuning (unfreeze, smaller LR)
for p in model.spatial.parameters(): p.requires_grad=True
opt=optim.Adam(model.parameters(), lr=1e-4)
crit=nn.BCEWithLogitsLoss()
for epoch in range(2):
    model.train()
    for x,y in val_loader:
        x,y=x.to(device), y.to(device)
        opt.zero_grad(); logits=model(x)
        loss=crit(logits,y)
        loss.backward(); opt.step()
    print('Finetune epoch',epoch,'done')
