## Installations

In [None]:
# pip install torch torchvision matplotlib tqdm

## Imports

In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from tqdm import tqdm

## Define Dataset and Collate Function

In [None]:
transform = T.Compose([
    T.ToTensor(),
])

dataset = VehicleDataset(image_dir="images", label_dir="labels", transform=transform)

def collate_fn(batch):
    images, targets = zip(*batch)
    return list(images), list(targets)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True)

## 2. Load Faster R-CNN and Modify for Custom Classes

In [None]:
num_classes = 1 + len(set([lbl for img, tgt in dataset for lbl in tgt["labels"]]))  # 1 background + N classes

model = fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

## 3. Optimizer & Training Loop

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=1e-4)

num_epochs = 10
model.train()

for epoch in range(num_epochs):
    total_loss = 0.0
    progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for images, targets in progress:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress.set_postfix(loss=loss.item())

    print(f"Epoch {epoch+1}: Average Loss = {total_loss / len(dataloader):.4f}")

## Inference

In [None]:
model.eval()
image = dataset[0][0].to(device)
with torch.no_grad():
    predictions = model([image])
    # output: boxes, labels, scores