In [None]:
# =========================================================
# GRID-BASED POTHOLE DETECTOR (YOLO-LITE) â€” SINGLE CELL
# =========================================================

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
import cv2
import matplotlib.pyplot as plt

# --------------------
# device
# --------------------
device = (
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print("Using device:", device)

# --------------------
# constants
# --------------------
IMG_SIZE = 224
S = 7
BATCH_SIZE = 16
EPOCHS = 12
CONF_THRESH = 0.2

# --------------------
# Dataset
# --------------------
class GridPotholeDataset(Dataset):
    def __init__(self, img_dir, lbl_dir):
        self.img_dir = Path(img_dir)
        self.lbl_dir = Path(lbl_dir)
        self.images = list(self.img_dir.glob("*"))
        self.transform = T.Compose([
            T.Resize((IMG_SIZE, IMG_SIZE)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)

        target = torch.zeros((S, S, 5), dtype=torch.float32)

        label_path = self.lbl_dir / (img_path.stem + ".txt")
        if label_path.exists():
            with open(label_path) as f:
                for line in f:
                    _, xc, yc, w, h = map(float, line.split())
                    cx = min(int(xc * S), S - 1)
                    cy = min(int(yc * S), S - 1)
                    tx = xc * S - cx
                    ty = yc * S - cy
                    target[cy, cx] = torch.tensor([1.0, tx, ty, w, h])

        return img, target

# --------------------
# Model
# --------------------
class GridDetector(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 16, 3, 2, 1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(S),
            nn.Conv2d(128, 5, 1)
        )

    def forward(self, x):
        x = self.net(x)
        return x.permute(0, 2, 3, 1)  # B,S,S,5

# --------------------
# Balanced loss
# --------------------
def yolo_lite_loss(pred, target):
    obj_pred = pred[..., 0]
    obj_true = target[..., 0]

    pos_mask = obj_true == 1
    neg_mask = obj_true == 0

    pos_loss = nn.BCEWithLogitsLoss()(obj_pred[pos_mask], obj_true[pos_mask]) if pos_mask.any() else 0.0
    neg_loss = nn.BCEWithLogitsLoss()(obj_pred[neg_mask], obj_true[neg_mask])

    obj_loss = pos_loss + 0.3 * neg_loss

    if pos_mask.any():
        box_loss = nn.MSELoss()(pred[pos_mask][...,1:], target[pos_mask][...,1:])
    else:
        box_loss = 0.0

    return obj_loss + 6.0 * box_loss

# --------------------
# DataLoader
# --------------------
train_ds = GridPotholeDataset(
    "../data/RDD2020/images/train",
    "../data/RDD2020/labels/train"
)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# --------------------
# Training
# --------------------
model = GridDetector().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(EPOCHS):
    total_loss = 0.0
    for imgs, targets in train_loader:
        imgs = imgs.to(device)
        targets = targets.to(device)

        preds = model(imgs)
        loss = yolo_lite_loss(preds, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {total_loss/len(train_loader):.4f}")

torch.save(model.state_dict(), "grid_pothole_detector.pt")
print("Saved grid_pothole_detector.pt")

# --------------------
# Inference + visualization
# --------------------
model.eval()

transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor()
])

TEST_DIR = Path("../data/RDD2020/images/test")
img_paths = list(TEST_DIR.glob("*"))[:20]

for img_path in img_paths:
    img_pil = Image.open(img_path).convert("RGB")
    ow, oh = img_pil.size
    x = transform(img_pil).unsqueeze(0).to(device)

    with torch.no_grad():
        pred = model(x)[0]

    img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
    boxes = []

    for cy in range(S):
        for cx in range(S):
            obj = torch.sigmoid(pred[cy, cx, 0]).item()
            if obj < CONF_THRESH:
                continue
            tx, ty, w, h = pred[cy, cx, 1:].cpu().numpy()
            xc = (cx + tx) / S
            yc = (cy + ty) / S
            x1 = int((xc - w / 2) * ow)
            y1 = int((yc - h / 2) * oh)
            x2 = int((xc + w / 2) * ow)
            y2 = int((yc + h / 2) * oh)
            boxes.append((obj, x1, y1, x2, y2))

    boxes = sorted(boxes, reverse=True)[:5]

    for obj, x1, y1, x2, y2 in boxes:
        cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
        cv2.putText(img, f"{obj:.2f}", (x1, max(y1 - 5, 15)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1)

    plt.figure(figsize=(5,5))
    plt.imshow(img)
    plt.title(img_path.name)
    plt.axis("off")
    plt.show()


Using device: mps
Epoch 1/12 | Loss: 0.8751
Epoch 2/12 | Loss: 0.7837
Epoch 3/12 | Loss: 0.7745
Epoch 4/12 | Loss: 0.7511
Epoch 5/12 | Loss: 0.7201
Epoch 6/12 | Loss: 0.7250
Epoch 7/12 | Loss: 0.7031
Epoch 8/12 | Loss: 0.6966
Epoch 9/12 | Loss: 0.7173
Epoch 10/12 | Loss: 0.6854
Epoch 11/12 | Loss: 0.6849
