# Training

## Installations

In [2]:
# pip install transformers torchvision datasets evaluate matplotlib
! pip install -q evaluate

## Id <-> label look-up maps

In [None]:
id2label = {
    0: "Auto", 1: "2-Wheeler", 2: "Bicycle", 3: "Bus", 4: "Hatchback",
    5: "LCV", 6: "Mini-bus", 7: "MUV", 8: "Sedan", 9: "SUV",
    10: "Tempo-traveller", 11: "Truck", 12: "Van", 13: "Vehicle (others)"
}
label2id = {v: k for k, v in id2label.items()}

## Import the dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset('imagefolder', data_dir='dataset/', split='train')

## Load the base DETR model

In [None]:
from transformers import DetrImageProcessor, DetrForObjectDetection

processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained(
    "facebook/detr-resnet-50",
    num_labels=14,
    ignore_mismatched_sizes=True,  # Important to allow changing number of classes
    id2label=id2label,
    label2id=label2id
)

## Encode dataset images

In [None]:
def transform(example):
    image = example['image']
    annotations = {
        "image_id": example["image_id"],
        "annotations": example["objects"]  # Assumes 'objects' field is COCO-style
    }
    encoding = processor(images=image, annotations=annotations, return_tensors="pt")
    encoding = {k: v.squeeze() for k, v in encoding.items()}  # remove batch dimension
    return encoding

In [None]:
train_dataset = dataset["train"].map(transform, remove_columns=dataset["train"].column_names)
val_dataset = dataset["val"].map(transform, remove_columns=dataset["val"].column_names)

## Training parameters

In [None]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="./detr-resnet-50-vehicle-finetuned",
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=25,
    learning_rate=1e-5,
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=2,
    fp16=True
)

## Training loop

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor,  # needed for Trainer to call on batch
)

In [None]:
trainer.train()

## Push to Huggingface

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
trainer.push_to_hub("Facebook's detection transformer architecture with resnet 50 supervised finetuned for detection and classification of vehicles")

# Inference

## Load base model with lora weights

In [None]:
from transformers import DetrForObjectDetection, DetrImageProcessor
import torch
import requests
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches

model_name = "xxx-i-am-raahul-m-xxx/detr-resnet-50-vehicle-finetuned"

model = DetrForObjectDetection.from_pretrained(model_name)
processor = DetrImageProcessor.from_pretrained(model_name)
model.eval()
model.to("cuda" if torch.cuda.is_available() else "cpu")

## Load the image to be inferenced

In [None]:
# Load image
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")

# Preprocess
inputs = processor(images=image, return_tensors="pt").to(model.device)

# Predict
with torch.no_grad():
    outputs = model(**inputs)

# Extract logits and boxes
logits = outputs.logits.softmax(-1)[0, :, :-1]  # exclude the "no object" class
boxes = outputs.pred_boxes[0]

# Get top predictions
scores, labels = logits.max(-1)
keep = scores > 0.8  # threshold
scores = scores[keep]
labels = labels[keep]
boxes = boxes[keep]

# Scale boxes to original image size
width, height = image.size
boxes = boxes * torch.tensor([width, height, width, height])
boxes = boxes.cpu().numpy()
labels = labels.cpu().numpy()
scores = scores.cpu().numpy()

## Visualization

In [None]:
plt.figure(figsize=(12, 8))
plt.imshow(image)
ax = plt.gca()

for box, label, score in zip(boxes, labels, scores):
    x_c, y_c, w, h = box
    x0 = x_c - w / 2
    y0 = y_c - h / 2

    rect = patches.Rectangle((x0, y0), w, h, linewidth=2, edgecolor='red', facecolor='none')
    ax.add_patch(rect)

    class_name = model.config.id2label[label]
    ax.text(x0, y0, f"{class_name}: {score:.2f}", color="black", fontsize=12,
            bbox=dict(facecolor="yellow", alpha=0.5))

plt.axis("off")
plt.show()