# DrishT: SSDLite-MobileNetV3 Text Detection Training

**Model**: SSDLite320 + MobileNetV3-Large (~3.4M params)  
**Dataset**: 7,344 train / 915 val / 915 test images (COCO JSON)  
**Categories**: text, license_plate, traffic_sign, autorickshaw, tempo, truck, bus  
**GPU**: Kaggle T4 (free, 30h/week)  

## Setup
1. Add dataset `djt5ingh/drisht-detection` to this notebook
2. Enable GPU: Settings → Accelerator → GPU T4 x2
3. Run all cells

## Key Design Decisions
- **Pre-loaded tensor**: All images loaded into uint8 RAM tensor at init (~2.1 GB) → zero I/O during training
- **Backbone freeze**: First 5 epochs train only the SSD head, then full fine-tuning
- **Gradient clipping**: Max norm 10.0 to prevent exploding gradients from SSD loss
- **Resumable checkpoints**: Every 10 epochs with optimizer/scheduler/scaler state
- **AMP**: Full fp16 (SSD loss is stable in fp16, unlike CTC)

In [None]:
import os, sys, json, time, shutil, random
from pathlib import Path
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
import torchvision
from torchvision.models.detection import ssdlite320_mobilenet_v3_large
from torchvision.models import MobileNet_V3_Large_Weights
import torchvision.transforms.functional as TF
import torchvision.transforms as T
from PIL import Image
from tqdm.auto import tqdm

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = False  # True hurts perf
    torch.backends.cudnn.benchmark = True        # Autotune conv algos for fixed 320x320

print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Configuration

In [None]:
# --- Auto-discover data root ---
OUTPUT_DIR = Path('/kaggle/working/detection_output')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# DIAGNOSTIC: Show what's actually mounted
_INPUT = Path('/kaggle/input')
if _INPUT.exists():
    mounted = sorted(d.name for d in _INPUT.iterdir())
    print(f'Mounted datasets: {mounted}')
else:
    mounted = []
    print('WARNING: /kaggle/input does not exist (not running on Kaggle?)')

# Search ALL mounted datasets for our annotations.json
DATA_ROOT = None
for base in (_INPUT.iterdir() if _INPUT.exists() else []):
    # Check flat: /kaggle/input/<slug>/train/annotations.json
    if (base / 'train' / 'annotations.json').exists():
        DATA_ROOT = base
        break
    # Check nested: /kaggle/input/<slug>/<subfolder>/train/annotations.json
    if base.is_dir():
        for sub in base.iterdir():
            if sub.is_dir() and (sub / 'train' / 'annotations.json').exists():
                DATA_ROOT = sub
                break
    if DATA_ROOT:
        break

if DATA_ROOT is None:
    drisht = [d for d in mounted if 'drisht' in d.lower() or 'detect' in d.lower()]
    raise FileNotFoundError(
        f'Cannot find train/annotations.json under /kaggle/input/.\n'
        f'Mounted datasets: {mounted}\n'
        f'Possible matches: {drisht}\n\n'
        f'FIX: Make sure you:\n'
        f'  1. Click "Add Input" (right sidebar) → search "drisht-detection" → Add\n'
        f'  2. RESTART the kernel session (adding data requires restart)\n'
        f'  3. Re-run all cells'
    )

print(f'DATA_ROOT: {DATA_ROOT}')

# Paths
TRAIN_JSON = DATA_ROOT / 'train' / 'annotations.json'
TRAIN_IMAGES = DATA_ROOT / 'train' / 'images'
VAL_JSON = DATA_ROOT / 'val' / 'annotations.json'
VAL_IMAGES = DATA_ROOT / 'val' / 'images'

# Model
NUM_CLASSES = 8  # 7 categories + background
INPUT_SIZE = 320

# Training
BATCH_SIZE = 32    # SSD320 is small enough for T4 at bs=32
NUM_WORKERS = 0    # Data pre-loaded into RAM — no I/O to parallelize
EPOCHS = 80
LR = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 4e-5
LR_MIN = 1e-6
FREEZE_BACKBONE_EPOCHS = 5
PATIENCE = 12
USE_AMP = True

# Category names
CATEGORIES = {
    0: 'background', 1: 'text', 2: 'license_plate', 3: 'traffic_sign',
    4: 'autorickshaw', 5: 'tempo', 6: 'truck', 7: 'bus'
}

# Verify data
for p in [TRAIN_JSON, VAL_JSON]:
    assert p.exists(), f'Missing: {p}'
print(f'Train images: {len(list(TRAIN_IMAGES.iterdir()))}')
print(f'Val images: {len(list(VAL_IMAGES.iterdir()))}')
print(f'Workers: {NUM_WORKERS} (all data pre-loaded into RAM tensors)')

## Dataset

In [None]:
class COCODetectionDataset(Dataset):
    """Pre-loads ALL images into a uint8 tensor at init — zero disk I/O during training.
    7,344 images × 3 × 320 × 320 × 1 byte ≈ 2.1 GB (fits in Kaggle's 13 GB RAM).
    """

    def __init__(self, json_path, img_dir, augment=False, input_size=320):
        with open(json_path, 'r') as f:
            coco = json.load(f)
        self.input_size = input_size
        self.augment = augment

        images_meta = {img['id']: img for img in coco['images']}
        img_to_anns = defaultdict(list)
        for ann in coco['annotations']:
            img_to_anns[ann['image_id']].append(ann)
        img_ids = [iid for iid in images_meta if len(img_to_anns[iid]) > 0]

        # Pre-load ALL images into a single uint8 tensor
        n = len(img_ids)
        self.data = torch.zeros(n, 3, input_size, input_size, dtype=torch.uint8)
        self.targets = []  # List of (boxes_tensor, labels_tensor)
        img_dir = Path(img_dir)

        print(f'  Pre-loading {n} images into tensor...')
        t0 = time.time()
        for i, iid in enumerate(img_ids):
            img_info = images_meta[iid]
            try:
                image = Image.open(img_dir / img_info['file_name']).convert('RGB')
                orig_w, orig_h = image.size
                image = image.resize((input_size, input_size), Image.BILINEAR)
                arr = np.array(image)  # (H, W, 3) uint8
                self.data[i] = torch.from_numpy(arr.transpose(2, 0, 1))

                # Pre-scale boxes to input_size coordinates
                scale_x = input_size / orig_w
                scale_y = input_size / orig_h
                boxes, labels = [], []
                for ann in img_to_anns[iid]:
                    x, y, w, h = ann['bbox']
                    if w <= 0 or h <= 0: continue
                    x1 = max(0.0, x * scale_x)
                    y1 = max(0.0, y * scale_y)
                    x2 = min(float(input_size), (x + w) * scale_x)
                    y2 = min(float(input_size), (y + h) * scale_y)
                    if x2 - x1 > 1 and y2 - y1 > 1:
                        boxes.append([x1, y1, x2, y2])
                        labels.append(ann['category_id'])

                if boxes:
                    self.targets.append((
                        torch.tensor(boxes, dtype=torch.float32),
                        torch.tensor(labels, dtype=torch.int64),
                    ))
                else:
                    self.targets.append((
                        torch.zeros((0, 4), dtype=torch.float32),
                        torch.zeros((0,), dtype=torch.int64),
                    ))
            except Exception:
                self.targets.append((
                    torch.zeros((0, 4), dtype=torch.float32),
                    torch.zeros((0,), dtype=torch.int64),
                ))

            if (i + 1) % 2000 == 0:
                print(f'    {i+1}/{n} ({time.time()-t0:.0f}s)')

        elapsed = time.time() - t0
        mb = self.data.nbytes / 1024**2
        print(f'  Done: {self.data.shape} uint8 tensor, {mb:.0f} MB, {elapsed:.1f}s')

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        # uint8 → float32 [0, 1] (torchvision detection models expect this range)
        image = self.data[idx].float().div_(255.0)
        boxes, labels = self.targets[idx]
        boxes = boxes.clone()  # Don't mutate stored data

        if self.augment:
            # Horizontal flip
            if random.random() < 0.5:
                image = image.flip(-1)
                if boxes.numel() > 0:
                    boxes[:, [0, 2]] = self.input_size - boxes[:, [2, 0]]

            # Brightness
            if random.random() < 0.3:
                image.mul_(random.uniform(0.7, 1.3))

            # Contrast
            if random.random() < 0.3:
                mean = image.mean()
                image = (image - mean).mul_(random.uniform(0.8, 1.2)).add_(mean)

            # Saturation
            if random.random() < 0.2:
                gray = image.mean(dim=0, keepdim=True)
                factor = random.uniform(0.7, 1.3)
                image = image * factor + gray * (1 - factor)

            image.clamp_(0.0, 1.0)

            # Filter degenerate boxes after flip
            if boxes.numel() > 0:
                widths = boxes[:, 2] - boxes[:, 0]
                heights = boxes[:, 3] - boxes[:, 1]
                valid = (widths > 1) & (heights > 1)
                boxes = boxes[valid]
                labels = labels[valid]

        return image, {'boxes': boxes, 'labels': labels}


def collate_fn(batch):
    images, targets = zip(*batch)
    return list(images), list(targets)


print('Building datasets (pre-loading images into RAM)...')
train_ds = COCODetectionDataset(TRAIN_JSON, TRAIN_IMAGES, augment=True, input_size=INPUT_SIZE)
val_ds = COCODetectionDataset(VAL_JSON, VAL_IMAGES, augment=False, input_size=INPUT_SIZE)

# num_workers=0: data in RAM tensors, no I/O to parallelize
# pin_memory=True: faster CPU→GPU transfer
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, collate_fn=collate_fn,
                          pin_memory=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, collate_fn=collate_fn,
                        pin_memory=True)

print(f'\nTrain: {len(train_ds)} images, {len(train_loader)} batches (bs={BATCH_SIZE}, drop_last=True)')
print(f'Val: {len(val_ds)} images, {len(val_loader)} batches')
print(f'Workers: {NUM_WORKERS} (data pre-loaded, no multiprocessing needed)')

## Model

In [None]:
model = ssdlite320_mobilenet_v3_large(
    num_classes=NUM_CLASSES,
    weights_backbone=MobileNet_V3_Large_Weights.IMAGENET1K_V1,
)
model = model.to(DEVICE)

total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total params: {total:,}')
print(f'Trainable:    {trainable:,}')
print(f'Size:         {sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2:.1f} MB')

## Training Utilities

In [None]:
def freeze_backbone(model, freeze=True):
    for param in model.backbone.parameters():
        param.requires_grad = not freeze


def _box_iou_single(box1, box2):
    x1 = max(box1[0].item(), box2[0].item())
    y1 = max(box1[1].item(), box2[1].item())
    x2 = min(box1[2].item(), box2[2].item())
    y2 = min(box1[3].item(), box2[3].item())
    inter = max(0, x2 - x1) * max(0, y2 - y1)
    area1 = (box1[2] - box1[0]).item() * (box1[3] - box1[1]).item()
    area2 = (box2[2] - box2[0]).item() * (box2[3] - box2[1]).item()
    return inter / max(area1 + area2 - inter, 1e-6)


@torch.no_grad()
def compute_map(model, loader, device, iou_threshold=0.5, max_batches=50):
    model.eval()
    all_dets = defaultdict(list)
    all_n_gt = defaultdict(int)

    for batch_idx, (images, targets) in enumerate(tqdm(loader, desc='mAP', leave=False)):
        if batch_idx >= max_batches: break
        images = [img.to(device, non_blocking=True) for img in images]
        preds = model(images)

        for pred, gt in zip(preds, targets):
            gt_boxes = gt['boxes'].to(device)
            gt_labels = gt['labels'].to(device)
            for lbl in gt_labels.tolist():
                all_n_gt[lbl] += 1

            matched = set()
            for i in range(len(pred['boxes'])):
                cls = pred['labels'][i].item()
                score = pred['scores'][i].item()
                best_iou, best_idx = 0.0, -1
                for gi in (gt_labels == cls).nonzero(as_tuple=True)[0].tolist():
                    if gi in matched: continue
                    iou = _box_iou_single(pred['boxes'][i], gt_boxes[gi])
                    if iou > best_iou:
                        best_iou, best_idx = iou, gi
                is_tp = best_iou >= iou_threshold and best_idx >= 0
                if is_tp: matched.add(best_idx)
                all_dets[cls].append((score, is_tp))

    aps = {}
    for cls in sorted(all_n_gt):
        dets = sorted(all_dets.get(cls, []), key=lambda x: -x[0])
        n_gt = all_n_gt[cls]
        if n_gt == 0: continue
        tp, fp = 0, 0
        prec, rec = [], []
        for _, is_tp in dets:
            tp += is_tp; fp += not is_tp
            prec.append(tp / (tp + fp)); rec.append(tp / n_gt)
        ap = sum(max((p for p, r in zip(prec, rec) if r >= t), default=0) for t in [i/10 for i in range(11)]) / 11
        aps[cls] = ap

    return sum(aps.values()) / max(len(aps), 1), aps


def train_one_epoch(model, loader, optimizer, scaler, device, epoch):
    model.train()
    total_loss, n = 0.0, 0
    pbar = tqdm(loader, desc=f'Epoch {epoch}')
    for images, targets in pbar:
        images = [img.to(device, non_blocking=True) for img in images]
        targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets]
        optimizer.zero_grad(set_to_none=True)
        if scaler:
            with autocast('cuda'):
                loss_dict = model(images, targets)
                loss = sum(loss_dict.values())
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), 10.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss_dict = model(images, targets)
            loss = sum(loss_dict.values())
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 10.0)
            optimizer.step()
        total_loss += loss.item(); n += 1
        pbar.set_postfix(loss=f'{loss.item():.4f}')
    return total_loss / max(n, 1)


@torch.no_grad()
def val_loss(model, loader, device, use_amp=False):
    model.train()  # Need train mode for SSD loss computation
    total, n = 0.0, 0
    for images, targets in tqdm(loader, desc='Val', leave=False):
        images = [img.to(device, non_blocking=True) for img in images]
        targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets]
        if use_amp:
            with autocast('cuda'):
                loss = sum(model(images, targets).values())
        else:
            loss = sum(model(images, targets).values())
        total += loss.item(); n += 1
    return total / max(n, 1)

print('Utilities defined.')
print('Gradient clipping: max_norm=10.0')
print('Validation supports AMP for faster forward pass')

## Train

In [None]:
# Freeze backbone initially
freeze_backbone(model, freeze=True)

# Optimizer with separate param groups
head_params = [p for n, p in model.named_parameters() if not n.startswith('backbone') and p.requires_grad]
backbone_params = [p for p in model.backbone.parameters() if p.requires_grad]

optimizer = optim.SGD([
    {'params': head_params, 'lr': LR},
    {'params': backbone_params, 'lr': LR * 0.1},
], momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LR_MIN)
scaler = GradScaler('cuda') if USE_AMP and DEVICE.type == 'cuda' else None

best_val = float('inf')
best_map = 0.0
patience_ctr = 0
history = []

print(f'Training for {EPOCHS} epochs on {DEVICE}')
print(f'AMP: {scaler is not None}, Backbone frozen for first {FREEZE_BACKBONE_EPOCHS} epochs')
print()

In [None]:
for epoch in range(1, EPOCHS + 1):
    t0 = time.time()

    # Unfreeze backbone
    if epoch == FREEZE_BACKBONE_EPOCHS + 1:
        freeze_backbone(model, freeze=False)
        backbone_params = list(model.backbone.parameters())
        head_params = [p for n, p in model.named_parameters() if not n.startswith('backbone')]
        optimizer = optim.SGD([
            {'params': head_params, 'lr': scheduler.get_last_lr()[0]},
            {'params': backbone_params, 'lr': scheduler.get_last_lr()[0] * 0.1},
        ], momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS - epoch + 1, eta_min=LR_MIN)
        print(f'  >> Backbone unfrozen at epoch {epoch}')

    tloss = train_one_epoch(model, train_loader, optimizer, scaler, DEVICE, epoch)
    vloss = val_loss(model, val_loader, DEVICE, use_amp=USE_AMP)

    # mAP every 5 epochs
    mAP = 0.0
    if epoch % 5 == 0 or epoch == EPOCHS:
        mAP, per_cls = compute_map(model, val_loader, DEVICE)
        ap_str = ' | '.join(f'{CATEGORIES.get(c,c)}: {ap:.3f}' for c, ap in sorted(per_cls.items()))
        print(f'  mAP@0.5: {mAP:.4f}  [{ap_str}]')

    scheduler.step()
    lr = optimizer.param_groups[0]['lr']
    elapsed = time.time() - t0

    print(f'Epoch {epoch:3d} | Train: {tloss:.4f} | Val: {vloss:.4f} | mAP: {mAP:.4f} | LR: {lr:.6f} | {elapsed:.1f}s')
    history.append({'epoch': epoch, 'train_loss': tloss, 'val_loss': vloss, 'mAP': mAP, 'lr': lr})

    # Save best by val loss
    if vloss < best_val:
        best_val = vloss
        patience_ctr = 0
        torch.save({'model': model.state_dict(), 'epoch': epoch,
                    'best_val': best_val, 'best_map': best_map},
                   OUTPUT_DIR / 'best.pth')
        print(f'  -> Saved best model (val_loss={vloss:.4f})')
    else:
        patience_ctr += 1

    # Save best by mAP
    if mAP > best_map:
        best_map = mAP
        torch.save({'model': model.state_dict(), 'epoch': epoch, 'best_map': best_map},
                   OUTPUT_DIR / 'best_map.pth')
        print(f'  -> Saved best mAP model ({mAP:.4f})')

    # Full checkpoint every 10 epochs for crash recovery
    if epoch % 10 == 0:
        torch.save({
            'model': model.state_dict(), 'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'scaler': scaler.state_dict() if scaler else None,
            'best_val': best_val, 'best_map': best_map,
            'patience_ctr': patience_ctr, 'history': history,
        }, OUTPUT_DIR / f'checkpoint_epoch_{epoch}.pth')
        print(f'  -> Checkpoint saved (epoch {epoch}, resumable)')

    if patience_ctr >= PATIENCE:
        print(f'\nEarly stopping at epoch {epoch}')
        break

print(f'\nDone! Best val_loss: {best_val:.4f}, Best mAP: {best_map:.4f}')

## Training Curves

In [None]:
import matplotlib.pyplot as plt

epochs = [h['epoch'] for h in history]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(epochs, [h['train_loss'] for h in history], label='Train')
ax1.plot(epochs, [h['val_loss'] for h in history], label='Val')
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.legend(); ax1.set_title('Loss')

map_epochs = [h['epoch'] for h in history if h['mAP'] > 0]
map_vals = [h['mAP'] for h in history if h['mAP'] > 0]
ax2.plot(map_epochs, map_vals, 'go-')
ax2.set_xlabel('Epoch'); ax2.set_ylabel('mAP@0.5'); ax2.set_title('mAP')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'training_curves.png', dpi=150)
plt.show()

## Export ONNX

In [None]:
# Load best model and export to ONNX for mobile deployment
best_ckpt = torch.load(OUTPUT_DIR / 'best_map.pth', map_location='cpu', weights_only=True)
model_export = ssdlite320_mobilenet_v3_large(num_classes=NUM_CLASSES)
model_export.load_state_dict(best_ckpt['model'])
model_export.eval()

dummy = [torch.randn(3, 320, 320)]
torch.onnx.export(
    model_export, dummy, str(OUTPUT_DIR / 'ssdlite_detection.onnx'),
    opset_version=17,
    input_names=['image'],
    output_names=['boxes', 'labels', 'scores'],
)
onnx_size = (OUTPUT_DIR / 'ssdlite_detection.onnx').stat().st_size / 1024**2
print(f'ONNX exported: {onnx_size:.1f} MB')

## Save History & Download
Download the output files from `/kaggle/working/detection_output/`

In [None]:
with open(OUTPUT_DIR / 'history.json', 'w') as f:
    json.dump(history, f, indent=2)

# Free VRAM
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print('Output files:')
for f in sorted(OUTPUT_DIR.iterdir()):
    print(f'  {f.name}: {f.stat().st_size / 1024**2:.1f} MB')
print(f'\nDownload all from: {OUTPUT_DIR}')