# Train DETR

In [None]:
# !pip install pytorch-lightning

# !wget -q https://github.com/direito-a-sombra/bus-view/releases/latest/download/imgs.tar.gz
# !wget -q https://raw.githubusercontent.com/direito-a-sombra/bus-view/refs/heads/main/data/training/train_boxes.json
# !wget -q https://raw.githubusercontent.com/direito-a-sombra/bus-view/refs/heads/main/data/training/train_files.json
# !tar -xzf imgs.tar.gz

In [None]:
import json

from PIL import Image as PImage
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
from transformers import AutoImageProcessor, DetrForObjectDetection

from torch.utils.data import DataLoader, Dataset
from torch import stack as t_stack
from torch.optim import AdamW

## Create COCO-ish Dataset

In [None]:
# Get info about all training image files

IMG_DIR = "./imgs"
TRAINING_FILES_JSON = "./data/training/train_files.json"

with open(TRAINING_FILES_JSON, "r") as ifp:
  training_files = json.load(ifp)

id2path = {}
for label, d2fs in training_files.items():
  for d ,fs in d2fs.items():
    for f in fs:
      id2path[f"{d}/{f}"] = { "dir": d, "name": f }

id2idx = { id:idx for idx,id in enumerate(sorted(id2path.keys())) }

In [None]:
# Prep COCO-ish dataset

LABEL2ID = {
  "bus_stop": 0,
  "bus_sign": 1,
}
ID2LABEL = { id:label for label,id in LABEL2ID.items() }

with open("./data/training/train_boxes.json", "r") as ifp:
  box_info = json.load(ifp)

tids, vids = train_test_split(list(box_info.keys()), test_size=0.2, random_state=1010)

tids = set(tids)
vids = set(vids)

annotated = {}

for img_id, objs in box_info.items():
  split = "train" if img_id in tids else "val"
  img_dir = id2path[img_id]["dir"]
  img_name = id2path[img_id]["name"]

  img_annotations = []
  for label, (x0,y0,x1,y1) in objs.items():
    bw, bh = int(x1 - x0), int(y1 - y0)
    img_annotations.append({
      "image_id": id2idx[img_id],
      "category_id": LABEL2ID[label],
      "area": int(bw * bh),
      "bbox": [x0, y0, bw, bh]
    })

  annotated[img_id] = {
    "filepath": f"{IMG_DIR}/{img_dir}/{img_name}",
    "image_id": id2idx[img_id],
    "annotations": img_annotations
  }

## PyTorch Dataset

In [None]:
class DetrDataset(Dataset):
  def __init__(self, ids, annotations, processor_name):
    SIZE = { "shortest_edge": 640, "longest_edge": 1280 }
    self.ids = ids
    self.processor = AutoImageProcessor.from_pretrained(processor_name, size=SIZE)
    self.annotations = annotations

  def __len__(self):
    return len(self.ids)

  def __getitem__(self, idx):
    id = self.ids[idx]
    annotations = self.annotations[id]
    fpath = annotations["filepath"]
    img = PImage.open(fpath)

    processed = self.processor(images=img, annotations=annotations, return_tensors="pt")

    return {
      "pixel_values": processed["pixel_values"].squeeze(),
      "labels": processed["labels"][0],
    }

  def collate_fn(self, batch):
    pixel_values = t_stack([x["pixel_values"] for x in batch])
    padded = self.processor.pad(pixel_values, return_tensors="pt")

    return {
      "pixel_values": padded["pixel_values"],
      "pixel_mask": padded["pixel_mask"],
      "labels": [x["labels"] for x in batch],
    }

## PyTorchLightning Class

In [None]:
class DetrTrain(LightningModule):
  def __init__(self, model_name, label2id, lr, lr_backbone, weight_decay):
    super().__init__()
    self.model = DetrForObjectDetection.from_pretrained(
      model_name,
      revision="no_timm",
      label2id=label2id,
      id2label={ id:label for label,id in label2id.items() },
      ignore_mismatched_sizes=True
    )

    self.lr = lr
    self.lr_backbone = lr_backbone
    self.weight_decay = weight_decay

    self.save_hyperparameters()

  def forward(self, pixel_values, pixel_mask):
    return self.model(pixel_values=pixel_values, pixel_mask=pixel_mask)

  def common_step(self, batch, split):
    outputs = self.model(pixel_values=batch["pixel_values"],
                         pixel_mask=batch["pixel_mask"],
                         labels=batch["labels"])

    self.log(f"{split}_loss", outputs.loss)
    for k,v in outputs.loss_dict.items():
      self.log(f"{split}_{k}", v.item())

    return outputs.loss

  def training_step(self, batch, batch_idx):
    loss = self.common_step(batch, "train")
    return loss

  def validation_step(self, batch, batch_idx):
    loss = self.common_step(batch, "val")
    return loss

  def configure_optimizers(self):
    param_dicts = [
      { "params": [p for n,p in self.named_parameters() if "backbone" not in n and p.requires_grad] },
      {
        "params": [p for n,p in self.named_parameters() if "backbone" in n and p.requires_grad],
        "lr": self.lr_backbone,
      },
    ]
    return AdamW(param_dicts, lr=self.lr, weight_decay=self.weight_decay)

## Instantiate

In [None]:
DETR_MODEL_NAME = "facebook/detr-resnet-50"

train_ds = DetrDataset(list(tids), annotations=annotated, processor_name=DETR_MODEL_NAME)
val_ds = DetrDataset(list(vids), annotations=annotated, processor_name=DETR_MODEL_NAME)

train_dl = DataLoader(train_ds, collate_fn=train_ds.collate_fn, batch_size=8, shuffle=True)
val_dl = DataLoader(val_ds, collate_fn=val_ds.collate_fn, batch_size=16, shuffle=False)

detr_train = DetrTrain(DETR_MODEL_NAME, LABEL2ID, lr=5e-5, lr_backbone=1e-5, weight_decay=1e-4)

## Test

In [None]:
batch0 = next(iter(val_dl))
out0 = detr_train.to("cuda")(pixel_values=batch0["pixel_values"].to("cuda"), pixel_mask=batch0["pixel_mask"].to("cuda"))
out0.logits.shape

## Train

In [None]:
checkpoint_callback = ModelCheckpoint(every_n_epochs=2, save_top_k=4, monitor="val_loss", save_last=True)

trainer = Trainer(max_epochs=128,
                  gradient_clip_val=0.1,
                  fast_dev_run=False,
                  log_every_n_steps=16,
                  callbacks=[checkpoint_callback])

detr_train.train()
trainer.fit(detr_train, train_dl, val_dl)

## Plot Training Metrics

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

df = pd.read_csv("./lightning_logs/version_1/metrics.csv")

train_df = df[df["val_loss"].isna()].dropna(axis=1)
val_df = df[df["train_loss"].isna()].dropna(axis=1)

In [None]:
t_col_plot = [c for c in train_df.columns if c.startswith("train_")]

for c in t_col_plot:
  plt.scatter(train_df["epoch"], train_df[c], s=4)
  plt.plot(train_df["epoch"], train_df[c])
  plt.ylabel(c)
  plt.show()

In [None]:
v_col_plot = [c for c in val_df.columns if c.startswith("val_")]

for c in v_col_plot:
  plt.scatter(val_df["epoch"], val_df[c], s=4)
  plt.plot(val_df["epoch"], val_df[c])
  plt.ylabel(c)
  plt.show()