### Single batch overfit test

In [None]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from torchvision.models import resnet34, ResNet34_Weights
from src.lib.models import MultiObjectNet
from src.lib.dataset import FashionDataset
from src.lib.loss import Loss
from src.lib.utils import create_transforms, collate

MAX_OBJS = 5
max_classes = 14

device = torch.device("cuda")


train_transforms = create_transforms("train", False)

train_dataset = FashionDataset("train", "data/train/annotations", "data/train/images", train_transforms)
batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=False, num_workers=0, collate_fn=collate)

batch = next(iter(train_dataloader))#train_dataset[idx]
x, y, y_box = batch
#y = [y]
#y_box = [y_box]
x = x.to(device)
y = [torch.tensor(i).to(device) for i in y]  # adapt if needed
y_box = [torch.tensor(i).to(device) for i in y_box]
loss_fn = Loss(max_classes, 0.5, 0.15, 0.5, warmup_epochs=0, learn_balance_weights=False)

resnet = resnet34(pretrained=True)
net_resnet = MultiObjectNet(in_size=224, 
                     max_objects=MAX_OBJS, 
                     num_classes=max_classes, 
                     backbone=nn.Sequential(
                         *list(resnet.children())[:-2], # Remove avgpool and fc
                         nn.Conv2d(512, 256, 1) # Dimensionality reduction
                     )
                    )

#net = SimpleNet()
optimizer = torch.optim.Adam([
        {"params": net_resnet.backbone.parameters(), "lr": 1e-5},
        {"params": net_resnet.decoder.parameters(), "lr": 2e-4},
        {"params": net_resnet.queries, "lr": 2e-4},
        {"params": net_resnet.category.parameters(), "lr": 2e-4},
        {"params": net_resnet.boxes.parameters(), "lr": 2e-4},
    ])
device = torch.device("mps")
net_resnet.to(device)
net_resnet.train()
for i in range(500):
    optimizer.zero_grad()
    out, boxes_out = net_resnet(x)
    loss, _ = loss_fn(out, y, boxes_out, y_box)
    loss.backward()
    optimizer.step()

    if i % 50 == 0:
        probs = out.softmax(dim=-1)
        scores, labels = probs.max(dim=-1)

        matched_indices = loss_fn.hungarian_matcher(out, y, boxes_out, y_box)  # returns list of (src_idx, tgt_idx)

        correct = 0
        total = 0
        for b, (src_idx, tgt_idx) in enumerate(matched_indices):
            # Keep only predictions with score > threshold and not class 0
            valid = (labels[b][src_idx] != 0) & (scores[b][src_idx] > 0.5)
            filtered_pred = labels[b][src_idx][valid]
            filtered_gt = y[b][tgt_idx][valid]

            correct += (filtered_pred == filtered_gt).sum().item()
            total += len(filtered_gt)

        acc = correct / total if total > 0 else 0.0
        print(f"Step {i}: Loss={loss.item():.4f}, Acc={acc:.2f}")