# Training

## Installations

In [None]:
# pip install transformers datasets peft bitsandbytes accelerate torchvision
!pip install bitsandbytes

## Imports

In [None]:
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training, TaskType
from datasets import load_dataset, Features, Sequence, Value, Array2D, Array3D
import bitsandbytes as bnb

## Id <-> label look-up maps

In [None]:
id2label = {
    0: "Auto", 1: "2-Wheeler", 2: "Bicycle", 3: "Bus", 4: "Hatchback",
    5: "LCV", 6: "Mini-bus", 7: "MUV", 8: "Sedan", 9: "SUV",
    10: "Tempo-traveller", 11: "Truck", 12: "Van", 13: "Vehicle (others)"
}
label2id = {v: k for k, v in id2label.items()}

## Load processor and model (4-bit quantized)

In [None]:
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")

model = DetrForObjectDetection.from_pretrained(
    "facebook/detr-resnet-50",
    num_labels=len(id2label),
    ignore_mismatched_sizes=True,
    id2label=id2label,
    label2id=label2id,
    device_map="auto",
    load_in_4bit=True,
    quantization_config=bnb.nn.Linear4bitLt.QuantizationConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
)

## Configure LoRA

In [None]:
model = prepare_model_for_kbit_training(model)

peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.OBJECT_DETECTION,
    target_modules=["q_proj", "k_proj", "v_proj", "ffn"]  # adjust based on DETR layer names
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

## Load dataset

In [None]:
dataset = load_dataset(
    "ybelkada/coco-2017",  # sample COCO-style dataset (replace with yours)
    split={"train": "train[:100]", "val": "validation[:50]"}  # for example/testing
)

## Preprocessing

In [None]:
def preprocess(example):
    image = example['image']
    annotations = {
        "image_id": example["image_id"],
        "annotations": example["objects"]
    }
    encoding = processor(images=image, annotations=annotations, return_tensors="pt")
    encoding = {k: v.squeeze() for k, v in encoding.items()}
    return encoding

train_dataset = dataset["train"].map(preprocess, remove_columns=dataset["train"].column_names)
val_dataset = dataset["val"].map(preprocess, remove_columns=dataset["val"].column_names)

## Training parameters

In [None]:
args = TrainingArguments(
    output_dir="./detr-resnet-50-vehicle-lora",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    num_train_epochs=10,
    learning_rate=2e-5,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=10,
    fp16=True,
    push_to_hub=False,
    report_to="none"
)

## Training loop configuration

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor,
)

In [None]:
trainer.train()

## Push to Huggingface

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
trainer.push_to_hub("Facebook's detection transformer architecture with resnet 50 supervised finetuned for detection and classification of vehicles")

# Inference

## Load base model with lora weights

In [None]:
from transformers import DetrImageProcessor, DetrForObjectDetection
from peft import PeftModel
import torch
import requests
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches

base_model_name = "facebook/detr-resnet-50"
model_name = "your-username/detr-resnet-50-vehicle-lora"

# Load base model
base_model = DetrForObjectDetection.from_pretrained(
    base_model_name,
    num_labels=len(id2label),
    ignore_mismatched_sizes=True,
    id2label=id2label,
    label2id=label2id
    device_map="auto",
    load_in_4bit=True
)

# Apply PEFT weights
model = PeftModel.from_pretrained(base_model, model_name)
model.eval()

# Load processor
processor = DetrImageProcessor.from_pretrained(base_model_name)

## Load the image to be inferenced

In [None]:
# Load an image
url = "https://c8.alamy.com/comp/2BFNHGX/group-of-cute-cats-on-white-background-2BFNHGX.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# Preprocess
inputs = processor(images=image, return_tensors="pt")

# Inference
with torch.no_grad():
    outputs = model(**inputs)

# Get logits and boxes
logits = outputs.logits
boxes = outputs.pred_boxes

# Apply softmax to get class probabilities
probs = logits.softmax(-1)[0, :, :-1]  # Remove "no-object" class
scores, labels = probs.max(-1)

# Thresholding detections
threshold = 0.9
keep = scores > threshold

## Visualization

In [None]:
plt.figure(figsize=(10, 8))
plt.imshow(image)
ax = plt.gca()

for score, label, box in zip(scores[keep], labels[keep], boxes[0][keep]):
    box = box.cpu() * torch.tensor([image.width, image.height, image.width, image.height])
    x_center, y_center, width, height = box
    x = x_center - width / 2
    y = y_center - height / 2

    rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor='red', facecolor='none')
    ax.add_patch(rect)
    label_name = model.config.id2label[label.item()]
    ax.text(x, y, f"{label_name}: {score:.2f}", bbox=dict(facecolor='yellow', alpha=0.5))

plt.axis('off')
plt.title("DETR Object Detection")
plt.show()