# Billboard Segmentation Model Training

This notebook:
1. Uploads your clean billboard dataset (1200 CLIP-filtered images)
2. Uses SAM to convert bbox annotations → segmentation polygon masks
3. Trains YOLOv8-seg for instance segmentation
4. Downloads the trained model

**Make sure to select GPU runtime:** Runtime → Change runtime type → T4 GPU

## Step 1: Check GPU & Install Dependencies

In [None]:
import torch
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("No GPU! Go to Runtime -> Change runtime type -> T4 GPU")

In [None]:
!pip install -q git+https://github.com/facebookresearch/segment-anything.git ultralytics opencv-python-headless
from segment_anything import sam_model_registry
print("All packages installed!")

## Step 2: Get Dataset from GitHub

In [None]:
import zipfile
import os

# Download dataset from GitHub release
!wget -q "https://github.com/fxsBulqit/billboard-segmentation/releases/download/v1.0-dataset/clean_dataset.zip"

# Extract
with zipfile.ZipFile('clean_dataset.zip', 'r') as z:
    z.extractall('.')

# Verify
n_images = len(os.listdir('clean_dataset/images'))
n_labels = len(os.listdir('clean_dataset/labels'))
print(f"Extracted: {n_images} images, {n_labels} labels")

## Step 3: Download SAM Model

In [None]:
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -O sam_vit_b.pth
print(f"SAM model downloaded: {os.path.getsize('sam_vit_b.pth') / 1e6:.0f} MB")

## Step 4: Generate Segmentation Masks with SAM

This takes each image + its YOLO bbox, feeds the bbox to SAM as a prompt, and gets back a precise segmentation mask. The mask is then converted to a polygon in YOLO segmentation format.

In [None]:
import cv2
import numpy as np
from pathlib import Path
from segment_anything import sam_model_registry, SamPredictor
from tqdm.notebook import tqdm

# Load SAM
print("Loading SAM...")
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b.pth")
sam.to("cuda")
predictor = SamPredictor(sam)
print("SAM loaded on GPU!")

IMAGES_DIR = Path("clean_dataset/images")
BBOX_LABELS_DIR = Path("clean_dataset/labels")
SEG_LABELS_DIR = Path("clean_dataset/seg_labels")
SEG_LABELS_DIR.mkdir(exist_ok=True)


def bbox_yolo_to_xyxy(yolo_bbox, img_w, img_h):
    cx, cy, bw, bh = yolo_bbox
    x1 = int((cx - bw / 2) * img_w)
    y1 = int((cy - bh / 2) * img_h)
    x2 = int((cx + bw / 2) * img_w)
    y2 = int((cy + bh / 2) * img_h)
    return [max(0, x1), max(0, y1), min(img_w, x2), min(img_h, y2)]


def mask_to_polygon(mask, epsilon_factor=0.005):
    h, w = mask.shape
    contours, _ = cv2.findContours(
        mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )
    if not contours:
        return None

    contour = max(contours, key=cv2.contourArea)
    perimeter = cv2.arcLength(contour, True)
    epsilon = epsilon_factor * perimeter
    simplified = cv2.approxPolyDP(contour, epsilon, True)

    if len(simplified) < 4:
        return None

    points = simplified.reshape(-1, 2).astype(float)
    points[:, 0] /= w
    points[:, 1] /= h
    return np.clip(points, 0.0, 1.0)


# Process all images
images = sorted(IMAGES_DIR.glob("*.jpg"))
success = 0
failed = 0
total_masks = 0

for img_path in tqdm(images, desc="Generating masks"):
    label_path = BBOX_LABELS_DIR / img_path.name.replace('.jpg', '.txt')
    if not label_path.exists():
        failed += 1
        continue

    image = cv2.imread(str(img_path))
    if image is None:
        failed += 1
        continue

    img_h, img_w = image.shape[:2]
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image_rgb)

    with open(label_path) as f:
        lines = f.read().strip().split('\n')

    seg_lines = []
    for line in lines:
        parts = line.strip().split()
        if len(parts) < 5:
            continue

        class_id = parts[0]
        yolo_bbox = [float(x) for x in parts[1:5]]
        x1, y1, x2, y2 = bbox_yolo_to_xyxy(yolo_bbox, img_w, img_h)

        input_box = np.array([x1, y1, x2, y2])
        masks, scores, _ = predictor.predict(
            box=input_box, multimask_output=True
        )

        best_idx = scores.argmax()
        mask = masks[best_idx]
        polygon = mask_to_polygon(mask)

        if polygon is None:
            # Fallback to rectangle from bbox
            cx, cy, bw, bh = yolo_bbox
            polygon = np.array([
                [cx - bw/2, cy - bh/2],
                [cx + bw/2, cy - bh/2],
                [cx + bw/2, cy + bh/2],
                [cx - bw/2, cy + bh/2],
            ])

        points_str = " ".join(f"{p[0]:.6f} {p[1]:.6f}" for p in polygon)
        seg_lines.append(f"{class_id} {points_str}")
        total_masks += 1

    if seg_lines:
        seg_path = SEG_LABELS_DIR / img_path.name.replace('.jpg', '.txt')
        with open(seg_path, 'w') as f:
            f.write('\n'.join(seg_lines) + '\n')
        success += 1
    else:
        failed += 1

print(f"\nDone! {success} images, {total_masks} masks generated, {failed} failed")

## Step 5: Visualize Some SAM Masks

Let's check that SAM did a good job before training.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random

def show_mask_result(img_name):
    img = cv2.imread(str(IMAGES_DIR / img_name))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]

    seg_path = SEG_LABELS_DIR / img_name.replace('.jpg', '.txt')
    if not seg_path.exists():
        print(f"No seg label for {img_name}")
        return

    fig, ax = plt.subplots(1, 1, figsize=(10, 8))
    ax.imshow(img)

    with open(seg_path) as f:
        for line in f:
            parts = line.strip().split()
            coords = [float(x) for x in parts[1:]]
            points = [(coords[i] * w, coords[i+1] * h) for i in range(0, len(coords), 2)]

            polygon = patches.Polygon(
                points, closed=True, fill=True,
                facecolor='lime', edgecolor='red',
                alpha=0.3, linewidth=2
            )
            ax.add_patch(polygon)

    ax.set_title(img_name)
    ax.axis('off')
    plt.tight_layout()
    plt.show()

# Show 6 random examples
all_seg = [f.name.replace('.txt', '.jpg') for f in sorted(SEG_LABELS_DIR.glob('*.txt'))]
samples = random.sample(all_seg, min(6, len(all_seg)))
for s in samples:
    show_mask_result(s)

## Step 6: Prepare Dataset for YOLOv8-seg Training

Split into train/val/test and create data.yaml.

In [None]:
import shutil
import random

TRAIN_DIR = Path("billboard_seg_dataset")

# Clean up if exists
if TRAIN_DIR.exists():
    shutil.rmtree(TRAIN_DIR)

for split in ['train', 'val', 'test']:
    (TRAIN_DIR / split / 'images').mkdir(parents=True)
    (TRAIN_DIR / split / 'labels').mkdir(parents=True)

# Get all images that have seg labels
seg_files = sorted(SEG_LABELS_DIR.glob('*.txt'))
all_names = [f.stem for f in seg_files]

random.seed(42)
random.shuffle(all_names)

# 80/15/5 split
n = len(all_names)
train_names = all_names[:int(0.80 * n)]
val_names = all_names[int(0.80 * n):int(0.95 * n)]
test_names = all_names[int(0.95 * n):]

def copy_files(names, split):
    for name in names:
        # Image
        src_img = IMAGES_DIR / f"{name}.jpg"
        if src_img.exists():
            shutil.copy2(src_img, TRAIN_DIR / split / 'images' / f"{name}.jpg")
        # Seg label
        src_lbl = SEG_LABELS_DIR / f"{name}.txt"
        if src_lbl.exists():
            shutil.copy2(src_lbl, TRAIN_DIR / split / 'labels' / f"{name}.txt")

copy_files(train_names, 'train')
copy_files(val_names, 'val')
copy_files(test_names, 'test')

print(f"Train: {len(train_names)} images")
print(f"Val:   {len(val_names)} images")
print(f"Test:  {len(test_names)} images")

# Create data.yaml
data_yaml = f"""path: /content/billboard_seg_dataset
train: train/images
val: val/images
test: test/images

nc: 1
names: ['billboard']
"""

with open(TRAIN_DIR / 'data.yaml', 'w') as f:
    f.write(data_yaml)

print(f"\ndata.yaml written to {TRAIN_DIR / 'data.yaml'}")

## Step 7: Train YOLOv8-seg Model

Training a segmentation model on the clean, SAM-annotated dataset.

In [None]:
from ultralytics import YOLO

# Load YOLOv8 segmentation base model
model = YOLO('yolov8m-seg.pt')  # medium model - good balance of speed/accuracy

# Train
results = model.train(
    data='/content/billboard_seg_dataset/data.yaml',
    epochs=100,
    imgsz=640,
    batch=16,
    patience=15,         # early stopping if no improvement for 15 epochs
    device=0,            # GPU
    workers=2,
    name='billboard_seg',
    # Augmentation
    hsv_h=0.015,
    hsv_s=0.5,
    hsv_v=0.3,
    degrees=5.0,
    translate=0.1,
    scale=0.3,
    flipud=0.0,          # no vertical flip (billboards don't flip upside down)
    fliplr=0.5,
    mosaic=0.8,
)

## Step 8: View Training Results

In [None]:
from IPython.display import Image, display
from pathlib import Path

# Find the training run directory
run_dir = sorted(Path('runs/segment').glob('billboard_seg*'))[-1]
print(f"Results in: {run_dir}")

# Show training curves
if (run_dir / 'results.png').exists():
    display(Image(filename=str(run_dir / 'results.png'), width=900))

# Show validation predictions
for img_name in ['val_batch0_pred.jpg', 'val_batch1_pred.jpg']:
    img_path = run_dir / img_name
    if img_path.exists():
        print(f"\n{img_name}:")
        display(Image(filename=str(img_path), width=900))

## Step 9: Test on Sample Images

Run the trained model on test images to see segmentation quality.

In [None]:
import matplotlib.pyplot as plt
import cv2
import numpy as np

# Load best model
best_model = YOLO(str(run_dir / 'weights' / 'best.pt'))

# Get test images
test_images = sorted(Path('billboard_seg_dataset/test/images').glob('*.jpg'))[:8]

fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()

for idx, img_path in enumerate(test_images):
    results = best_model(str(img_path), verbose=False)
    result = results[0]

    # Plot with masks
    annotated = result.plot()
    annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)

    axes[idx].imshow(annotated)
    axes[idx].set_title(img_path.name[:25], fontsize=9)
    axes[idx].axis('off')

plt.suptitle('Segmentation Results on Test Set', fontsize=14)
plt.tight_layout()
plt.show()

## Step 10: Download Trained Model

Download the best model weights to use locally.

In [None]:
from google.colab import files
import shutil

best_weights = run_dir / 'weights' / 'best.pt'
output_name = 'billboard_seg_best.pt'

shutil.copy2(best_weights, output_name)
print(f"Model size: {os.path.getsize(output_name) / 1e6:.1f} MB")

# Also zip the full results
shutil.make_archive('training_results', 'zip', str(run_dir))

print("\nDownloading model...")
files.download(output_name)

print("\nDownloading full training results...")
files.download('training_results.zip')

## Step 11: Quick Replacement Test

Test the new segmentation model does billboard replacement with actual polygon masks instead of bboxes.

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

def replace_billboard_seg(image_path, ad_path, model):
    """Replace billboards using segmentation masks."""
    img = cv2.imread(str(image_path))
    ad = cv2.imread(str(ad_path))

    results = model(str(image_path), verbose=False)
    result = results[0]

    if result.masks is None:
        print(f"No billboards found in {image_path}")
        return img

    output = img.copy()

    for i, mask in enumerate(result.masks.xy):
        if len(mask) < 4:
            continue

        # Get the polygon points
        pts = mask.astype(np.int32)

        # Get bounding rect of the polygon for perspective transform
        rect = cv2.boundingRect(pts)
        x, y, w, h = rect

        if w < 20 or h < 20:
            continue

        # Find the 4 corner-like points for perspective transform
        # (top-left, top-right, bottom-right, bottom-left)
        hull = cv2.convexHull(pts)
        epsilon = 0.02 * cv2.arcLength(hull, True)
        approx = cv2.approxPolyDP(hull, epsilon, True)

        if len(approx) == 4:
            # Perfect quad
            corners = approx.reshape(4, 2).astype(np.float32)
        else:
            # Use bounding rect corners
            corners = np.array([
                [x, y], [x + w, y],
                [x + w, y + h], [x, y + h]
            ], dtype=np.float32)

        # Sort corners: TL, TR, BR, BL
        center = corners.mean(axis=0)
        angles = np.arctan2(corners[:, 1] - center[1], corners[:, 0] - center[0])
        order = np.argsort(angles)
        sorted_corners = corners[order]
        # Rearrange to TL, TR, BR, BL
        tl_idx = np.argmin(sorted_corners[:, 0] + sorted_corners[:, 1])
        sorted_corners = np.roll(sorted_corners, -tl_idx, axis=0)

        # Resize ad to match
        ad_h, ad_w = ad.shape[:2]
        ad_corners = np.array([
            [0, 0], [ad_w, 0],
            [ad_w, ad_h], [0, ad_h]
        ], dtype=np.float32)

        # Perspective transform
        M = cv2.getPerspectiveTransform(ad_corners, sorted_corners)
        warped = cv2.warpPerspective(ad, M, (img.shape[1], img.shape[0]))

        # Create mask from the actual segmentation polygon
        seg_mask = np.zeros(img.shape[:2], dtype=np.uint8)
        cv2.fillPoly(seg_mask, [pts], 255)

        # Apply replacement only within the seg mask
        output[seg_mask > 0] = warped[seg_mask > 0]

    return output


# Test on a few images (upload your ad image ss.png if you want)
# For demo, just show the segmentation masks
test_imgs = sorted(Path('billboard_seg_dataset/test/images').glob('*.jpg'))[:4]

fig, axes = plt.subplots(1, len(test_imgs), figsize=(20, 5))
for idx, img_path in enumerate(test_imgs):
    results = best_model(str(img_path), verbose=False)
    annotated = results[0].plot()
    annotated = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
    axes[idx].imshow(annotated)
    axes[idx].axis('off')
    n_masks = len(results[0].masks.xy) if results[0].masks else 0
    axes[idx].set_title(f"{n_masks} billboard(s)")

plt.suptitle('Segmentation Model Output')
plt.tight_layout()
plt.show()