# Signature detection with custom model

### Imports

In [8]:
import torch
import torch.nn as nn, torch.nn.functional as F
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from torchvision.models.detection.retinanet import RetinaNet_ResNet50_FPN_Weights
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms.functional as TF
from torchvision import transforms
import webdataset as wds
from PIL import Image
import numpy as np
import time
import matplotlib.pyplot as plt
import cv2
import json

In [9]:
# check GPU
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device:", torch.cuda.get_device_name(0))

CUDA available: True
Device: NVIDIA GeForce MX450


### Hyperparameters

In [10]:
train_dataset = "datasets/custom_augmented/train-00000.tar"
val_dataset = "datasets/custom_augmented/val-00000.tar"
test_dataset = "datasets/custom_augmented/test-00000.tar"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

num_classes = 2  # 1 class (signature) + background
imgsz = 256
epochs = 2
batch_size = 4
learning_rate = 0.005

### utility functions

In [11]:
def iou(boxA, boxB):
    # boxes are [x1,y1,x2,y2]
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    interW = max(0, xB - xA)
    interH = max(0, yB - yA)
    interArea = interW * interH
    boxAArea = max(0, (boxA[2] - boxA[0])) * max(0, (boxA[3] - boxA[1]))
    boxBArea = max(0, (boxB[2] - boxB[0])) * max(0, (boxB[3] - boxB[1]))
    denom = float(boxAArea + boxBArea - interArea)
    return interArea / denom if denom > 0 else 0.0

In [12]:
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    running_loss = 0.0
    it = 0
    for images, targets in loader:
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        running_loss += losses.item()
        it += 1
    return running_loss / max(1, it)

In [13]:
def validate(model, loader, device):
    model.train()
    val_loss = 0.0
    it = 0

    with torch.no_grad():
        for images, targets in loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Temporarily run in train mode to get loss dict (model() in eval returns list)
            loss_dict = model(images, targets)

            losses = sum(loss for loss in loss_dict.values())
            val_loss += losses.item()
            it += 1

    return val_loss / max(1, it)


In [14]:
def evaluate_precision_recall(model, loader, device, iou_th=0.5, score_th=0.5):
    model.eval()
    TP = 0
    FP = 0
    FN = 0
    with torch.no_grad():
        for images, targets in loader:
            img = images[0].to(device)
            gt = targets[0]
            preds = model([img])[0]
            pred_boxes = preds['boxes'].cpu().numpy()
            pred_scores = preds['scores'].cpu().numpy()
            gt_boxes = gt['boxes'].cpu().numpy() if gt['boxes'].size(0) > 0 else np.zeros((0,4))

            keep_idx = np.where(pred_scores >= score_th)[0]
            pred_boxes = pred_boxes[keep_idx]
            matched_gt = set()
            for pb in pred_boxes:
                best_iou = 0
                best_j = -1
                for j, gb in enumerate(gt_boxes):
                    if j in matched_gt:
                        continue
                    cur_iou = iou(pb, gb)
                    if cur_iou > best_iou:
                        best_iou = cur_iou
                        best_j = j
                if best_iou >= iou_th and best_j >= 0:
                    TP += 1
                    matched_gt.add(best_j)
                else:
                    FP += 1
            FN += (len(gt_boxes) - len(matched_gt))

    precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    return precision, recall


In [15]:
def training_loop(model, train_loader, val_loader, optimizer, lr_scheduler, device, epochs=10):
    train_losses = []
    val_losses = []
    for epoch in range(epochs):
        t0 = time.time()
        try:
            train_loss = train_one_epoch(model, train_loader, optimizer, device)
            lr_scheduler.step()
            val_loss = validate(model, val_loader, device)
            prec, rec = evaluate_precision_recall(model, val_loader, device)
            print(f'Epoch {epoch+1}/{epochs} — train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f}, prec: {prec:.3f}, rec: {rec:.3f}, time: {time.time()-t0:.1f}s')
            train_losses.append(train_loss)
            val_losses.append(val_loss)
        except RuntimeError as e:
            if 'out of memory' in str(e).lower():
                print('RuntimeError: CUDA out of memory during training.\nConsider:')
                # try to free cache and continue or abort
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                raise
            else:
                raise
    return model, train_losses, val_losses

### Dataset

In [16]:
transform = transforms.Compose([
    transforms.Resize((imgsz, imgsz)),
    transforms.ToTensor()
])

def preprocessSample(sample):
     # Detect image key dynamically
    img_key = None
    for k in sample.keys():
        if k.lower() in ["jpg", "jpeg", "png"]:
            img_key = k
            break
    if img_key is None:
        raise ValueError(f"No supported image format found in sample keys: {list(sample.keys())}")

    # Image already decoded to PIL
    img = sample[img_key]
    if img.mode != "RGB":
        img = img.convert("RGB")

    # Get original image size before resizing
    orig_w, orig_h = img.size

    # Resize image
    img_resized = transforms.Resize((imgsz, imgsz))(img)
    new_w, new_h = img_resized.size

    # Compute scale factors
    scale_x = new_w / orig_w
    scale_y = new_h / orig_h

    # Parse target
    target = sample["json"]

    # Convert boxes [x, y, w, h] → [x1, y1, x2, y2]
    boxes = []
    for (x, y, w, h) in target["boxes"]:
        x1, y1, x2, y2 = x, y, x + w, y + h
        # Scale coordinates
        x1 *= scale_x
        x2 *= scale_x
        y1 *= scale_y
        y2 *= scale_y
        boxes.append([x1, y1, x2, y2])

    # Convert to tensors
    target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
    target["labels"] = torch.as_tensor(target["labels"], dtype=torch.int64)

    # Final transform to tensor
    img_tensor = transforms.ToTensor()(img_resized)

    return img_tensor, target

In [17]:
train_dataset = (
    wds.WebDataset(train_dataset)   # <- use pattern or list of tar paths
    .decode("pil")
    .map(preprocessSample)
)
val_dataset = (
    wds.WebDataset(val_dataset)   # <- use pattern or list of tar paths
    .decode("pil")
    .map(preprocessSample)
)
test_dataset = (
    wds.WebDataset(test_dataset)   # <- use pattern or list of tar paths
    .decode("pil")
    .map(preprocessSample)
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=0,
    collate_fn=lambda x: tuple(zip(*x))
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    num_workers=0,
    collate_fn=lambda x: tuple(zip(*x))
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    num_workers=0,
    collate_fn=lambda x: tuple(zip(*x))
)



## Model definition

In [18]:
class SimpleBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU()
        )
    def forward(self, x):
        return self.features(x)

class DetectionHead(nn.Module):
    def __init__(self, in_channels, num_anchors, num_classes):
        super().__init__()
        self.num_anchors = num_anchors
        self.num_classes = num_classes
        self.loc = nn.Conv2d(in_channels, num_anchors * 4, 1)
        self.cls = nn.Conv2d(in_channels, num_anchors * num_classes, 1)
    def forward(self, x):
        locs = self.loc(x)
        clss = self.cls(x)
        # reshape to (batch, anchors, 4) and (batch, anchors, num_classes)
        batch = x.shape[0]
        locs = locs.permute(0,2,3,1).contiguous().view(batch, -1, 4)
        clss = clss.permute(0,2,3,1).contiguous().view(batch, -1, self.num_classes)
        return locs, clss

class CustomDetector(nn.Module):
    def __init__(self, num_classes, anchor_stride=32):
        super().__init__()
        self.backbone = SimpleBackbone()
        # Calculate feature map size
        self.anchor_stride = anchor_stride
        fmap_size = imgsz // anchor_stride
        self.num_anchors = fmap_size * fmap_size
        self.head = DetectionHead(128, self.num_anchors, num_classes)
        # Generate anchor centers
        grid = torch.meshgrid(
            torch.arange(fmap_size), torch.arange(fmap_size), indexing='ij'
        )
        self.register_buffer('anchor_centers', torch.stack(grid, dim=-1).reshape(-1,2).float() * anchor_stride)
    def forward(self, x):
        feats = self.backbone(x)
        locs, clss = self.head(feats)
        # anchors are at self.anchor_centers
        return locs, clss, self.anchor_centers

## Training

In [20]:
model = CustomDetector(num_classes=2)

try:
    model.to(device)
except RuntimeError as e:
    print('Error moving model to device — falling back to CPU.\n', e)
    device = torch.device('cpu')
    model.to(device)

model

CustomDetector(
  (backbone): SimpleBackbone(
    (features): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): ReLU()
      (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (5): ReLU()
      (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (7): ReLU()
    )
  )
  (head): DetectionHead(
    (loc): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
    (cls): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [21]:
# clear cache and reduce fragmentation
if torch.cuda.is_available():
    torch.cuda.empty_cache()

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [25]:
# Custom training loop for CustomDetector
def custom_loss(locs, clss, targets, anchors, device):
    # targets: list of dicts with 'boxes' and 'labels'
    batch_size = len(targets)
    loc_loss = 0.0
    cls_loss = 0.0
    for i in range(batch_size):
        gt_boxes = targets[i]['boxes'].to(device)
        gt_labels = targets[i]['labels'].to(device)
        # For simplicity, match each GT to closest anchor
        if gt_boxes.numel() == 0:
            continue
        anchor_centers = anchors.to(device)
        pred_locs = locs[i]
        pred_clss = clss[i]
        for j, gt_box in enumerate(gt_boxes):
            # Find closest anchor
            anchor_dists = torch.norm(anchor_centers - gt_box[:2], dim=1)
            anchor_idx = torch.argmin(anchor_dists)
            # Box regression loss (MSE)
            loc_loss += F.mse_loss(pred_locs[anchor_idx], gt_box, reduction='mean')
            # Classification loss (CrossEntropy)
            cls_loss += F.cross_entropy(pred_clss[anchor_idx].unsqueeze(0), gt_labels[j].unsqueeze(0))
    return loc_loss + cls_loss

def train_custom_detector(model, loader, optimizer, device, epochs=2):
    model.train()
    losses = []
    for epoch in range(epochs):
        running_loss = 0.0
        batch_count = 0
        for images, targets in loader:
            images = torch.stack([img.to(device) for img in images])
            locs, clss, anchors = model(images)
            loss = custom_loss(locs, clss, targets, anchors, device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            batch_count += 1
        avg_loss = running_loss / max(1, batch_count)
        print(f"Epoch {epoch+1}/{epochs} — loss: {avg_loss:.4f}")
        losses.append(avg_loss)
    return model, losses

In [26]:
print("Starting training...")
model, train_losses = train_custom_detector(model, train_loader, optimizer, device, epochs)

Starting training...
Epoch 1/2 — loss: nan
Epoch 1/2 — loss: nan
Epoch 2/2 — loss: nan
Epoch 2/2 — loss: nan


In [None]:
def draw_boxes(img_tensor, boxes, color=(0,255,0), linewidth=2):
    img = (img_tensor.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8).copy()
    for b in boxes:
        x1, y1, x2, y2 = map(int, b)
        cv2.rectangle(img, (x1, y1), (x2, y2), color, linewidth)
    return img

n=5
score_thresh=0.5
model.eval()
plt.figure(figsize=(12, 4 * n))

cnt = 1
with torch.no_grad():
    iter_dataset = iter(test_dataset)
    for i in range(n):
        try:
            img_tensor, target = next(iter_dataset)
        except StopIteration:
            print("End of dataset reached.")
            break

        img = img_tensor.to(device)
        preds = model([img])[0]

        pred_boxes = preds['boxes'].cpu().numpy()
        pred_scores = preds['scores'].cpu().numpy()
        keep = pred_scores >= score_thresh
        pred_boxes = pred_boxes[keep]

        gt_boxes = (
            target['boxes'].numpy()
            if target["boxes"].numel() > 0
            else np.zeros((0, 4))
        )

        # Draw boxes: GT (green), Predictions (red)
        vis_gt = draw_boxes(img_tensor, gt_boxes, color=(0,255,0))
        vis_pred = draw_boxes(img_tensor, pred_boxes, color=(0,0,255))

        # Combine side-by-side
        combined = np.concatenate([vis_gt, vis_pred], axis=1)
        plt.subplot(n, 1, cnt)
        plt.axis("off")
        plt.title(f"GT (green) | Pred (red) — Sample {i}")
        plt.imshow(combined[:,:,::-1])
        cnt += 1

plt.tight_layout()
plt.show()