In [1]:
import transformers
import torch
import torchvision
import cv2
from PIL import Image
import json
import numpy as np
from matplotlib import pyplot as plt
from glob import glob
import os

device = "cuda"
model_name = "SenseTime/deformable-detr"
coco_folder='./partnetsim-1024-fixed-viewpoints/coco'

In [2]:
# Deformable-DETR has an unfortunate assumption in the code of labels being 0...N-1 for N labels. Therefore, some pre-processing is required as our labels start at 1. 
for split in ["train", "val"]:
    ann_file = os.path.join(coco_folder, "coco_annotation", "MotionNet_train.json" if split=="train" else "MotionNet_valid.json")
    with open(ann_file, 'r') as f:
        coco_data = json.load(f)
    
    for category in coco_data['categories']:
        category['id'] -= 1
    
    modified_ann_file = os.path.join(coco_folder, "coco_annotation", 'modified_' + os.path.basename(ann_file))
    with open(modified_ann_file, 'w') as f:
        json.dump(coco_data, f)

In [3]:
class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, coco_folder, processor, train=True):
        ann_file = os.path.join(coco_folder, "coco_annotation", "modified_MotionNet_train.json" if train else "modified_MotionNet_valid.json")
        super(CocoDetection, self).__init__(os.path.join(coco_folder, "train/origin" if train else "valid/origin"), ann_file)
        self.processor = processor

    def __getitem__(self, idx):
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        
        modified_target = []
        for annotation in target:
            annotation['category_id'] -= 1
            modified_target.append(annotation)
        
        target = {'image_id': image_id, 'annotations': modified_target}
        encoding = self.processor(images=img, annotations=target, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze()
        target = encoding["labels"][0]
        return pixel_values, target


In [4]:
from transformers import DeformableDetrImageProcessor

processor = DeformableDetrImageProcessor.from_pretrained(model_name)

train_dataset = CocoDetection(coco_folder='./partnetsim-1024-fixed-viewpoints/coco', processor=processor)
val_dataset = CocoDetection(coco_folder='./partnetsim-1024-fixed-viewpoints/coco', processor=processor, train=False)

The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.


loading annotations into memory...
Done (t=0.05s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [5]:
cats = train_dataset.coco.cats
id2label = {k: v['name'] for k,v in cats.items()}
label2id = {v: k for k,v in id2label.items()}

In [7]:
from torch.utils.data import DataLoader

def collate_fn(batch):
  pixel_values = [item[0] for item in batch]
  encoding = processor.pad(pixel_values, return_tensors="pt")
  labels = [item[1] for item in batch]
  batch = {}
  batch['pixel_values'] = encoding['pixel_values']
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = labels
  return batch

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4, shuffle=True, pin_memory=True, prefetch_factor=4, num_workers=4)
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=1, pin_memory=True, prefetch_factor=4, num_workers=4)

In [8]:
import pytorch_lightning as pl
from transformers import DeformableDetrForObjectDetection
import torch

class DeformableDetr(pl.LightningModule):
   def __init__(self, lr, lr_backbone, weight_decay, id2label, label2id):
      super().__init__()
      self.model = DeformableDetrForObjectDetection.from_pretrained(model_name,
                                                            num_labels=len(id2label),
                                                            id2label=id2label,
                                                            ignore_mismatched_sizes=True)
      self.model = self.model.to(device)
      self.lr = lr
      self.lr_backbone = lr_backbone
      self.weight_decay = weight_decay

   def forward(self, pixel_values, pixel_mask):
      outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask)
      return outputs

   def common_step(self, batch, batch_idx):
      pixel_values = batch["pixel_values"]
      pixel_mask = batch["pixel_mask"]
      labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

      outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

      loss = outputs.loss
      loss_dict = outputs.loss_dict

      return loss, loss_dict

   def training_step(self, batch, batch_idx):
      loss, loss_dict = self.common_step(batch, batch_idx)
      self.log("training_loss", loss)
      for k,v in loss_dict.items():
         self.log("train_" + k, v.item())

      return loss

   def validation_step(self, batch, batch_idx):
      loss, loss_dict = self.common_step(batch, batch_idx)
      self.log("validation_loss", loss)
      for k,v in loss_dict.items():
         self.log("validation_" + k, v.item())
      return loss

   def configure_optimizers(self):
      param_dicts = [
            {"params": [p for n, p in self.named_parameters() if "backbone" not in n and p.requires_grad]},
            {
               "params": [p for n, p in self.named_parameters() if "backbone" in n and p.requires_grad],
               "lr": self.lr_backbone,
            },
      ]
      optimizer = torch.optim.AdamW(param_dicts, lr=self.lr,
                                 weight_decay=self.weight_decay)

      return optimizer

   def train_dataloader(self):
      return train_dataloader

   def val_dataloader(self):
      return val_dataloader

In [None]:
# Useful for GPUs with tensor cores
torch.set_float32_matmul_precision('medium')

In [None]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from datetime import datetime

wandb_logger = WandbLogger(project="DeformableDETR-pl-finetune")

dirpath = f"checkpoints/DeformableDETR/{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
os.makedirs(dirpath, exist_ok=True)
checkpoint_callback = ModelCheckpoint(
    dirpath=dirpath,
    filename="deformable-detr-{epoch:02d}",
    save_top_k=-1,
    every_n_epochs=5,
    save_last=True,
)

hparams = {"checkpoints_path": dirpath, 
           "lr": 1e-5, 
           "lr_backbone": 1e-5, 
           "weight_decay": 1e-3, 
           "id2label": id2label, 
           "label2id": label2id, 
           "max_steps": 3000, 
           "gradient_clip_val": 0.2, 
           "accelerator": device, 
           "devices": 1, 
           "batch_size": 4, 
           "model_name": model_name}

model = DeformableDetr(lr=hparams["lr"], lr_backbone=hparams["lr_backbone"], weight_decay=hparams["weight_decay"], id2label=id2label, label2id=label2id).to(device)

trainer = pl.Trainer(
    max_steps=hparams["max_steps"],
    gradient_clip_val=hparams["gradient_clip_val"],
    logger=wandb_logger,
    accelerator=device,
    devices=hparams["devices"],
    callbacks=[checkpoint_callback],
)

wandb_logger.log_hyperparams(hparams)

trainer.fit(model)

In [None]:
model.model.push_to_hub(f"diliash/deformable-detr-{dirpath.split('/')[-1]}")
processor.push_to_hub(f"diliash/deformable-detr-{dirpath.split('/')[-1]}")