In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from tracknet.model import GridTrackNetModel
from tracknet.dataset import TrackNet
from torchinfo import summary
from torch.optim import Adadelta
import numpy as np
from torch import cuda
from tqdm import tqdm
import torch
import wandb

In [None]:
model = GridTrackNetModel()
summary(model, input_size=(10, 15, 432, 768))

In [None]:
wandb.login()
wandb.init(project='GridTrackNet', name='wandb_train_l1_adam')
wandb.watch(model, log_freq=500)

In [None]:
# Define a Custom Loss function

def custom_loss(y_pred, y_true, confWeight=1, offsetWeight=0.001):
    print(y_pred.shape)

    y_pred = y_pred.split(5, dim=1)  # shape [batch, 5, 15, 27, 48]
    y_pred = torch.stack(y_pred, dim=2)  # shape [batch, 5, 15, 27, 48]
    y_pred = y_pred.permute(0, 2, 3, 4, 1) # shape [batch, 15, 27, 48, 5]
    
    y_true = y_true.split(5, dim=1)  # split into 5 tensors
    y_true = torch.stack(y_true, dim=2)  # stack them along the third dimension
    y_true = y_true.permute(0, 2, 3, 4, 1)

    # Separate confidence grids from offset grids
    confGridTrue, xOffsetGridTrue, yOffsetGridTrue = y_true.chunk(3, dim=-1)
    confGridPred, xOffsetGridPred, yOffsetGridPred = y_pred.chunk(3, dim=-1)

    # Combine x-offset and y-offset into a single tensor
    yTrueOffset = torch.cat([xOffsetGridTrue, yOffsetGridTrue], dim=-1)
    yPredOffset = torch.cat([xOffsetGridPred, yOffsetGridPred], dim=-1)

    # Offset loss
    diff = torch.abs(yTrueOffset - yPredOffset)
    sum_diff = diff.sum(dim=-1, keepdim=True)
    masked_sum_diff = confGridTrue * sum_diff  # Only compute loss for cell where confGridTrue = 1
    sum_offset = masked_sum_diff.sum(dim=[2, 3, 4])  # sum over spatial dimensions
    offset = sum_offset.mean(dim=1)  # mean over the batch

    # Confidence loss (focal)
    alpha = 0.75
    gamma = 2
    positiveConfLoss = alpha * confGridTrue * (1 - confGridPred).pow(gamma) * torch.log(torch.clamp(confGridPred, min=1e-7, max=1))
    negativeConfLoss = (1 - alpha) * (1 - confGridTrue) * confGridPred.pow(gamma) * torch.log(torch.clamp(1 - confGridPred, min=1e-7, max=1))
    confidence = -(positiveConfLoss + negativeConfLoss).mean(dim=[1, 2, 3, 4])  # mean over all dimensions except batch

    # Total loss
    loss = offsetWeight * offset + confWeight * confidence

    return loss.sum()  # Return the sum of the loss over the batch


In [None]:
# def validate(model, val_loader):
#     device = "cuda" if cuda.is_available() else "cpu"
#     model.eval()

#     corrects = []
#     losses = []

#     for _, (instances, label) in enumerate(val_loader):
#         with torch.no_grad():
#             label = label.permute(0, 1, 4, 2, 3).reshape(val_loader.batch_size, 3 * 5, 27, 48).to(device, dtype=torch.float32)

#             instances = instances.to(device, dtype=torch.float32)
#             outputs = model(instances)

#             loss = custom_loss(outputs, label)
#             losses.append(loss.item() * instances.size(0))

#             for i in range(val_loader.batch_size):
#                 # Each 3 items is a grid confidence for 1 of 5 frames
#                 for j in range(0, 15, 3):
#                     gt = np.argmax(label[i][j].flatten().cpu())
#                     gt_x, gt_y = np.unravel_index(gt, (27, 48))

#                     out = np.argmax(outputs[i][j].flatten().cpu())
#                     out_x, out_y = np.unravel_index(out, (27, 48))
#                     print(gt_x, out_x, gt_y, out_y)

#                     corrects.append(gt_x == out_x and gt_y == out_y)

#     acc = sum(corrects) / len(corrects)
#     avg_loss = np.average(losses)

#     return acc, avg_loss

In [None]:
import torch

def validate(model, val_loader):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()

    corrects = 0
    total = 0
    losses = 0

    for _, (instances, label) in enumerate(val_loader):
        with torch.no_grad():
            instances = instances.to(device, dtype=torch.float32)
            
            # Forward pass
            outputs = model(instances)

            error_label = label.permute(0, 1, 4, 2, 3).reshape(val_loader.batch_size, 3 * 5, 27, 48).to(device, dtype=torch.float32)
            loss = custom_loss(outputs, error_label)
            losses += loss.item() * instances.size(0)

            y_pred = outputs.split(5, dim=1)  # shape 3 x [batch, 5, 27, 48]
            y_pred = torch.stack(y_pred, dim=2)  # shape [batch, 5, 3, 27, 48]
            y_pred = y_pred.permute(0, 1, 3, 4, 2) # shape [batch, 5, 27, 48, 3]

            confGridTrue, xOffsetGridTrue, yOffsetGridTrue = label.chunk(3, dim=-1)
            confGridPred, xOffsetGridPred, yOffsetGridPred = y_pred.chunk(3, dim=-1) # shape [batch, 5, 27, 48, 1]

            confGridTrue = torch.squeeze(confGridTrue, dim=-1)
            confGridPred = torch.squeeze(confGridPred, dim=-1)

            for i in range(confGridTrue.shape[0]):
                for j in range(confGridTrue.shape[1]):
                    gt = np.argmax(confGridTrue[i][j].flatten().cpu())
                    gt_x, gt_y = np.unravel_index(gt, (27, 48))

                    out = np.argmax(confGridPred[i][j].flatten().cpu())
                    out_x, out_y = np.unravel_index(out, (27, 48))

                    corrects += gt_x == out_x and gt_y == out_y
                    total += 1

    # Calculate accuracy and average loss
    acc = corrects / total
    avg_loss = losses / total

    return acc, avg_loss

In [None]:

def train(model, train_loader, val_loader, epochs = 50):
    device = "cuda" if cuda.is_available() else "cpu"
    model.to(device)

    # criterion = nn.L1Loss()
    optimizer = Adadelta(model.parameters(), lr=1.0)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        total_samples = 0

        for instances, label in tqdm(train_loader):
            optimizer.zero_grad()

            label = label.permute(0, 1, 4, 2, 3).reshape(train_loader.batch_size, 3 * 5, 27, 48).to(device, dtype=torch.float32)

            instances = instances.to(device, dtype=torch.float32)
            outputs = model(instances)

            loss = custom_loss(outputs, label)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * instances.size(0)
            total_samples += instances.size(0)
            
            # wandb.log({ 'train_loss': loss })

        avg_loss = running_loss / total_samples

        val_acc, val_loss = validate(model, criterion, val_loader)

        print(f"Epoch {epoch+1}/{epochs}: Train Loss {avg_loss:.4f}, Val Acc={val_acc:.4f}, Val Loss={val_loss:.4f}")

        wandb.log({
            'epoch': epoch,
            'train_loss': avg_loss,
            'val_acc': val_acc,
            'val_loss': val_loss,
        })


In [None]:
import os
import random
from torch.utils.data import DataLoader

files = os.listdir('compiled_dataset')
files = list(filter(lambda x: x.endswith('.hdf5'), files))
random.shuffle(files)

val_files = files[:6]
train_files = files[6:]

# Load the dataset

val_dataset = TrackNet('compiled_dataset', val_files, debug=False)
val_loader = DataLoader(val_dataset, batch_size=10, pin_memory=True, num_workers=0)

train_dataset = TrackNet('compiled_dataset', train_files, debug=False)
train_loader = DataLoader(train_dataset, batch_size=10, pin_memory=True, num_workers=0)

In [None]:
train(model, train_loader, val_loader)

# wandb.finish()