In [None]:
import os, random, time, xml.etree.ElementTree as ET
import pandas as pd, numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.io import read_video

from ultralytics import YOLO

# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    try:
        torch.backends.mps.manual_seed_all(seed)
    except:
        pass

def bbox_iou(box1, box2):
    xA = max(box1[0], box2[0])
    yA = max(box1[1], box2[1])
    xB = min(box1[2], box2[2])
    yB = min(box1[3], box2[3])
    inter = max(0, xB-xA) * max(0, yB-yA)
    A = (box1[2]-box1[0])*(box1[3]-box1[1])
    B = (box2[2]-box2[0])*(box2[3]-box2[1])
    return inter / (A + B - inter + 1e-9)

seed_everything(42)

# ─────────────────────────────────────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────────────────────────────────────
config = {
    "dataset_root": "./PIE Dataset",   # <-- your top‐level folder
    "seq_len": 8,
    "img_size": 320,
    "batch_size": 4,
    "num_epochs": 30,
    "learning_rate": 1e-4,
    "device": "mps" if torch.backends.mps.is_available() else "cpu",
    "num_workers": 0,
    "lambda_cls": 1.0,
    "lambda_reg": 1.0,
    "iou_threshold_matching": 0.5,
    "save_dir": "./checkpoints",
    "yolo_weights": "yolov8n.pt",
}
os.makedirs(config["save_dir"], exist_ok=True)
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
# print("Using device:", device)
config["device"] = device

print("Using device:", config["device"])

# ─────────────────────────────────────────────────────────────────────────────
# Dataset
# ─────────────────────────────────────────────────────────────────────────────
class PIESequenceDataset(Dataset):
    def __init__(self, root_dir, seq_len=8, img_size=224):
        self.root = root_dir
        self.seq_len = seq_len
        self.img_size = img_size

        ann_root = os.path.join(root_dir, "annotations")
        records = []
        for set_name in sorted(os.listdir(ann_root)):
            ann_set = os.path.join(ann_root, set_name)
            vid_set = os.path.join(root_dir, set_name)
            if not os.path.isdir(vid_set):
                continue

            for xml_file in sorted(os.listdir(ann_set)):
                if not xml_file.endswith(".xml"):
                    continue
                video_id = xml_file.replace("_annt.xml", "")
                xml_path = os.path.join(ann_set, xml_file)
                tree = ET.parse(xml_path)
                root = tree.getroot()

                for track in root.findall("track"):
                    track_id = int(track.attrib.get("id", 0))
                    for box in track.findall("box"):
                        frame_id = int(box.attrib["frame"])
                        x1 = float(box.attrib["xtl"])
                        y1 = float(box.attrib["ytl"])
                        x2 = float(box.attrib["xbr"])
                        y2 = float(box.attrib["ybr"])

                        cl_lab, ftc = 0, -1
                        for attr in box.findall("attribute"):
                            name = attr.attrib.get("name", "")
                            if name == "crossing_label":
                                cl_lab = int(attr.text)
                            elif name == "frames_to_cross":
                                ftc = int(attr.text)

                        records.append({
                            "set": set_name,
                            "video": video_id,
                            "frame_id": frame_id,
                            "track_id": track_id,
                            "x1": x1, "y1": y1, "x2": x2, "y2": y2,
                            "crossing_label": cl_lab,
                            "frames_to_cross": ftc
                        })

        self.df = pd.DataFrame.from_records(records)

        # 2) Build sliding‐window sequences
        self.sequences = []
        grouped = self.df.groupby(["set", "video", "track_id"])
        for (set_name, video_id, track_id), g in grouped:
            g = g.sort_values("frame_id")
            fids = g.frame_id.tolist()
            bbs = list(zip(g.x1, g.y1, g.x2, g.y2))
            cll = g.crossing_label.tolist()
            ftc = g.frames_to_cross.tolist()

            for i in range(len(fids) - seq_len + 1):
                sf = fids[i : i+seq_len]
                sb = bbs[i : i+seq_len]
                lbl_cls = int(cll[i+seq_len-1] == 1)
                lbl_reg = ftc[i+seq_len-1] if ftc[i+seq_len-1] >= 0 else -1
                self.sequences.append({
                    "set": set_name,
                    "video": video_id,
                    "frame_ids": sf,
                    "bboxes": sb,
                    "label_cls": lbl_cls,
                    "label_reg": lbl_reg
                })

        print(f"Total sequences: {len(self.sequences)}")

        self.frame_tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((640, 640)),
            transforms.ToTensor(),
        ])

        self.video_cache = {}

    def load_video(self, set_name, video_id):
        key = (set_name, video_id)
        if key in self.video_cache:
            return self.video_cache[key]

        video_path = os.path.join(self.root, set_name, f"{video_id}.mp4")
        v, _, _ = read_video(video_path, pts_unit="sec")  # [T, H, W, C]
        v = v.permute(0, 3, 1, 2)                           # [T, C, H0, W0]
        H0, W0 = v.shape[2], v.shape[3]
        self.video_cache[key] = (v, (H0, W0))
        return self.video_cache[key]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        v, (H0, W0) = self.load_video(seq["set"], seq["video"])
        scale_x, scale_y = 640 / W0, 640 / H0

        frames = []
        scaled_bbs = []
        for fid, (x1, y1, x2, y2) in zip(seq["frame_ids"], seq["bboxes"]):
            frm = v[fid]                    # (C, H0, W0)
            tfrm = self.frame_tf(frm)       # → (3, 640, 640)
            frames.append(tfrm)

            scaled_bbs.append([
                x1 * scale_x, y1 * scale_y,
                x2 * scale_x, y2 * scale_y
            ])

        frames = torch.stack(frames, dim=0)                    # (S, 3, 640, 640)
        bboxes = torch.tensor(scaled_bbs, dtype=torch.float32) # (S, 4)
        return (
            frames,
            bboxes,
            torch.tensor(seq["label_cls"], dtype=torch.float32),
            torch.tensor(seq["label_reg"], dtype=torch.float32),
        )

# ─────────────────────────────────────────────────────────────────────────────
# Model definition
# ─────────────────────────────────────────────────────────────────────────────
class YOLOv12Backbone(nn.Module):
    def __init__(self, weights_path="yolov8n.pt", iou_thresh=0.5, device="cpu"):
        super().__init__()
        self.device = device

        self.model = YOLO(weights_path).to(device)
        self.iou_thresh = iou_thresh

        self.fc = nn.Linear(5, 64)

    def train(self, mode: bool = True):
        self.training = mode
        return self

    @torch.no_grad()
    def forward(self, frames, gt_bboxes):
        B, S, C, H, W = frames.shape
        out = torch.zeros(B, S, 64, device=self.device)

        for b in range(B):
            for t in range(S):

                single = frames[b, t].unsqueeze(0).to(self.device)
                results = self.model.predict(
                    source=single,
                    imgsz=H,
                    device=self.device,
                    save=False,
                    save_txt=False
                )[0].boxes

                best_iou  = 0.0
                best_feat = torch.zeros(5, device=self.device)

                if results.shape[0] > 0:
                    boxes = results.xyxy.cpu().numpy()   # (N,4)
                    confs = results.conf.cpu().numpy()   # (N,)
                    for i, box in enumerate(boxes):
                        iou = bbox_iou(box.tolist(), gt_bboxes[b, t].tolist())
                        if iou > best_iou:
                            best_iou = iou
                            best_feat = torch.tensor(
                                [box[0], box[1], box[2], box[3], float(confs[i])],
                                device=self.device
                            )

                if best_iou < self.iou_thresh:
                    best_feat = torch.zeros(5, device=self.device)

                out[b, t] = self.fc(best_feat)

        return out

class SimpleCNNBackbone(nn.Module):
    def __init__(self, img_size=224):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool  = nn.MaxPool2d(2, 2)
        conv_out = (img_size//4)*(img_size//4)*64
        self.fc    = nn.Linear(conv_out, 128)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        return self.fc(x.view(x.size(0), -1))

class TemporalTransformer(nn.Module):
    def __init__(self, embed_dim=256, num_heads=4, num_layers=2):
        super().__init__()
        enc = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim*4,
            batch_first=True
        )
        self.tr = nn.TransformerEncoder(enc, num_layers=num_layers)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        B, S, D = x.size()
        cls = self.cls_token.expand(B, -1, -1)
        out = self.tr(torch.cat([cls, x], dim=1))
        return out[:, 0]

class PredictionHead(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        self.intent_fc = nn.Linear(embed_dim, 1)
        self.time_fc   = nn.Linear(embed_dim, 1)

    def forward(self, x):
        intent = torch.sigmoid(self.intent_fc(x)).squeeze(-1)
        ttc    = self.time_fc(x).squeeze(-1)
        return intent, ttc

class YOLOv12_IntentNet(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.device = cfg["device"]
        self.yolo_backbone = YOLOv12Backbone(
            cfg["yolo_weights"], cfg["iou_threshold_matching"], cfg["device"]
        )
        self.cnn_backbone  = SimpleCNNBackbone(cfg["img_size"])
        self.linear_fuse   = nn.Linear(64 + 128, 256)
        self.temporal_model= TemporalTransformer(256, 4, 2)
        self.pred_head     = PredictionHead(256)

    def forward(self, frames, gt_bboxes):
        B, S, C, H, W = frames.shape
        y_feats = self.yolo_backbone(frames, gt_bboxes)  # (B, S, 64)

        patches = []
        for b in range(B):
            for t in range(S):
                arr = (frames[b,t].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
                pil = Image.fromarray(arr)
                x1,y1,x2,y2 = gt_bboxes[b,t].cpu().int().tolist()
                if x2<=x1 or y2<=y1:

                    cx, cy = W//2, H//2
                    x1_,y1_ = cx-self.img_size//2, cy-self.img_size//2
                    x2_,y2_ = x1_+self.img_size, y1_+self.img_size
                    patch = pil.crop((x1_,y1_,x2_,y2_))
                else:
                    patch = pil.crop((x1,y1,x2,y2))
                patch = patch.resize((self.img_size, self.img_size))
                patch_t = transforms.ToTensor()(patch)
                patch_t = transforms.Normalize(
                    mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]
                )(patch_t)
                patches.append(patch_t)
        patches = torch.stack(patches, 0).to(self.device)
        cnn_feats = self.cnn_backbone(patches).view(B, S, 128)

        fus = self.linear_fuse(torch.cat([y_feats, cnn_feats], dim=2))
        cls_tok = self.temporal_model(fus)
        return self.pred_head(cls_tok)

class CombinedLoss(nn.Module):
    def __init__(self, lambda_cls=1.0, lambda_reg=1.0):
        super().__init__()
        self.bce = nn.BCELoss()
        self.l1  = nn.SmoothL1Loss()
        self.lc, self.lr = lambda_cls, lambda_reg

    def forward(self, intent_p, ttc_p, label_cls, label_reg):
        loss_c = self.bce(intent_p, label_cls)
        mask   = (label_reg >= 0)
        if mask.sum()>0:
            loss_r = self.l1(ttc_p[mask], label_reg[mask])
        else:
            loss_r = torch.tensor(0.0, device=intent_p.device)
        return self.lc*loss_c + self.lr*loss_r, loss_c.detach(), loss_r.detach()

# ─────────────────────────────────────────────────────────────────────────────
# Training / Validation
# ─────────────────────────────────────────────────────────────────────────────
def train_one_epoch(model, loader, opt, crit, device):
    model.train()
    for i, (fr, bb, lc, lr) in enumerate(loader, 1):
        fr, bb, lc, lr = fr.to(device), bb.to(device), lc.to(device), lr.to(device)
        opt.zero_grad()
        pi, pt = model(fr, bb)
        loss, _, _ = crit(pi, pt, lc, lr)
        loss.backward()
        opt.step()
        if i % 20 == 0:
            print(f"  Batch {i}/{len(loader)} | Loss {loss.item():.4f}")

def validate_one_epoch(model, loader, crit, device):
    model.eval()
    losses, preds, labs = [], [], []
    with torch.no_grad():
        for fr, bb, lc, lr in loader:
            fr, bb, lc, lr = fr.to(device), bb.to(device), lc.to(device), lr.to(device)
            pi, pt = model(fr, bb)
            losses.append(crit(pi, pt, lc, lr)[1].item())
            preds.append((pi>=0.5).float().cpu())
            labs.append(lc.cpu())
    acc = (torch.cat(preds)==torch.cat(labs)).float().mean().item()
    return np.mean(losses), acc

# ─────────────────────────────────────────────────────────────────────────────
def main():
    # Dataset & Split
    ds = PIESequenceDataset(config["dataset_root"], config["seq_len"], config["img_size"])
    idx = list(range(len(ds))); random.shuffle(idx)
    split = int(0.8 * len(idx))
    train_ds = torch.utils.data.Subset(ds, idx[:split])
    val_ds   = torch.utils.data.Subset(ds, idx[split:])

    train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True,
                              num_workers=config["num_workers"], pin_memory=False)
    val_loader   = DataLoader(val_ds,   batch_size=config["batch_size"], shuffle=False,
                              num_workers=config["num_workers"], pin_memory=False)

    model = YOLOv12_IntentNet(config).to(config["device"])
    opt   = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=1e-4)
    crit  = CombinedLoss(config["lambda_cls"], config["lambda_reg"])

    best_val = float("inf")
    for ep in range(1, config["num_epochs"]+1):
        print(f"\n===== Epoch {ep}/{config['num_epochs']} =====")
        t0 = time.time()
        train_one_epoch(model, train_loader, opt, crit, config["device"])
        val_loss, val_acc = validate_one_epoch(model, val_loader, crit, config["device"])
        print(f"Validation → Loss: {val_loss:.4f} | Acc: {val_acc*100:.2f}% | Time: {(time.time()-t0):.0f}s")

        if val_loss < best_val:
            best_val = val_loss
            ckpt = os.path.join(config["save_dir"], f"best_epoch{ep}.pt")
            torch.save(model.state_dict(), ckpt)
            print("  → Saved checkpoint:", ckpt)
            
    if device.type == "mps":
        torch.mps.empty_cache()

    print("\nTraining finished.")

if __name__ == "__main__":
    main()
