In [4]:
import sys
import os
import json
from PIL import Image
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torchvision.ops as ops
import math
import random
from dotenv import load_dotenv

load_dotenv()
YT_ROOTS = [os.getenv("YT_ROOT_1"), os.getenv("YT_ROOT_2")]


def extract_bbox_from_points(points):
    xs = [p[0] for p in points]
    ys = [p[1] for p in points]
    x1, x2 = min(xs), max(xs)
    y1, y2 = min(ys), max(ys)
    return [x1, y1, x2 - x1, y2 - y1]  # x, y, w, h


def get_brake_status(raw):
    if "BrakeOn" in raw:
        return "brake_on"
    if "BrakeOff" in raw:
        return "brake_off"
    return "unknown"


def get_turn_signal(ts):
    if ts == "left":
        return "left_signal"
    if ts == "right":
        return "right_signal"
    if ts == "hazard":
        return "hazard"
    if ts == "off":
        return "off"
    return "unknown"


def normalize_label(shape):
    raw = shape["label"]
    ts = shape.get("attributes", {}).get("turn_signal", "")
    return get_brake_status(raw), get_turn_signal(ts)


def process_json(json_path):
    img_path = json_path[:-5] + ".jpg"
    if not os.path.exists(img_path):
        return None

    try:
        with open(json_path, "r") as f:
            ann = json.load(f)
    except:
        return None

    objects = []
    for shape in ann.get("shapes", []):
        bbox = extract_bbox_from_points(shape["points"])
        brake_status, turn_signal = normalize_label(shape)
        if brake_status != "unknown":
            objects.append({"bbox": bbox, "label": brake_status})
        if turn_signal not in ["unknown", "off"]:
            objects.append({"bbox": bbox, "label": turn_signal})

    return {"image_path": img_path, "objects": objects}


def load_yt_dataset_fast(root_dirs, max_workers=16):
    json_files = [
        os.path.join(dirpath, f)
        for root in root_dirs
        for dirpath, _, filenames in os.walk(root)
        for f in filenames
        if f.lower().endswith(".json")
    ]
    print("finished finding all JSON files")

    samples = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_json, jf): jf for jf in json_files}
        for idx, future in enumerate(as_completed(futures)):
            res = future.result()
            if res:
                samples.append(res)
            if (idx + 1) % 1000 == 0:
                print(f"Processed {idx+1} / {len(json_files)} JSON files")

    return samples

In [5]:
print('starting to load YT dataset...')
samples = load_yt_dataset_fast(YT_ROOTS)
print("Total samples:", len(samples))


starting to load YT dataset...


TypeError: expected str, bytes or os.PathLike object, not NoneType

In [None]:
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
model.to(device)
model.eval()

Where does OWL-ViT fail in this safety-critical scenario?


In [6]:
label_map = {
    "brake_off": "car with brake light off",
    "brake_on": "car with brake light on",
    "left_signal": "car with left signal on",
    "right_signal": "car with right signal on",
    "hazard": "car with hazard lights on",
}

text_queries = list(label_map.values())
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
label_to_idx = {key: idx for idx, key in enumerate(label_map.keys())}


class OwlViTDataset(Dataset):
    def __init__(self, records, label_map, min_boxes=1):
        self.label_map = label_map
        self.records = [
            record
            for record in records
            if sum(obj["label"] in label_map for obj in record["objects"]) >= min_boxes
        ]
        self.label_to_idx = label_to_idx

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        item = self.records[idx]
        img_path = item["image_path"]
        objects = item["objects"]

        image = Image.open(img_path).convert("RGB")
        width, height = image.size

        boxes = []
        labels_idx = []
        for obj in objects:
            class_name = obj["label"]
            if class_name not in self.label_map:
                continue

            x, y, w, h = obj["bbox"]
            x1 = max(x / width, 0.0)
            y1 = max(y / height, 0.0)
            x2 = min((x + w) / width, 1.0)
            y2 = min((y + h) / height, 1.0)

            boxes.append([float(x1), float(y1), float(x2), float(y2)])
            labels_idx.append(self.label_to_idx[class_name])

        if boxes:
            boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
        else:
            boxes_tensor = torch.zeros((0, 4), dtype=torch.float32)

        if labels_idx:
            labels_tensor = torch.tensor(labels_idx, dtype=torch.long)
        else:
            labels_tensor = torch.zeros((0,), dtype=torch.long)

        return {
            "image": image,
            "boxes": boxes_tensor,
            "labels": labels_tensor,
            "image_path": img_path,
            "size": (height, width),
        }

In [8]:
def train_val_split(records, train_ratio=0.85, seed=42):
    indices = list(range(len(records)))
    random.Random(seed).shuffle(indices)
    split_idx = max(1, int(len(indices) * train_ratio))
    train_records = [records[i] for i in indices[:split_idx]]
    val_records = [records[i] for i in indices[split_idx:]]
    if not val_records:
        val_records = train_records[-1:]
        train_records = train_records[:-1]
    return train_records, val_records


def owlvit_collate_fn(batch):
    images = [item["image"] for item in batch]
    batch_text = [text_queries] * len(images)
    encoded = processor(images=images, text=batch_text, return_tensors="pt", padding=True)

    return {
        "pixel_values": encoded["pixel_values"],
        "input_ids": encoded["input_ids"],
        "attention_mask": encoded["attention_mask"],
        "gt_boxes": [item["boxes"] for item in batch],
        "gt_labels": [item["labels"] for item in batch],
        "image_paths": [item["image_path"] for item in batch],
    }


def cxcywh_to_xyxy(boxes):
    cx, cy, w, h = boxes.unbind(-1)
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    return torch.stack([x1, y1, x2, y2], dim=-1)


def xyxy_to_cxcywh(boxes):
    x1, y1, x2, y2 = boxes.unbind(-1)
    w = x2 - x1
    h = y2 - y1
    cx = x1 + 0.5 * w
    cy = y1 + 0.5 * h
    return torch.stack([cx, cy, w, h], dim=-1)


def owlvit_detection_loss(outputs, gt_boxes, gt_labels, cls_weight=1.0, box_weight=2.0, iou_weight=1.0):
    pred_boxes = outputs.pred_boxes
    pred_logits = outputs.logits
    device = pred_boxes.device
    total_loss = torch.zeros(1, device=device)
    matched_batches = 0

    for batch_idx in range(len(gt_boxes)):
        boxes = gt_boxes[batch_idx].to(device)
        labels = gt_labels[batch_idx].to(device)
        if boxes.numel() == 0:
            continue

        preds_xyxy = cxcywh_to_xyxy(pred_boxes[batch_idx])
        ious = ops.box_iou(preds_xyxy, boxes)
        best_idx = torch.argmax(ious, dim=0)
        matched_logits = pred_logits[batch_idx][best_idx]
        cls_loss = F.cross_entropy(matched_logits, labels)

        pred_cxcywh = pred_boxes[batch_idx][best_idx]
        target_cxcywh = xyxy_to_cxcywh(boxes)
        box_loss = F.l1_loss(pred_cxcywh, target_cxcywh)

        matched_ious = ious[best_idx, torch.arange(len(boxes), device=device)]
        iou_loss = (1.0 - matched_ious.clamp(0.0, 1.0)).mean()

        total_loss = total_loss + cls_weight * cls_loss + box_weight * box_loss + iou_weight * iou_loss
        matched_batches += 1

    if matched_batches == 0:
        return total_loss

    return total_loss / matched_batches

In [9]:
train_records, val_records = train_val_split(samples, train_ratio=0.70, seed=42)
train_dataset = OwlViTDataset(train_records, label_map)
val_dataset = OwlViTDataset(val_records, label_map)

print(f"Train samples: {len(train_dataset)} | Val samples: {len(val_dataset)}")

NameError: name 'samples' is not defined

In [10]:
batch_size = 16
num_epochs = 4
learning_rate = 5e-6
weight_decay = 0.01
grad_clip = 1.0

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=owlvit_collate_fn,
    num_workers=0,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=owlvit_collate_fn,
    num_workers=0,
)

optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
total_train_steps = max(1, len(train_loader) * num_epochs)
scheduler = CosineAnnealingLR(optimizer, T_max=total_train_steps)

In [11]:
def run_epoch(data_loader, train=True):
    if train:
        model.train()
    else:
        model.eval()

    epoch_loss = 0.0
    steps = 0
    loop = tqdm(data_loader, desc="train" if train else "val", leave=False)

    for batch in loop:
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        if train:
            optimizer.zero_grad()

        with torch.set_grad_enabled(train):
            outputs = model(
                input_ids=input_ids,
                pixel_values=pixel_values,
                attention_mask=attention_mask,
            )
            loss = owlvit_detection_loss(
                outputs,
                gt_boxes=batch["gt_boxes"],
                gt_labels=batch["gt_labels"],
            )

            if train:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()

        epoch_loss += loss.item()
        steps += 1
        loop.set_postfix(loss=epoch_loss / max(1, steps))

    return epoch_loss / max(1, steps)


for epoch in range(num_epochs):
    train_loss = run_epoch(train_loader, train=True)
    val_loss = run_epoch(val_loader, train=False)
    print(f"Epoch {epoch + 1}/{num_epochs} | train loss: {train_loss:.4f} | val loss: {val_loss:.4f}")

train:   0%|          | 0/3401 [00:00<?, ?it/s]

val:   0%|          | 0/1460 [00:00<?, ?it/s]

Epoch 1/4 | train loss: 0.4499 | val loss: 0.3777


train:   0%|          | 0/3401 [00:00<?, ?it/s]

val:   0%|          | 0/1460 [00:00<?, ?it/s]

Epoch 2/4 | train loss: 0.3202 | val loss: 0.3385


train:   0%|          | 0/3401 [00:00<?, ?it/s]

val:   0%|          | 0/1460 [00:00<?, ?it/s]

Epoch 3/4 | train loss: 0.2547 | val loss: 0.3320


train:   0%|          | 0/3401 [00:00<?, ?it/s]

val:   0%|          | 0/1460 [00:00<?, ?it/s]

Epoch 4/4 | train loss: 0.2199 | val loss: 0.3324


In [None]:
save_dir = os.path.join("artifacts", f"owlvit-finetune")
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir)
processor.save_pretrained(save_dir)
print(f"Saved fine-tuned OWL-ViT artifacts to {save_dir}")

Saved fine-tuned OWL-ViT artifacts to artifacts\owlvit-finetune


In [10]:
ckpt_dir = "artifacts/owlvit-finetune" 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

finetuned_processor = OwlViTProcessor.from_pretrained(ckpt_dir)
finetuned_model = OwlViTForObjectDetection.from_pretrained(ckpt_dir).to(device)
finetuned_model.eval()

OwlViTForObjectDetection(
  (owlvit): OwlViTModel(
    (text_model): OwlViTTextTransformer(
      (embeddings): OwlViTTextEmbeddings(
        (token_embedding): Embedding(49408, 512)
        (position_embedding): Embedding(16, 512)
      )
      (encoder): OwlViTEncoder(
        (layers): ModuleList(
          (0-11): 12 x OwlViTEncoderLayer(
            (self_attn): OwlViTAttention(
              (k_proj): Linear(in_features=512, out_features=512, bias=True)
              (v_proj): Linear(in_features=512, out_features=512, bias=True)
              (q_proj): Linear(in_features=512, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=512, bias=True)
            )
            (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (mlp): OwlViTMLP(
              (activation_fn): QuickGELUActivation()
              (fc1): Linear(in_features=512, out_features=2048, bias=True)
              (fc2): Linear(in_features=2048, out_

In [None]:
# Evaluation configuration
max_eval_samples = 300  # cap for speed; set None to use all validation samples
sample_stride = 1       # take every Nth sample if dataset large
iou_threshold = 0.5     # IoU threshold for a correct localization
score_threshold = 0.20  # Probability threshold after softmax over label prompts (0-1)
visualize = True        # toggle visualization output
visualization_samples = 8  # number of random samples to visualize (<= max_eval_samples)
random_seed = 123

# Ensure deterministic subset selection
random.Random(random_seed).shuffle(val_records)

print(f"Eval config: max_eval_samples={max_eval_samples} stride={sample_stride} iou_th={iou_threshold} prob_th={score_threshold} visualize={visualize}")

NameError: name 'val_records' is not defined

In [None]:
def box_iou_xyxy(a, b):
    if a.numel() == 0 or b.numel() == 0:
        return torch.zeros((a.shape[0], b.shape[0]))
    inter_x1 = torch.max(a[:, None, 0], b[:, 0])
    inter_y1 = torch.max(a[:, None, 1], b[:, 1])
    inter_x2 = torch.min(a[:, None, 2], b[:, 2])
    inter_y2 = torch.min(a[:, None, 3], b[:, 3])
    inter_w = (inter_x2 - inter_x1).clamp(min=0)
    inter_h = (inter_y2 - inter_y1).clamp(min=0)
    inter = inter_w * inter_h
    area_a = (a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1])
    area_b = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
    union = area_a[:, None] + area_b - inter
    return inter / union.clamp(min=1e-6)


def match_predictions_to_gt(pred_boxes, pred_scores, gt_boxes, gt_labels, iou_th=0.5, score_th=0.0):
    if pred_boxes.numel() == 0:
        return 0, 0, 0
    keep = pred_scores >= score_th
    if keep.sum() == 0:
        return 0, 0, gt_boxes.shape[0]
    pred_boxes = pred_boxes[keep]
    pred_scores = pred_scores[keep]

    cx, cy, w, h = pred_boxes.unbind(-1)
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    pred_xyxy = torch.stack([x1, y1, x2, y2], dim=-1)
    if gt_boxes.numel() == 0:
        return 0, pred_boxes.shape[0], 0

    ious = box_iou_xyxy(pred_xyxy, gt_boxes)
    matched_gt = set()
    tp = 0
    fp = 0
    for pred_idx in torch.argsort(pred_scores, descending=True):
        best_gt = torch.argmax(ious[pred_idx])
        best_iou = ious[pred_idx, best_gt].item()
        if best_iou >= iou_th and best_gt.item() not in matched_gt:
            matched_gt.add(best_gt.item())
            tp += 1
        else:
            fp += 1
    fn = gt_boxes.shape[0] - len(matched_gt)
    return tp, fp, fn


def evaluate_models(records, pretrained_model, finetuned_model, processor, label_map, max_samples=None, stride=1, iou_th=0.5, score_th=0.0):
    label_texts = list(label_map.values())
    subset = records[::stride]
    if max_samples is not None:
        subset = subset[:max_samples]

    metrics = {
        'pretrained': {'tp':0,'fp':0,'fn':0},
        'finetuned': {'tp':0,'fp':0,'fn':0},
    }

    for r in tqdm(subset, desc='eval', leave=False):
        image = Image.open(r['image_path']).convert('RGB')
        gt_boxes_list = []
        gt_labels_list = []
        for obj in r['objects']:
            if obj['label'] in label_map:
                x, y, w, h = obj['bbox']
                width, height = image.size
                x1 = x / width
                y1 = y / height
                x2 = (x + w) / width
                y2 = (y + h) / height
                gt_boxes_list.append([x1, y1, x2, y2])
                gt_labels_list.append(label_to_idx[obj['label']])
        if gt_boxes_list:
            gt_boxes_tensor = torch.tensor(gt_boxes_list, dtype=torch.float32)
        else:
            gt_boxes_tensor = torch.zeros((0,4), dtype=torch.float32)
        if gt_labels_list:
            gt_labels_tensor = torch.tensor(gt_labels_list, dtype=torch.long)
        else:
            gt_labels_tensor = torch.zeros((0,), dtype=torch.long)

        enc = processor(images=[image], text=[label_texts], return_tensors='pt')
        with torch.no_grad():
            pt_out = pretrained_model(**{k: v.to(device) for k,v in enc.items() if k in ['pixel_values','input_ids','attention_mask']})
            ft_out = finetuned_model(**{k: v.to(device) for k,v in enc.items() if k in ['pixel_values','input_ids','attention_mask']})

        def extract_probs(outputs):
            logits = outputs.logits[0]
            if logits.shape[-1] != len(label_texts):
                sz = min(logits.shape[-1], len(label_texts))
                logits = logits[:, :sz]
            probs = torch.softmax(logits, dim=-1)  # [num_queries, num_labels]
            max_probs, pred_label_idx = torch.max(probs, dim=-1)
            return max_probs
        pt_probs = extract_probs(pt_out).cpu()
        ft_probs = extract_probs(ft_out).cpu()
        pred_boxes_pt = pt_out.pred_boxes[0].cpu()
        pred_boxes_ft = ft_out.pred_boxes[0].cpu()

        tp, fp, fn = match_predictions_to_gt(pred_boxes_pt, pt_probs, gt_boxes_tensor, gt_labels_tensor, iou_th=iou_th, score_th=score_th)
        metrics['pretrained']['tp'] += tp
        metrics['pretrained']['fp'] += fp
        metrics['pretrained']['fn'] += fn

        tp, fp, fn = match_predictions_to_gt(pred_boxes_ft, ft_probs, gt_boxes_tensor, gt_labels_tensor, iou_th=iou_th, score_th=score_th)
        metrics['finetuned']['tp'] += tp
        metrics['finetuned']['fp'] += fp
        metrics['finetuned']['fn'] += fn

    def finalize(m):
        precision = m['tp'] / max(1, (m['tp'] + m['fp']))
        recall = m['tp'] / max(1, (m['tp'] + m['fn']))
        f1 = 2*precision*recall / max(1e-6, (precision + recall))
        return {**m, 'precision': precision, 'recall': recall, 'f1': f1}

    return {k: finalize(v) for k,v in metrics.items()}


def visualize_examples(records, finetuned_model, pretrained_model, processor, label_map, num_samples=6):
    samples_vis = random.sample(records, min(num_samples, len(records)))
    label_texts = list(label_map.values())
    for r in samples_vis:
        image = Image.open(r['image_path']).convert('RGB')
        width, height = image.size
        fig, axes = plt.subplots(1,2, figsize=(12,6))
        for ax, mdl, title in zip(axes, [pretrained_model, finetuned_model], ['Pretrained','Finetuned']):
            enc = processor(images=[image], text=[label_texts], return_tensors='pt')
            with torch.no_grad():
                out = mdl(**{k: v.to(device) for k,v in enc.items() if k in ['pixel_values','input_ids','attention_mask']})
            logits = out.logits[0]
            if logits.shape[-1] != len(label_texts):
                sz = min(logits.shape[-1], len(label_texts))
                logits = logits[:, :sz]
            probs = torch.softmax(logits, dim=-1)
            max_probs, _ = torch.max(probs, dim=-1)
            keep = max_probs.cpu() >= score_threshold
            pred_boxes = out.pred_boxes[0].cpu()[keep]
            cx, cy, w, h = pred_boxes.unbind(-1)
            x1 = (cx - 0.5*w) * width
            y1 = (cy - 0.5*h) * height
            x2 = (cx + 0.5*w) * width
            y2 = (cy + 0.5*h) * height
            ax.imshow(image)
            for bx1, by1, bx2, by2, prob in zip(x1, y1, x2, y2, max_probs[keep]):
                rect = patches.Rectangle((bx1.item(), by1.item()), (bx2-bx1).item(), (by2-by1).item(), linewidth=1, edgecolor='cyan', facecolor='none')
                ax.add_patch(rect)
                ax.text(bx1.item(), by1.item()-5, f"{prob.item():.2f}", color='cyan', fontsize=8)
            for obj in r['objects']:
                if obj['label'] in label_map:
                    gx, gy, gw, gh = obj['bbox']
                    rect = patches.Rectangle((gx, gy), gw, gh, linewidth=2, edgecolor='red', facecolor='none')
                    ax.add_patch(rect)
            ax.set_title(f"{title} (prob>={score_threshold})")
            ax.axis('off')
        plt.tight_layout()
        plt.show()

print("Loaded evaluation helpers with probability thresholding.")

Loaded evaluation helpers.


In [None]:
# Run evaluation
records_for_eval = val_records
metrics = evaluate_models(records_for_eval, model, finetuned_model, processor, label_map,
                          max_samples=max_eval_samples, stride=sample_stride, iou_th=iou_threshold, score_th=score_threshold)
print("Evaluation metrics (IoU>=%.2f, prob>=%.2f):" % (iou_threshold, score_threshold))
for k,v in metrics.items():
    print(f"{k}: tp={v['tp']} fp={v['fp']} fn={v['fn']} precision={v['precision']:.3f} recall={v['recall']:.3f} f1={v['f1']:.3f}")

NameError: name 'val_records' is not defined

In [14]:
if visualize:
    print(f"Visualizing {visualization_samples} examples...")
    visualize_examples(val_records[:max_eval_samples or len(val_records)], finetuned_model, model, processor, label_map, num_samples=visualization_samples)
else:
    print("Visualization disabled. Set visualize=True to enable.")

Visualizing 8 examples...


NameError: name 'val_records' is not defined