# Model Training – Faster R-CNN

## Objective

This notebook trains a Faster R-CNN object detection model on the WAID dataset.
The goal is to evaluate a high-accuracy region-based detector for wildlife
detection in aerial drone imagery, using the preprocessing and data loading
pipeline defined in `src/data/`.


### Imports

In [None]:
# Core
import os
import sys
import time
import torch
import numpy as np

# PyTorch
from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# Utilities
from tqdm import tqdm
from collections import defaultdict, Counter

# Custom
from src.data.dataset import WAIDDataset
from src.data.augmentations import get_train_transforms, get_val_transforms


### Device & Reproducibility

In [2]:
torch.cuda.empty_cache()

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

Using device: cuda


In [4]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 

### Paths & Classes

In [5]:
PROJECT_ROOT = os.path.abspath("..")
sys.path.insert(0, PROJECT_ROOT)

DATA_DIR = os.path.join(PROJECT_ROOT, "data")
RAW_DATA_DIR = os.path.join(DATA_DIR, "raw")

IMAGE_DIR = os.path.join(RAW_DATA_DIR, "images")
ANNOTATION_DIR = os.path.join(RAW_DATA_DIR, "annotations")

CLASSES_PATH = os.path.join(DATA_DIR, "classes.txt")

with open(CLASSES_PATH) as f:
    CLASS_NAMES = [line.strip() for line in f if line.strip()]

NUM_CLASSES = len(CLASS_NAMES) + 1  # +1 for background

SPLITS = ["train", "valid"]


### Load File Lists

In [None]:
image_files = defaultdict(list)

for split in SPLITS:
    split_dir = os.path.join(IMAGE_DIR, split)
    image_files[split] = sorted([
        f for f in os.listdir(split_dir)
        if f.lower().endswith((".jpg", ".png"))
    ])

    print(f"{split}: {len(image_files[split])} images")

train: 10056 images
valid: 2873 images


## Class Aware Sampling

### Compute class frequency per image

In [None]:
def compute_image_class_presence(annotation_dir, image_files):
    """
    Returns:
        image_classes: dict {img_name: set(class_ids)}
        class_counts: Counter {class_id: number of images containing it}
    """
    image_classes = {}
    class_counts = Counter()

    for img_name in image_files:
        ann_path = os.path.join(
            annotation_dir,
            os.path.splitext(img_name)[0] + ".txt"
        )

        classes_in_image = set()

        with open(ann_path) as f:
            for line in f:
                cid = int(line.split()[0])
                classes_in_image.add(cid)

        image_classes[img_name] = classes_in_image

        for cid in classes_in_image:
            class_counts[cid] += 1

    return image_classes, class_counts

In [8]:
train_image_classes, class_counts = compute_image_class_presence(
    annotation_dir=os.path.join(ANNOTATION_DIR, "train"),
    image_files=image_files["train"]
)

print("Class counts (images containing class):")
for cid, count in class_counts.items():
    print(f"{CLASS_NAMES[cid]}: {count}")

Class counts (images containing class):
cattle: 3267
sheep: 2920
seal: 2344
kiang: 546
zebra: 451
camelus: 528


### Compute image sampling weights

In [None]:
def compute_image_weights(image_classes, class_counts, num_classes):
    """
    Compute a sampling weight per image.
    """
    class_freq = {
        c: class_counts[c] for c in range(num_classes)
    }

    image_weights = []

    for img_name, classes in image_classes.items():
        if len(classes) == 0:
            image_weights.append(0.0)
            continue

        # Weight = average inverse frequency of classes in image
        weights = [
            1.0 / class_freq[c]
            for c in classes
            if class_freq[c] > 0
        ]

        image_weights.append(np.mean(weights))

    return image_weights

In [10]:
image_weights = compute_image_weights(
    train_image_classes,
    class_counts,
    NUM_CLASSES
)

print("Example image weights:", image_weights[:10])

Example image weights: [np.float64(0.00030609121518212427), np.float64(0.00030609121518212427), np.float64(0.00030609121518212427), np.float64(0.00030609121518212427), np.float64(0.00030609121518212427), np.float64(0.00030609121518212427), np.float64(0.00030609121518212427), np.float64(0.00030609121518212427), np.float64(0.00030609121518212427), np.float64(0.00030609121518212427)]


### Create a WeightedRandomSampler

In [None]:
sampler = WeightedRandomSampler(
    weights=image_weights,
    num_samples=len(image_weights),
    replacement=True
)

### Dataset Instantiation

In [None]:
train_dataset = WAIDDataset(
    image_dir=os.path.join(IMAGE_DIR, "train"),
    annotation_dir=os.path.join(ANNOTATION_DIR, "train"),
    image_files=image_files["train"],
    num_classes=len(CLASS_NAMES),
    transforms=get_train_transforms()
)

val_dataset = WAIDDataset(
    image_dir=os.path.join(IMAGE_DIR, "valid"),
    annotation_dir=os.path.join(ANNOTATION_DIR, "valid"),
    image_files=image_files["valid"],
    num_classes=len(CLASS_NAMES),
    transforms=get_val_transforms()
)

  from .autonotebook import tqdm as notebook_tqdm
  original_init(self, **validated_kwargs)
  self._set_keys()


### DataLoaders

In [13]:
def collate_fn(batch):
    images = [item["image"] for item in batch]

    targets = []
    for item in batch:
        targets.append({
            "bboxes": item["bboxes"],
            "labels": item["labels"],
            "image_size": item["image_size"]
        })

    return images, targets


train_loader = DataLoader(
    train_dataset,
    batch_size=3,
    sampler=sampler,
    num_workers=0,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=3,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)

### Load Pretrained Faster R-CNN

In [14]:
model = fasterrcnn_resnet50_fpn(weights="DEFAULT")

### Replace Classifier Head

In [15]:
in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(
    in_features,
    NUM_CLASSES
)

model.to(DEVICE)

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

### Optimizer & Scheduler

In [16]:
params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

### Convert YOLO normalized boxes to absolute XYXY format

In [17]:
def yolo_to_xyxy(bboxes, image_size):
    h, w = image_size
    boxes = []

    # bboxes is a torch Tensor [N,4] in yolo normalized
    b = bboxes.detach().cpu().numpy() if hasattr(bboxes, "detach") else np.array(bboxes)

    for x, y, bw, bh in b:
        x1 = (x - bw/2) * w
        y1 = (y - bh/2) * h
        x2 = (x + bw/2) * w
        y2 = (y + bh/2) * h

        # reorder just in case
        x1, x2 = (min(x1, x2), max(x1, x2))
        y1, y2 = (min(y1, y2), max(y1, y2))

        # clip to image bounds
        x1 = max(0.0, min(x1, w - 1.0))
        y1 = max(0.0, min(y1, h - 1.0))
        x2 = max(0.0, min(x2, w - 1.0))
        y2 = max(0.0, min(y2, h - 1.0))

        # drop degenerate boxes
        if (x2 - x1) > 1.0 and (y2 - y1) > 1.0:
            boxes.append([x1, y1, x2, y2])

    if len(boxes) == 0:
        return torch.zeros((0, 4), dtype=torch.float32)

    return torch.tensor(boxes, dtype=torch.float32)

### One Epoch Training Function

In [18]:
def train_one_epoch(model, optimizer, dataloader, device):
    model.train()
    epoch_loss = 0.0

    for images, targets in tqdm(dataloader):
        images = [img.to(device) for img in images]

        formatted_targets = []
        for t in targets:
            boxes_xyxy = yolo_to_xyxy(
                t["bboxes"], t["image_size"]
            ).to(device)

            formatted_targets.append({
                "boxes": boxes_xyxy,
                "labels": t["labels"].to(device)
            })

        loss_dict = model(images, formatted_targets)
        loss = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        del loss, loss_dict

    return epoch_loss / len(dataloader)


### One Epoch Validation Function

In [19]:
@torch.no_grad()
def validate_one_epoch(model, dataloader, device):
    model.train()   # ⚠️ yes, TRAIN mode on purpose
    total_loss = 0.0
    n_batches = 0

    for images, targets in dataloader:
        images = [img.to(device) for img in images]

        formatted_targets = []
        for t in targets:
            boxes_xyxy = yolo_to_xyxy(
                t["bboxes"], t["image_size"]
            ).to(device)

            formatted_targets.append({
                "boxes": boxes_xyxy,
                "labels": t["labels"].to(device)
            })

        loss_dict = model(images, formatted_targets)
        loss = sum(v for v in loss_dict.values())

        total_loss += loss.item()
        n_batches += 1

    if n_batches == 0:
        return float("nan")

    return total_loss / n_batches

### Training and Validation Loop

In [20]:
NUM_EPOCHS = 5

for epoch in range(NUM_EPOCHS):
    start = time.time()

    train_loss = train_one_epoch(
        model, optimizer, train_loader, DEVICE
    )

    val_loss = validate_one_epoch(
        model, val_loader, DEVICE
    )

    lr_scheduler.step()

    print(
        f"Epoch [{epoch+1}/{NUM_EPOCHS}] | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"Time: {time.time() - start:.1f}s"
    )

100%|██████████| 3352/3352 [44:06<00:00,  1.27it/s]


Epoch [1/5] | Train Loss: 2.0939 | Val Loss: 0.9256 | Time: 3026.7s


100%|██████████| 3352/3352 [44:28<00:00,  1.26it/s]


Epoch [2/5] | Train Loss: 2.0824 | Val Loss: 0.9589 | Time: 3047.6s


100%|██████████| 3352/3352 [44:05<00:00,  1.27it/s]


Epoch [3/5] | Train Loss: 2.0618 | Val Loss: 0.9341 | Time: 3024.1s


100%|██████████| 3352/3352 [43:58<00:00,  1.27it/s]


Epoch [4/5] | Train Loss: 1.8045 | Val Loss: 0.8704 | Time: 3015.0s


100%|██████████| 3352/3352 [44:00<00:00,  1.27it/s]


Epoch [5/5] | Train Loss: 1.7580 | Val Loss: 0.8331 | Time: 3018.5s


### Saving the Model

In [21]:
os.makedirs("outputs/models", exist_ok=True)

torch.save(
    model.state_dict(),
    "outputs/models/faster_rcnn_waid.pth"
)