In [None]:
# =========================
# SMALLER ViT + PROGRESS BAR
# =========================

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vit_b_32, ViT_B_32_Weights
from pathlib import Path
from PIL import Image
import torchvision.transforms as T
from tqdm import tqdm

# -------------------------
# device
# -------------------------
device = (
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print("Using device:", device)

# -------------------------
# Dataset
# -------------------------
class PotholeDataset(Dataset):
    def __init__(self, img_dir, lbl_dir, img_size=224):
        self.img_dir = Path(img_dir)
        self.lbl_dir = Path(lbl_dir)
        self.images = list(self.img_dir.glob("*"))
        self.transform = T.Compose([
            T.Resize((img_size, img_size)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)

        label_path = self.lbl_dir / (img_path.stem + ".txt")

        if label_path.exists():
            boxes = []
            with open(label_path) as f:
                for line in f:
                    _, xc, yc, w, h = map(float, line.split())
                    boxes.append((w * h, xc, yc, w, h))
            _, xc, yc, w, h = max(boxes)
            target = torch.tensor([1.0, xc, yc, w, h])
        else:
            target = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0])

        return img, target.float()

# -------------------------
# ViT-B/32 detector (smaller + faster)
# -------------------------
class ViTPotholeDetector(nn.Module):
    def __init__(self):
        super().__init__()
        weights = ViT_B_32_Weights.IMAGENET1K_V1
        self.vit = vit_b_32(weights=weights)
        self.vit.heads = nn.Identity()

        for p in self.vit.parameters():
            p.requires_grad = False

        self.head = nn.Linear(768, 5)

    def forward(self, x):
        x = self.vit(x)
        return self.head(x)

# -------------------------
# Loss
# -------------------------
def detection_loss(pred, target):
    obj_pred = torch.sigmoid(pred[:, 0])
    box_pred = pred[:, 1:]

    obj_true = target[:, 0]
    box_true = target[:, 1:]

    cls_loss = nn.BCELoss()(obj_pred, obj_true)
    mask = obj_true == 1
    box_loss = nn.MSELoss()(box_pred[mask], box_true[mask]) if mask.any() else 0.0

    return cls_loss + 5.0 * box_loss

# -------------------------
# DataLoader
# -------------------------
train_ds = PotholeDataset(
    img_dir="../data/RDD2020/images/train",
    lbl_dir="../data/RDD2020/labels/train",
    img_size=224
)

train_loader = DataLoader(
    train_ds,
    batch_size=16,     # smaller batch for MPS
    shuffle=True,
    num_workers=0
)

# -------------------------
# Training with progress
# -------------------------
model = ViTPotholeDetector().to(device)
optimizer = torch.optim.Adam(model.head.parameters(), lr=1e-3)

epochs = 5

for epoch in range(epochs):
    total_loss = 0.0

    for imgs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        imgs = imgs.to(device)
        targets = targets.to(device)

        preds = model(imgs)
        loss = detection_loss(preds, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(train_loader):.4f}", flush=True)

# -------------------------
# Save
# -------------------------
torch.save(model.state_dict(), "vit_b32_pothole_detector.pt")
print("Saved vit_b32_pothole_detector.pt")


Using device: mps
Downloading: "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth" to /Users/vamsi.k/.cache/torch/hub/checkpoints/vit_b_32-d86f8d99.pth


100%|██████████| 337M/337M [00:17<00:00, 19.9MB/s] 
Epoch 1/5: 100%|█████████▉| 472/473 [00:45<00:00, 10.44it/s]