In [None]:
import os
import random
import time
import yaml
import pandas as pd
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from ultralytics import YOLO

# For computing IoU
def bbox_iou(box1, box2):
    # Intersection
    xA = max(box1[0], box2[0])
    yA = max(box1[1], box2[1])
    xB = min(box1[2], box2[2])
    yB = min(box1[3], box2[3])
    inter_w = max(0, xB - xA)
    inter_h = max(0, yB - yA)
    inter_area = inter_w * inter_h

    # Areas
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union_area = box1_area + box2_area - inter_area + 1e-9

    return inter_area / union_area


# Ensure reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # For MPS compatibility:
    try:
        torch.backends.mps.manual_seed_all(seed)
    except:
        pass

seed_everything(42)


config = {
    "dataset_root": "pie_dataset",
    "seq_len": 8,              # number of frames in each sequence
    "img_size": 224,           # crop size for each pedestrian patch
    "batch_size": 8,
    "num_epochs": 30,
    "learning_rate": 1e-4,
    "device": "mps" if torch.backends.mps.is_available() else "cpu",
    "num_workers": 2,
    "lambda_cls": 1.0,
    "lambda_reg": 1.0,
    "lambda_det": 1.0,
    # Where to save model checkpoints
    "save_dir": "./checkpoints",
    # YOLO weights (placeholder for YOLOv12). 
    "yolo_weights": "yolov8n.pt",
    "iou_threshold_matching": 0.5,
}

os.makedirs(config["save_dir"], exist_ok=True)
print("Using device:", config["device"])


class PIESequenceDataset(Dataset):
    def __init__(self, root_dir, seq_len=8, img_size=224, transform=None):
        super().__init__()
        self.root_dir = root_dir
        self.seq_len = seq_len
        self.img_size = img_size
        self.transform = transform

        # 1. Load annotations CSV
        ann_path = os.path.join(root_dir, "annotations.csv")
        if not os.path.exists(ann_path):
            raise FileNotFoundError(f"Cannot find {ann_path}. Make sure annotations.csv exists.")

        self.df = pd.read_csv(ann_path)
        # Ensure required columns exist
        required_cols = {
            "video", "frame_id", "track_id",
            "x1", "y1", "x2", "y2",
            "crossing_label", "frames_to_cross"
        }
        if not required_cols.issubset(set(self.df.columns)):
            raise ValueError(f"annotations.csv must contain columns: {required_cols}")

        # 2. Group by (video, track_id) and sort by frame_id
        grouped = self.df.groupby(["video", "track_id"])
        self.sequences = []
        for (video, track_id), group in grouped:
            group_sorted = group.sort_values("frame_id")
            frame_ids = group_sorted["frame_id"].tolist()
            bboxes = list(zip(group_sorted["x1"], group_sorted["y1"],
                              group_sorted["x2"], group_sorted["y2"]))
            crossing_labels = group_sorted["crossing_label"].tolist()
            frames_to_cross = group_sorted["frames_to_cross"].tolist()

            # Sliding window
            num_frames = len(frame_ids)
            for start_idx in range(num_frames - seq_len + 1):
                end_idx = start_idx + seq_len
                seq_frame_ids = frame_ids[start_idx:end_idx]
                seq_bboxes = bboxes[start_idx:end_idx]
                seq_crossing_labels = crossing_labels[start_idx:end_idx]
                seq_frames_to_cross = frames_to_cross[start_idx:end_idx]

                # We define Y_cls = 1 if the *last* frame in this seq has crossing_label==1
                label_cls = int(seq_crossing_labels[-1] == 1)
                # For regression: take frames_to_cross of the *last* frame
                ftc = seq_frames_to_cross[-1]
                if ftc >= 0:
                    label_reg = ftc  # We will convert to seconds in the model
                else:
                    label_reg = -1   # Means “no crossing soon”

                self.sequences.append({
                    "video": video,
                    "frame_ids": seq_frame_ids,
                    "bboxes": seq_bboxes,
                    "label_cls": label_cls,
                    "label_reg": label_reg
                })

        print(f"Total sequences: {len(self.sequences)}")

        # Define transforms for the cropped patch (after detection matching):
        self.patch_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        # Transform to feed full frame into YOLO: only resize so YOLO runs faster
        self.frame_transform = transforms.Compose([
            transforms.Resize((640, 640)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq_info = self.sequences[idx]
        video = seq_info["video"]
        frame_ids = seq_info["frame_ids"]
        bboxes = seq_info["bboxes"]
        label_cls = seq_info["label_cls"]
        label_reg = seq_info["label_reg"]

        full_frames = []
        for i, fid in enumerate(frame_ids):
            img_path = os.path.join(self.root_dir, "images", video, f"{fid}.jpg")
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Missing image: {img_path}")
            img = Image.open(img_path).convert("RGB")
            img_t = self.frame_transform(img)  # for YOLO
            full_frames.append(img_t)

        frames_tensor = torch.stack(full_frames, dim=0)  # (seq_len, 3, 640, 640)

        # Return bboxes in a Tensor as well
        bboxes_tensor = torch.tensor(bboxes, dtype=torch.float32)  # (seq_len, 4)

        return frames_tensor, bboxes_tensor, torch.tensor(label_cls, dtype=torch.float32), torch.tensor(label_reg, dtype=torch.float32)



class YOLOv12Backbone(nn.Module):
    def __init__(self, weights_path="yolov8n.pt", iou_thresh=0.5, device="cpu"):
        super().__init__()
        self.device = device
        # Load Ultralytics YOLO model
        self.model = YOLO(weights_path)
        self.model.to(device)
        self.iou_thresh = iou_thresh

        # A small MLP to turn (x1,y1,x2,y2,conf) → 64-dim
        self.fc = nn.Linear(5, 64)

    @torch.no_grad()
    def forward(self, frames, gt_bboxes):
        batch_size, seq_len, C, H, W = frames.shape
        yolo_feats = torch.zeros(batch_size, seq_len, 64, device=self.device)

        for b in range(batch_size):
            for t in range(seq_len):
                frame = frames[b, t]  # (3,640,640)
                bbox_gt = gt_bboxes[b, t].tolist()  # [x1,y1,x2,y2]
                # Run YOLO detection on this single frame
                results = self.model.predict(source=frame.unsqueeze(0), imgsz=640, device=self.device, save=False, save_txt=False)
                # results is a list with one element: a Results object
                det = results[0].boxes  # det: Boxes object with .xyxy, .conf, .cls

                best_iou = 0.0
                best_feat = torch.zeros(5, device=self.device)  # placeholder [x1,y1,x2,y2,conf]

                if det.shape[0] > 0:
                    # det.xyxy is a tensor (num_dets, 4), det.conf is (num_dets,)
                    boxes_xyxy = det.xyxy.cpu().numpy()  # shape (num_dets,4)
                    confs = det.conf.cpu().numpy()       # shape (num_dets,)
                    for i in range(boxes_xyxy.shape[0]):
                        box = boxes_xyxy[i].tolist()
                        conf = float(confs[i])
                        iou = bbox_iou(box, bbox_gt)
                        if iou > best_iou:
                            best_iou = iou
                            best_feat = torch.tensor([box[0], box[1], box[2], box[3], conf], device=self.device)

                # Only accept if best_iou >= threshold; else leave feat=zero
                if best_iou < self.iou_thresh:
                    best_feat = torch.zeros(5, device=self.device)

                # Pass through MLP → 64 dims
                yolo_feats[b, t] = self.fc(best_feat)

        return yolo_feats  # (batch, seq_len, 64)


class SimpleCNNBackbone(nn.Module):

    def __init__(self, img_size=224):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        # After two poolings, size = img_size/4 × img_size/4
        conv_output_size = (img_size // 4) * (img_size // 4) * 64
        self.fc = nn.Linear(conv_output_size, 128)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # (batch*seq_len, 32, img_size/2, img_size/2)
        x = self.pool(F.relu(self.conv2(x)))  # (batch*seq_len, 64, img_size/4, img_size/4)
        x = x.view(x.size(0), -1)
        return self.fc(x)  # (batch*seq_len, 128)


class TemporalTransformer(nn.Module):
    def __init__(self, embed_dim=256, num_heads=4, num_layers=2):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        B, S, D = x.size()
        # expand cls token to batch
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, 1+S, D)
        out = self.transformer(x)              # (B, 1+S, D)
        return out[:, 0, :]  # return only the CLS token output  → (B, D)


class PredictionHead(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        self.intent_fc = nn.Linear(embed_dim, 1)
        self.time_fc = nn.Linear(embed_dim, 1)

    def forward(self, x):
        intent = torch.sigmoid(self.intent_fc(x))
        ttc = self.time_fc(x)
        return intent.squeeze(-1), ttc.squeeze(-1)


class YOLOv12_IntentNet(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.device = cfg["device"]
        self.seq_len = cfg["seq_len"]
        self.img_size = cfg["img_size"]

        # 1. YOLOv12 backbone
        self.yolo_backbone = YOLOv12Backbone(
            weights_path=cfg["yolo_weights"],
            iou_thresh=cfg["iou_threshold_matching"],
            device=cfg["device"]
        )

        # 2. Small CNN for cropped patches
        self.cnn_backbone = SimpleCNNBackbone(img_size=self.img_size)

        # 3. After concatenation (64 + 128 = 192), project to 256
        self.linear_fuse = nn.Linear(64 + 128, 256)

        # 4. Transformer for temporal modeling
        self.temporal_model = TemporalTransformer(embed_dim=256, num_heads=4, num_layers=2)

        # 5. Prediction head
        self.pred_head = PredictionHead(embed_dim=256)

    def forward(self, frames, gt_bboxes):
        B, S, C, H, W = frames.shape

        # 1. YOLO → (B, S, 64)
        yolo_feats = self.yolo_backbone(frames, gt_bboxes)  # (B, S, 64)

        # 2. Crop patches & get CNN features
        patches = []
        for b in range(B):
            for t in range(S):
                # We crop from the *original* 640×640 frame
                img_np = (frames[b, t].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                pil = Image.fromarray(img_np)
                x1, y1, x2, y2 = gt_bboxes[b, t].cpu().numpy().astype(int).tolist()
                x1 = max(0, min(x1, W - 1))
                y1 = max(0, min(y1, H - 1))
                x2 = max(0, min(x2, W - 1))
                y2 = max(0, min(y2, H - 1))
                if x2 <= x1 or y2 <= y1:
                    # If invalid box, crop a small center patch
                    cx1, cy1 = W // 2 - self.img_size // 2, H // 2 - self.img_size // 2
                    cx2, cy2 = cx1 + self.img_size, cy1 + self.img_size
                    patch = pil.crop((cx1, cy1, cx2, cy2))
                else:
                    patch = pil.crop((x1, y1, x2, y2))
                patch = patch.resize((self.img_size, self.img_size))
                patch_tensor = transforms.ToTensor()(patch)
                patch_tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                    std=[0.229, 0.224, 0.225])(patch_tensor)
                patches.append(patch_tensor)

        patches_tensor = torch.stack(patches, dim=0)  # (B*S, 3, img_size, img_size)
        cnn_feats = self.cnn_backbone(patches_tensor)  # (B*S, 128)
        cnn_feats = cnn_feats.view(B, S, 128)           # (B, S, 128)

        # 3. Concatenate YOLO feats + CNN feats → (B, S, 192)
        combined = torch.cat([yolo_feats, cnn_feats], dim=2)  # (B, S, 64+128=192)

        # Project to 256 dims
        fused = self.linear_fuse(combined)  # (B, S, 256)

        # 4. Transformer → (B, 256) [CLS token]
        cls_out = self.temporal_model(fused)  # (B, 256)

        # 5. Heads → (intent, ttc)
        intent_pred, ttc_pred = self.pred_head(cls_out)  # each is (B,)
        return intent_pred, ttc_pred


class CombinedLoss(nn.Module):
    def __init__(self, lambda_cls=1.0, lambda_reg=1.0):
        super().__init__()
        self.lambda_cls = lambda_cls
        self.lambda_reg = lambda_reg
        self.bce = nn.BCELoss()
        self.smooth_l1 = nn.SmoothL1Loss()

    def forward(self, intent_pred, ttc_pred, label_cls, label_reg):
        loss_cls = self.bce(intent_pred, label_cls)

        # For regression, mask out samples where label_reg < 0
        mask = (label_reg >= 0).float()
        if mask.sum() > 0:
            loss_reg = self.smooth_l1(
                ttc_pred[mask == 1], 
                label_reg[mask == 1]
            )
        else:
            loss_reg = torch.tensor(0.0, device=intent_pred.device)

        total_loss = self.lambda_cls * loss_cls + self.lambda_reg * loss_reg
        return total_loss, loss_cls.detach(), loss_reg.detach()


def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    running_cls = 0.0
    running_reg = 0.0

    for i, (frames, bboxes, label_cls, label_reg) in enumerate(dataloader):
        # frames: (B, S, 3, 640,640)
        # bboxes: (B, S, 4)
        frames = frames.to(device)
        bboxes = bboxes.to(device)
        label_cls = label_cls.to(device)
        label_reg = label_reg.to(device)

        optimizer.zero_grad()
        intent_pred, ttc_pred = model(frames, bboxes)

        loss, loss_cls, loss_reg = criterion(intent_pred, ttc_pred, label_cls, label_reg)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_cls += loss_cls.item()
        running_reg += loss_reg.item()

        if (i + 1) % 20 == 0:
            avg_loss = running_loss / 20
            avg_cls = running_cls / 20
            avg_reg = running_reg / 20
            print(f"  Batch {i+1}/{len(dataloader)} | "
                  f"Loss: {avg_loss:.4f} | BCL: {avg_cls:.4f} | S-L1: {avg_reg:.4f}")
            running_loss = 0.0
            running_cls = 0.0
            running_reg = 0.0


def validate_one_epoch(model, dataloader, criterion, device):
    model.eval()
    val_loss = 0.0
    val_cls = 0.0
    val_reg = 0.0
    all_intent_preds = []
    all_intent_labels = []

    with torch.no_grad():
        for i, (frames, bboxes, label_cls, label_reg) in enumerate(dataloader):
            frames = frames.to(device)
            bboxes = bboxes.to(device)
            label_cls = label_cls.to(device)
            label_reg = label_reg.to(device)

            intent_pred, ttc_pred = model(frames, bboxes)
            loss, loss_cls, loss_reg = criterion(intent_pred, ttc_pred, label_cls, label_reg)

            val_loss += loss.item()
            val_cls += loss_cls.item()
            val_reg += loss_reg.item()

            # Collect for computing metrics
            all_intent_preds.append(intent_pred.cpu())
            all_intent_labels.append(label_cls.cpu())

    avg_loss = val_loss / len(dataloader)
    avg_cls = val_cls / len(dataloader)
    avg_reg = val_reg / len(dataloader)

    # Compute simple accuracy for intent prediction
    all_intents = torch.cat(all_intent_preds, dim=0)
    all_labels = torch.cat(all_intent_labels, dim=0)
    preds_binary = (all_intents >= 0.5).float()
    acc = (preds_binary == all_labels).float().mean().item()

    return avg_loss, avg_cls, avg_reg, acc


# %% [code]
# Main training loop
def main_training(config):
    device = config["device"]

    # 1. Dataset & Dataloaders
    full_dataset = PIESequenceDataset(
        root_dir=config["dataset_root"],
        seq_len=config["seq_len"],
        img_size=config["img_size"]
    )
    # Split 80/20
    num_samples = len(full_dataset)
    indices = list(range(num_samples))
    random.shuffle(indices)
    split = int(0.8 * num_samples)
    train_idx, val_idx = indices[:split], indices[split:]

    train_dataset = torch.utils.data.Subset(full_dataset, train_idx)
    val_dataset = torch.utils.data.Subset(full_dataset, val_idx)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=config["num_workers"],
        pin_memory=True if device != "cpu" else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=config["num_workers"],
        pin_memory=True if device != "cpu" else False
    )

    # 2. Model, Optimizer, Loss
    model = YOLOv12_IntentNet(config).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=1e-4)
    criterion = CombinedLoss(
        lambda_cls=config["lambda_cls"],
        lambda_reg=config["lambda_reg"]
    )

    best_val_loss = float("inf")
    for epoch in range(config["num_epochs"]):
        print(f"\n===== Epoch {epoch+1}/{config['num_epochs']} =====")
        start_time = time.time()

        # Train
        train_one_epoch(model, train_loader, optimizer, criterion, device)

        # Validate
        val_loss, val_cls_l, val_reg_l, val_acc = validate_one_epoch(model, val_loader, criterion, device)
        print(f"Validation → Total Loss: {val_loss:.4f} | BCE: {val_cls_l:.4f} | S‐L1: {val_reg_l:.4f} | "
              f"Intent Acc: {val_acc * 100:.2f}%")

        # Save checkpoint if improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            ckpt_path = os.path.join(config["save_dir"], f"best_model_epoch{epoch+1}.pt")
            torch.save({
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_loss": val_loss,
            }, ckpt_path)
            print(f"  → Saved new best checkpoint: {ckpt_path}")

        elapsed = time.time() - start_time
        print(f"Epoch {epoch+1} completed in {elapsed//60:.0f}m {elapsed%60:.0f}s")

    print("\nTraining finished.")


# Run training
if __name__ == "__main__":
    main_training(config)


import matplotlib.pyplot as plt
import matplotlib.patches as patches

def inference_example(config, checkpoint_path):
    device = config["device"]
    # Rebuild dataset (same split)
    full_dataset = PIESequenceDataset(
        root_dir=config["dataset_root"],
        seq_len=config["seq_len"],
        img_size=config["img_size"]
    )
    num_samples = len(full_dataset)
    indices = list(range(num_samples))
    random.shuffle(indices)
    split = int(0.8 * num_samples)
    val_idx = indices[split:]
    val_dataset = torch.utils.data.Subset(full_dataset, val_idx)

    # Pick one random sample
    sample_idx = random.choice(val_idx)
    frames_tensor, bboxes_tensor, label_cls, label_reg = full_dataset[sample_idx]
    # Add batch dimension
    frames_batch = frames_tensor.unsqueeze(0).to(device)   # (1, S, 3, 640,640)
    bboxes_batch = bboxes_tensor.unsqueeze(0).to(device)   # (1, S, 4)

    # Build model & load checkpoint
    model = YOLOv12_IntentNet(config).to(device)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    # Run forward
    with torch.no_grad():
        intent_p, ttc_p = model(frames_batch, bboxes_batch)
        intent_p = intent_p.item()
        ttc_p = ttc_p.item()

    print(f"Ground‐Truth Intent: {label_cls.item():.0f} | Predicted Intent: {intent_p:.3f}")
    print(f"Ground‐Truth ttc (frames): {label_reg.item():.1f} | Predicted ttc (frames): {ttc_p:.3f}")

    # Visualize the S frames with GT bounding boxes
    frames_np = (frames_tensor.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
    fig, axs = plt.subplots(1, config["seq_len"], figsize=(20, 3))
    for t in range(config["seq_len"]):
        axs[t].imshow(frames_np[t])
        x1, y1, x2, y2 = bboxes_tensor[t].numpy().tolist()
        rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                 linewidth=2, edgecolor='red', facecolor='none')
        axs[t].add_patch(rect)
        axs[t].axis("off")
    plt.suptitle(f"Pred Intent: {intent_p:.3f}, Pred TTC(frames): {ttc_p:.2f}")
    plt.show()