# R-CNN Stage 4: Bounding Box Regression

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

from rcnn.config.config import RCNNConfig
from rcnn.train_bbox_regressor import BBoxRegressionTrainer

In [None]:
ANNOTATIONS_CSV = Path("data/airbus/annotations/train_annotations_filtered.csv")
IMAGES_DIR = Path("data/airbus/images/train")
ARTIFACTS_DIR = Path("artifacts/rcnn")

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Using Apple Silicon GPU")
else:
    DEVICE = torch.device("cpu")
    print("Using CPU")

In [None]:
config = RCNNConfig()
config.output_dir = ARTIFACTS_DIR

trainer = BBoxRegressionTrainer(config, device=str(DEVICE))

In [None]:
trainer.train(
    annotations_csv=ANNOTATIONS_CSV,
    images_dir=IMAGES_DIR,
    iou_threshold=0.6,
    batch_size=64
)

In [None]:
trainer.save(ARTIFACTS_DIR)

print(f"\nTrained {len(trainer.bbox_regressors)} class-specific regressors")
for class_id in trainer.bbox_regressors.keys():
    print(f"  - Class {class_id}: bbox_regressor_class_{class_id}.npz")

## Visualize Regression Performance

Test on a sample image to see bbox refinement

In [None]:
from PIL import Image, ImageDraw
import random

annotations = pd.read_csv(ANNOTATIONS_CSV)
sample_image_id = random.choice(annotations['image_id'].unique())

image_path = IMAGES_DIR / sample_image_id
img = Image.open(image_path).convert("RGB")

img_annotations = annotations[annotations['image_id'] == sample_image_id]

proposals = trainer.proposer.generate_proposals(img)
print(f"Generated {len(proposals)} proposals")

plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.imshow(img)
plt.title("Original Proposals")
plt.axis('off')

for prop in proposals[:20]:
    rect = plt.Rectangle(
        (prop[0], prop[1]), prop[2]-prop[0], prop[3]-prop[1],
        fill=False, edgecolor='blue', linewidth=1, alpha=0.5
    )
    plt.gca().add_patch(rect)

plt.subplot(1, 2, 2)
plt.imshow(img)
plt.title("After Bbox Regression")
plt.axis('off')

warped_regions = [trainer.warper.warp_region(img, prop) for prop in proposals[:20]]
warped_batch = torch.stack(warped_regions)
features = trainer.extract_features(warped_batch)

if len(trainer.bbox_regressors) > 0:
    class_id = list(trainer.bbox_regressors.keys())[0]
    regressor = trainer.bbox_regressors[class_id]
    refined_boxes = regressor.predict(features, proposals[:20])
    
    for refined in refined_boxes:
        rect = plt.Rectangle(
            (refined[0], refined[1]), refined[2]-refined[0], refined[3]-refined[1],
            fill=False, edgecolor='red', linewidth=1, alpha=0.5
        )
        plt.gca().add_patch(rect)

plt.tight_layout()
plt.show()