In [3]:
# Deterministic, self calibrating few-shot gesture recognizer.
# DCT per channel over time, PCA whitening to embed_dim, Gaussian head with shrinkage,
# optional conformal abstention, SafetyGate for hold to confirm.
# MediaPipe is optional, the code falls back gracefully.

from __future__ import annotations

import os
import io
import json
import math
import time
import random
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict, Any, Deque
from collections import deque, defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Optional MediaPipe
try:
    import mediapipe as mp
    _HAS_MP = True
except Exception:
    mp = None
    _HAS_MP = False

# -----------------------------
# Configs
# -----------------------------

@dataclass
class ModelConfig:
    # sequence config
    seq_len: int = 32
    min_len: int = 12
    frame_dim: int = 42            # 21 points, x y
    use_deltas: bool = True

    # deterministic encoder
    dct_k_pos: int = 12            # keep lowest K temporal freqs per channel
    dct_k_delta: int = 10
    pca_embed_dim: int = 128       # embedding size after PCA
    pca_whiten: bool = True
    energy_norm: bool = False

    # safety gate
    conf_thresh: float = 0.70
    smooth_k: int = 7
    ema_alpha: float = 0.9
    confirm_secs: float = 2.5

    # Gaussian head and uncertainty
    logit_scale: float = 5.0
    cov_reg: float = 1e-2
    shrinkage_lambda: float = 0.20
    shrink_target: str = "identity"

    # Conformal abstention
    conformal_alpha: float = 0.10
    conformal_mondrian: bool = True
    calib_frac: float = 0.25

    # augmentation
    pos_noise_std: float = 0.01
    time_warp_prob: float = 0.0
    time_warp_segments: int = 3
    time_warp_sigma: float = 0.20


@dataclass
class GateConfig:
    conf_thresh: float
    smooth_k: int
    ema_alpha: float
    confirm_secs: float


@dataclass
class Gesture:
    name: str
    action: str


# -----------------------------
# Safety gate, continuous hold
# -----------------------------

class SafetyGate:
    def __init__(self, cfg: GateConfig) -> None:
        self.cfg = cfg
        self.reset()

    def reset(self) -> None:
        self.ema_conf = 0.0
        self.hist: Deque[int] = deque(maxlen=self.cfg.smooth_k)
        self.state = "idle"
        self.pending_idx: Optional[int] = None
        self.hold_s: float = 0.0
        self._last_t: Optional[float] = None

    def update(self, idx: int, conf: float, now: float, conf_thresh_override: Optional[float] = None) -> str:
        self._last_t = self._last_t or now
        dt = max(0.0, now - self._last_t)
        self._last_t = now

        self.ema_conf = self.cfg.ema_alpha * self.ema_conf + (1 - self.cfg.ema_alpha) * max(conf, 0.0)
        thresh = float(conf_thresh_override) if conf_thresh_override is not None else self.cfg.conf_thresh

        if idx >= 0 and self.ema_conf >= thresh:
            self.hist.append(idx)
            items = list(self.hist)
            if not items:
                self.state = "arming"
                return self.state
            majority = max(set(items), key=items.count)
            need = max(2, math.ceil(self.cfg.smooth_k * 0.6))
            stable = items.count(majority) >= need
            if stable:
                if self.pending_idx != majority:
                    self.pending_idx = majority
                    self.hold_s = 0.0
                self.hold_s += dt
                self.state = "confirm" if self.hold_s >= self.cfg.confirm_secs else "countdown"
            else:
                self.state = "arming"
        else:
            self.state = "arming"
            self.hold_s = max(0.0, self.hold_s - 0.5 * dt)
        return self.state

    def remaining(self, now: float) -> float:
        return max(0.0, self.cfg.confirm_secs - self.hold_s)

    def decide_fire(self, cancel: bool, now: float) -> bool:
        if cancel:
            self.reset()
            return False
        ok = self.state == "confirm"
        self.reset()
        return ok


# -----------------------------
# Async camera
# -----------------------------

import threading
import cv2

class VideoCaptureAsync:
    def __init__(self, src: int = 0, width: Optional[int] = None, height: Optional[int] = None):
        self.src = src
        self.cap = cv2.VideoCapture(src)
        if not self.cap.isOpened():
            raise RuntimeError(f"Could not open camera index {src}")
        # lower latency if backend supports it
        try:
            self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
        except Exception:
            pass
        if width is not None:
            self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
        if height is not None:
            self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
        self.q: Deque[np.ndarray] = deque(maxlen=1)
        self.lock = threading.Lock()
        self.stop_flag = threading.Event()
        self.th = threading.Thread(target=self._worker, daemon=True)
        self.th.start()

    def _worker(self) -> None:
        miss = 0
        while not self.stop_flag.is_set():
            ok, frame = self.cap.read()
            if not ok or frame is None:
                miss += 1
                if miss > 50:
                    try:
                        self.cap.release()
                        time.sleep(0.1)
                        self.cap = cv2.VideoCapture(self.src)
                        try:
                            self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
                        except Exception:
                            pass
                    except Exception:
                        pass
                    miss = 0
                time.sleep(0.02)
                continue
            miss = 0
            with self.lock:
                self.q.clear()
                self.q.append(frame)

    def read(self) -> Tuple[bool, Optional[np.ndarray]]:
        with self.lock:
            if self.q:
                return True, self.q[-1].copy()
        ok, frame = self.cap.read()
        return ok, frame if ok else None

    def stop(self) -> None:
        self.stop_flag.set()

    def release(self) -> None:
        try:
            self.stop()
        except Exception:
            pass
        try:
            if self.th.is_alive():
                self.th.join(timeout=0.5)
        except Exception:
            pass
        try:
            self.cap.release()
        except Exception:
            pass


class LandmarkSmoother:
    def __init__(self, window_size: int = 5):
        self.k = max(1, int(window_size))
        self.buf: Deque[np.ndarray] = deque(maxlen=self.k)

    def reset(self) -> None:
        self.buf.clear()

    def apply(self, pts_xy: np.ndarray) -> np.ndarray:
        self.buf.append(pts_xy)
        arr = np.stack(list(self.buf), axis=0)
        return arr.mean(axis=0)


# -----------------------------
# Orientation normalizer
# -----------------------------

class OrientationNormalizer:
    """
    Smooth scale and rotation per stream. Reset when you clear the window or switch modes.
    """
    def __init__(self, max_deg_step: float = 10.0, scale_alpha: float = 0.9):
        self.max_step = math.radians(max_deg_step)
        self.scale_alpha = float(scale_alpha)
        self.prev_theta: Optional[float] = None
        self.prev_scale: Optional[float] = None

    def reset(self) -> None:
        self.prev_theta = None
        self.prev_scale = None

    def normalize(self, rel: np.ndarray) -> np.ndarray:
        v = rel[9] - rel[0]
        ang = math.atan2(v[1], v[0])
        theta = (math.pi / 2.0) - ang
        if self.prev_theta is None:
            self.prev_theta = float(theta)
        else:
            dtheta = float(theta - self.prev_theta)
            dtheta = max(-self.max_step, min(self.max_step, dtheta))
            self.prev_theta = self.prev_theta + dtheta
        c = math.cos(self.prev_theta)
        s = math.sin(self.prev_theta)
        R = np.array([[c, -s], [s, c]], dtype=np.float32)
        return rel @ R.T

    def smooth_scale(self, scale: float) -> float:
        if self.prev_scale is None:
            self.prev_scale = float(scale)
        else:
            self.prev_scale = self.scale_alpha * self.prev_scale + (1.0 - self.scale_alpha) * float(scale)
        return max(1e-6, float(self.prev_scale))


# -----------------------------
# Landmarks, rotation normalization
# -----------------------------

def _init_mp_hands():
    hands = mp.solutions.hands.Hands(
        static_image_mode=False,
        max_num_hands=2,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5,
        model_complexity=1,
    )
    return hands

_MP_CTX = _init_mp_hands() if _HAS_MP else None

def extract_landmarks(
    frame_bgr: np.ndarray,
    smoother: Optional[LandmarkSmoother] = None,
    orient: Optional[OrientationNormalizer] = None,
    canonicalize_left_to_right: bool = True
) -> Tuple[Optional[torch.Tensor], Optional[str]]:
    if not _HAS_MP or _MP_CTX is None:
        return None, None
    h, w = frame_bgr.shape[:2]
    img_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    res = _MP_CTX.process(img_rgb)
    if not res.multi_hand_landmarks:
        return None, None

    idx_best = 0
    handed = None
    if res.multi_handedness:
        scores = []
        for i, hd in enumerate(res.multi_handedness):
            try:
                scores.append((i, hd.classification[0].score, hd.classification[0].label))
            except Exception:
                scores.append((i, 0.0, None))
        scores.sort(key=lambda t: t[1], reverse=True)
        idx_best, _, handed = scores[0]
    lms = res.multi_hand_landmarks[idx_best]

    pts = np.array([[lm.x, lm.y] for lm in lms.landmark], dtype=np.float32)
    if smoother is not None:
        pts_px = pts.copy()
        pts_px[:, 0] *= w
        pts_px[:, 1] *= h
        pts_px = smoother.apply(pts_px)
        pts = pts_px.copy()
        pts[:, 0] /= max(1, w)
        pts[:, 1] /= max(1, h)

    center = pts[0:1, :]
    rel = pts - center

    scale = np.linalg.norm(pts[9] - pts[0]) + 1e-6
    if not np.isfinite(scale) or scale < 1e-3:
        min_xy = pts.min(axis=0)
        max_xy = pts.max(axis=0)
        scale = float(np.linalg.norm(max_xy - min_xy) + 1e-6)
    if orient is not None:
        scale = orient.smooth_scale(float(scale))
    rel = rel / max(1e-6, float(scale))

    if canonicalize_left_to_right and handed and handed.lower().startswith("left"):
        rel[:, 0] = -rel[:, 0]

    if orient is not None:
        rel = orient.normalize(rel)

    return torch.from_numpy(rel.astype(np.float32)), handed


# -----------------------------
# Feature building and aug
# -----------------------------

def build_feature(lm_xy_rel: torch.Tensor, prev_lm_xy_rel: Optional[torch.Tensor], use_deltas: bool) -> Optional[torch.Tensor]:
    if lm_xy_rel is None or not torch.is_tensor(lm_xy_rel):
        return None
    cur = lm_xy_rel.reshape(-1)
    if use_deltas:
        if prev_lm_xy_rel is not None and prev_lm_xy_rel.numel() == lm_xy_rel.numel():
            delta = (lm_xy_rel - prev_lm_xy_rel).reshape(-1)
        else:
            delta = torch.zeros_like(cur)
        feat = torch.cat([cur, delta], dim=0)
    else:
        feat = cur
    if torch.any(torch.isnan(feat)):
        return None
    return feat.float()


def resample_sequence(seq_td: torch.Tensor, target_len: int) -> torch.Tensor:
    assert seq_td.dim() == 2
    T, D = seq_td.shape
    if T == target_len:
        return seq_td
    src_idx = torch.linspace(0, T - 1, steps=target_len)
    idx0 = torch.clamp(src_idx.floor().long(), 0, T - 1)
    idx1 = torch.clamp(idx0 + 1, 0, T - 1)
    w = (src_idx - idx0.float()).unsqueeze(1)
    out = (1 - w) * seq_td[idx0] + w * seq_td[idx1]
    return out


def time_warp(seq_td: torch.Tensor, segments: int = 3, sigma: float = 0.2) -> torch.Tensor:
    T, D = seq_td.shape
    if T < 3 or segments < 1 or sigma <= 0:
        return seq_td
    knots = sorted(np.random.choice(np.arange(1, T - 1), size=min(segments, max(1, T // 4)), replace=False).tolist())
    points = [0] + knots + [T - 1]
    slopes = np.clip(np.random.normal(loc=1.0, scale=sigma, size=len(points) - 1), a_min=0.3, a_max=2.5)
    deltas = np.array([points[i + 1] - points[i] for i in range(len(points) - 1)], dtype=np.float32)
    warped_deltas = deltas * slopes
    t_warp = np.concatenate([[0.0], np.cumsum(warped_deltas)])
    t_warp *= (T - 1) / max(t_warp[-1], 1e-6)
    grid = np.linspace(0, T - 1, num=T, dtype=np.float32)

    out = torch.empty_like(seq_td)
    seq_np = seq_td.detach().cpu().numpy()
    for d in range(D):
        vals = np.interp(grid, t_warp, seq_np[points, d])
        out[:, d] = torch.from_numpy(vals).to(dtype=seq_td.dtype)
    return out.to(device=seq_td.device)


# -----------------------------
# Secure store, PBKDF2 with salt
# -----------------------------

class SecureStore:
    """
    Writes artifacts.pt under the project dir.
    If GESTURE_STORE_PASSPHRASE is set and cryptography is available, encrypts with PBKDF2 derived key and a per file salt.
    Set GESTURE_STORE_REQUIRE_ENCRYPTION=1 to refuse plaintext.
    """
    def __init__(self, project_dir: str):
        self.project_dir = project_dir
        os.makedirs(self.project_dir, exist_ok=True)
        self.path = os.path.join(self.project_dir, "artifacts.pt")
        self._fernet_cls = None
        self._pass = ""
        self._require = False
        self._init_crypto()

    def _init_crypto(self) -> None:
        self._require = os.environ.get("GESTURE_STORE_REQUIRE_ENCRYPTION", "0") == "1"
        self._pass = os.environ.get("GESTURE_STORE_PASSPHRASE", "") or ""
        try:
            from cryptography.fernet import Fernet  # type: ignore
            self._fernet_cls = Fernet
        except Exception:
            self._fernet_cls = None

    def _derive_key(self, salt: bytes) -> Optional[bytes]:
        if not self._pass or self._fernet_cls is None:
            return None
        import base64, hashlib
        key = hashlib.pbkdf2_hmac("sha256", self._pass.encode("utf-8"), salt, 200_000, dklen=32)
        return base64.urlsafe_b64encode(key)

    def save_artifacts(self, payload: Dict[str, Any]) -> None:
        try:
            raw = io.BytesIO()
            torch.save(payload, raw)
            data = raw.getvalue()
            if self._fernet_cls is not None and self._pass:
                salt = os.urandom(16)
                fkey = self._derive_key(salt)
                if fkey is None:
                    raise RuntimeError("KDF failed")
                f = self._fernet_cls(fkey)
                data = b"ENC1" + salt + f.encrypt(data)
            elif self._require:
                raise RuntimeError("Encryption required but unavailable")
            with open(self.path, "wb") as f:
                f.write(data)
        except Exception as e:
            print(f"[SecureStore] save failed: {e}")

    def load_artifacts(self) -> Optional[Dict[str, Any]]:
        if not os.path.exists(self.path):
            return None
        try:
            data = open(self.path, "rb").read()
            if data[:4] == b"ENC1":
                salt = data[4:20]
                ct = data[20:]
                fkey = self._derive_key(salt)
                if fkey is None or self._fernet_cls is None:
                    print("[SecureStore] cannot decrypt, missing passphrase or cryptography")
                    return None
                f = self._fernet_cls(fkey)
                try:
                    data = f.decrypt(ct)
                except Exception:
                    print("[SecureStore] decrypt failed, wrong passphrase")
                    return None
            elif self._require:
                print("[SecureStore] plaintext refused by policy")
                return None
            buf = io.BytesIO(data)
            payload = torch.load(buf, map_location="cpu")
            return payload
        except Exception as e:
            print(f"[SecureStore] load failed: {e}")
            return None


# -----------------------------
# Deterministic encoder: DCT + PCA
# -----------------------------

class DeterministicEncoder(nn.Module):
    def __init__(self, cfg: ModelConfig, device: torch.device):
        super().__init__()
        self.cfg = cfg
        self.device = device
        self.T = cfg.seq_len
        self.pos_dim = cfg.frame_dim
        self.delta_dim = cfg.frame_dim if cfg.use_deltas else 0
        self._C_pos = self._make_dct_basis(self.T, cfg.dct_k_pos).to(device)
        self._C_del = self._make_dct_basis(self.T, cfg.dct_k_delta).to(device) if self.delta_dim else None

        # PCA parameters will be learned from support
        self.register_buffer("pca_mean", torch.zeros(1))
        self.register_buffer("pca_components", torch.zeros(1))  # E x F
        self.register_buffer("pca_scale", torch.ones(1))        # 1 x E
        self._pca_ready = False

    @staticmethod
    def _make_dct_basis(T: int, K: int) -> torch.Tensor:
        n = torch.arange(T).float().unsqueeze(1)  # T x 1
        k = torch.arange(K).float().unsqueeze(0)  # 1 x K
        basis = torch.cos(math.pi / T * (n + 0.5) * k)  # T x K
        basis[:, 0] *= math.sqrt(0.5)
        basis = basis * math.sqrt(2.0 / T)
        return basis  # T x K

    def _split_pos_delta(self, seq_bt_d: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        if self.delta_dim:
            pos = seq_bt_d[:, :, :self.pos_dim]
            delt = seq_bt_d[:, :, self.pos_dim:self.pos_dim + self.delta_dim]
            return pos, delt
        else:
            return seq_bt_d, None

    def _dct_features(self, x_btd: torch.Tensor, C: torch.Tensor, k_keep: int) -> torch.Tensor:
        B, T, Dch = x_btd.shape
        x2 = x_btd.reshape(T, B * Dch)  # T x (B*D)
        proj = torch.matmul(C[:, :k_keep].T, x2)  # K x (B*D)
        proj = proj.reshape(k_keep, B, Dch).permute(1, 0, 2)  # B x K x D
        if self.cfg.energy_norm:
            eps = 1e-8
            norm = torch.linalg.norm(proj, dim=1, keepdims=True).clamp_min(eps)
            proj = proj / norm
        return proj.reshape(B, -1)

    def seq_to_basefeat(self, seq_bt_d: torch.Tensor) -> torch.Tensor:
        pos, delt = self._split_pos_delta(seq_bt_d)
        f_pos = self._dct_features(pos, self._C_pos, self.cfg.dct_k_pos)
        if delt is not None and self._C_del is not None:
            f_del = self._dct_features(delt, self._C_del, self.cfg.dct_k_delta)
            base = torch.cat([f_pos, f_del], dim=1)
        else:
            base = f_pos
        return base

    def fit_pca(self, feats_bf: torch.Tensor) -> None:
        mean = feats_bf.mean(dim=0, keepdim=True)
        X = feats_bf - mean
        E = min(self.cfg.pca_embed_dim, min(X.shape[0], X.shape[1]))
        # try low rank PCA when beneficial
        try:
            U, S, V = torch.pca_lowrank(X, q=E, center=False)
            comps = V[:, :E].T
            S_use = S[:E]
        except Exception:
            U, S, Vh = torch.linalg.svd(X, full_matrices=False)
            comps = Vh[:E, :]
            S_use = S[:E]
        if self.cfg.pca_whiten:
            n = max(1, X.shape[0] - 1)
            scale = (math.sqrt(n) / S_use.clamp_min(1e-8)).unsqueeze(0)
        else:
            scale = torch.ones(1, E, device=comps.device)
        self.pca_mean = mean.detach()
        self.pca_components = comps.detach()
        self.pca_scale = scale.detach()
        self._pca_ready = True

    def transform(self, feats_bf: torch.Tensor) -> torch.Tensor:
        if not self._pca_ready:
            E = min(self.cfg.pca_embed_dim, feats_bf.shape[1])
            z = F.normalize(feats_bf[:, :E], p=2, dim=1)
            return z
        X = feats_bf - self.pca_mean
        Z = torch.matmul(X, self.pca_components.t())
        Z = Z * self.pca_scale
        Z = F.normalize(Z, p=2, dim=1)
        return Z

    @torch.inference_mode()
    def encode(self, seq_bt_d: torch.Tensor) -> torch.Tensor:
        base = self.seq_to_basefeat(seq_bt_d)
        z = self.transform(base)
        return z


# -----------------------------
# Gaussian head with shrinkage and conformal
# -----------------------------

class GaussianHead:
    def __init__(self, embed_dim: int, cfg: ModelConfig):
        self.embed_dim = embed_dim
        self.cfg = cfg
        self.mu: Optional[torch.Tensor] = None             # K x E
        self.cov: Optional[torch.Tensor] = None            # K x E x E
        self.chol: Optional[torch.Tensor] = None           # K x E x E
        self.logdet: Optional[torch.Tensor] = None         # K
        self.K: int = 0
        self.q_global: Optional[float] = None
        self.q_per_class: Dict[int, float] = {}

    @staticmethod
    def _shrink_target(sample_cov: torch.Tensor, mode: str) -> torch.Tensor:
        E = sample_cov.shape[-1]
        if mode == "identity":
            trace = torch.trace(sample_cov)
            target = (trace / E) * torch.eye(E, dtype=sample_cov.dtype, device=sample_cov.device)
        elif mode == "diag":
            target = torch.diag(torch.diag(sample_cov))
        else:
            trace = torch.trace(sample_cov)
            target = (trace / E) * torch.eye(E, dtype=sample_cov.dtype, device=sample_cov.device)
        return target

    def fit(self, embs: torch.Tensor, labels: torch.Tensor) -> None:
        device = embs.device
        labels = labels.to(torch.long)
        classes = labels.unique(sorted=True)
        K = classes.numel()
        E = embs.shape[1]
        mu_list, cov_list, chol_list, logdet_list = [], [], [], []
        lam = float(self.cfg.shrinkage_lambda)
        base_reg = float(self.cfg.cov_reg)

        for k in classes.tolist():
            idx = labels == k
            z = embs[idx]
            n_k = z.shape[0]
            if n_k < 4:
                mu_k = z.mean(dim=0) if n_k > 0 else torch.zeros(E, device=device)
                var = z.var(dim=0, unbiased=True) if n_k > 1 else torch.ones(E, device=device)
                cov_k = torch.diag(var + base_reg)
            else:
                mu_k = z.mean(dim=0)
                D = (z - mu_k)
                S = (D.t() @ D) / max(1, n_k - 1)
                S = S + base_reg * torch.eye(E, device=device)
                Tgt = self._shrink_target(S, self.cfg.shrink_target)
                cov_k = (1.0 - lam) * S + lam * Tgt
            reg = 0.0
            L = None
            for _ in range(5):
                L_try, info = torch.linalg.cholesky_ex(cov_k + reg * torch.eye(E, device=device))
                if int(info.item()) == 0:
                    L = L_try
                    break
                reg = max(1e-6, 10.0 * (reg if reg > 0 else base_reg))
            if L is None:
                L = torch.linalg.cholesky(cov_k + (base_reg + 1e-3) * torch.eye(E, device=device))
            logdet_k = 2.0 * torch.log(torch.diag(L)).sum()
            mu_list.append(mu_k)
            cov_list.append(cov_k)
            chol_list.append(L)
            logdet_list.append(logdet_k)
        self.mu = torch.stack(mu_list, dim=0)
        self.cov = torch.stack(cov_list, dim=0)
        self.chol = torch.stack(chol_list, dim=0)
        self.logdet = torch.stack(logdet_list, dim=0)
        self.K = K

    def _m2_batch(self, z: torch.Tensor) -> torch.Tensor:
        assert self.mu is not None and self.chol is not None
        diff = z.unsqueeze(0) - self.mu  # K x E
        b = diff.unsqueeze(-1)                       # K x E x 1
        y = torch.cholesky_solve(b, self.chol)       # K x E x 1
        m2 = (diff * y.squeeze(-1)).sum(dim=1)       # K
        return m2

    def nll(self, z: torch.Tensor) -> torch.Tensor:
        assert self.logdet is not None
        m2 = self._m2_batch(z)
        E = z.shape[-1]
        nll = 0.5 * (m2 + self.logdet + E * math.log(2 * math.pi))
        return nll

    def predict_with_margin(self, z: torch.Tensor) -> Tuple[int, float, float, float]:
        nll = self.nll(z)
        logits = -self.cfg.logit_scale * nll
        probs = torch.softmax(logits, dim=-1)
        conf, idx = probs.max(dim=-1)
        if probs.numel() >= 2:
            top2 = torch.topk(probs, k=2)
            margin = float(top2.values[0].item() - top2.values[1].item())
        else:
            margin = float(conf.item())
        return int(idx.item()), float(conf.item()), float(nll[idx].item()), margin

    @staticmethod
    def _quantile(scores: List[float], alpha: float) -> float:
        n = len(scores)
        if n == 0:
            return float("inf")
        k = int(math.ceil((n + 1) * (1 - alpha)))
        k = min(max(1, k), n)
        s = sorted(scores)
        return s[k - 1]

    def fit_conformal(self, embs: torch.Tensor, labels: torch.Tensor) -> None:
        assert self.mu is not None
        scores_by_class: Dict[int, List[float]] = defaultdict(list)
        all_scores: List[float] = []
        for z, y in zip(embs, labels.long()):
            nll = self.nll(z)
            s = float(nll[y].item())
            all_scores.append(s)
            scores_by_class[int(y.item())].append(s)
        if self.cfg.conformal_mondrian:
            self.q_per_class = {}
            for c, arr in scores_by_class.items():
                self.q_per_class[c] = self._quantile(arr, self.cfg.conformal_alpha) if len(arr) >= 3 else float("inf")
            self.q_global = None
        else:
            self.q_global = self._quantile(all_scores, self.cfg.conformal_alpha)
            self.q_per_class = {}

    def abstain(self, pred_idx: int, nonconformity: float) -> bool:
        if self.cfg.conformal_mondrian and self.q_per_class:
            q = self.q_per_class.get(pred_idx, float("inf"))
            return nonconformity > q
        if self.q_global is not None:
            return nonconformity > self.q_global
        return False


# -----------------------------
# Model manager
# -----------------------------

class ModelManager:
    def __init__(self, cfg: ModelConfig, device: torch.device):
        self.cfg = cfg
        d = cfg.frame_dim * (2 if cfg.use_deltas else 1)
        self.frame_dim = d
        self.device = device
        self.encoder = DeterministicEncoder(cfg, device=self.device).to(self.device)
        self.embed_dim = cfg.pca_embed_dim
        self.gestures: List[Gesture] = []
        self.seqs: List[torch.Tensor] = []
        self.labels: List[int] = []
        self.gauss = GaussianHead(embed_dim=self.embed_dim, cfg=cfg)

    def _find_gesture_id(self, name: str) -> Optional[int]:
        for i, g in enumerate(self.gestures):
            if g.name == name:
                return i
        return None

    def ensure_gesture(self, name: str, action: str) -> int:
        gid = self._find_gesture_id(name)
        if gid is not None:
            if self.gestures[gid].action != action:
                self.gestures[gid] = Gesture(name=name, action=action)
            return gid
        gid = len(self.gestures)
        self.gestures.append(Gesture(name=name, action=action))
        return gid

    def add_sample(self, name_or_id: Any, seq: torch.Tensor) -> int:
        if isinstance(name_or_id, int):
            gid = int(name_or_id)
            if gid < 0 or gid >= len(self.gestures):
                raise IndexError("gesture id out of range")
        else:
            gid = self.ensure_gesture(str(name_or_id), action=self._default_action(str(name_or_id)))
        if seq.dim() == 3 and seq.shape[0] == 1:
            seq = seq.squeeze(0)
        if seq.dim() != 2:
            raise ValueError("seq must be (T, D) or (1, T, D)")
        s = resample_sequence(seq.detach().cpu().float(), self.cfg.seq_len)
        for aug in self._augment_sequence(s):
            self.seqs.append(aug)
            self.labels.append(gid)
        self.update_geometry_and_conformal()
        return gid

    def add_example(self, name: str, seq: torch.Tensor) -> int:
        return self.add_sample(name, seq)

    def _default_action(self, name: str) -> str:
        return name

    def remove_gesture(self, idx: int) -> None:
        if idx < 0 or idx >= len(self.gestures):
            raise IndexError("gesture id out of range")
        del self.gestures[idx]
        new_seqs, new_labels = [], []
        for s, y in zip(self.seqs, self.labels):
            if y == idx:
                continue
            new_seqs.append(s)
            new_labels.append(y - 1 if y > idx else y)
        self.seqs, self.labels = new_seqs, new_labels
        self.update_geometry_and_conformal()

    @property
    def num_classes(self) -> int:
        return len(self.gestures)

    def load_encoder_weights(self, path: str) -> None:
        raise RuntimeError("Deterministic encoder does not use learned weights")

    def apply_temperature(self, T: float) -> None:
        pass

    @torch.inference_mode()
    def encode(self, seq_bt_d: torch.Tensor) -> torch.Tensor:
        self.encoder.eval()
        return self.encoder.encode(seq_bt_d)

    def _augment_sequence(self, s: torch.Tensor) -> List[torch.Tensor]:
        out: List[torch.Tensor] = []
        T, D = s.shape
        pos_dim = min(self.cfg.frame_dim, D)
        pos = s[:, :pos_dim].clone()
        has_delta = self.cfg.use_deltas and D >= 2 * self.cfg.frame_dim
        if has_delta:
            out.append(s)
        else:
            out.append(pos)
        if self.cfg.pos_noise_std > 0:
            noisy_pos = pos + torch.randn_like(pos) * self.cfg.pos_noise_std
            if self.cfg.use_deltas:
                d = torch.zeros_like(noisy_pos)
                d[1:] = noisy_pos[1:] - noisy_pos[:-1]
                noisy = torch.cat([noisy_pos, d], dim=1)
            else:
                noisy = noisy_pos
            out.append(noisy)
        if random.random() < self.cfg.time_warp_prob:
            warped_pos = time_warp(pos, segments=self.cfg.time_warp_segments, sigma=self.cfg.time_warp_sigma)
            if self.cfg.use_deltas:
                d = torch.zeros_like(warped_pos)
                d[1:] = warped_pos[1:] - warped_pos[:-1]
                warped = torch.cat([warped_pos, d], dim=1)
            else:
                warped = warped_pos
            out.append(warped)
        return out

    def add_gesture_support(self, name: str, action: str, support: List[torch.Tensor]) -> None:
        assert len(support) > 0
        gid = len(self.gestures)
        self.gestures.append(Gesture(name=name, action=action))
        for s in support:
            s = resample_sequence(s.detach().cpu().float(), self.cfg.seq_len)
            for aug in self._augment_sequence(s):
                self.seqs.append(aug)
                self.labels.append(gid)
        self.update_geometry_and_conformal()

    def mirror_sequence(self, seq_td: torch.Tensor) -> torch.Tensor:
        T, D = seq_td.shape
        out = seq_td.clone()
        pos_dim = self.cfg.frame_dim if D >= self.cfg.frame_dim else D
        xs = out[:, 0:pos_dim:2]
        out[:, 0:pos_dim:2] = -xs
        if D > pos_dim:
            dxs = out[:, pos_dim:D:2]
            out[:, pos_dim:D:2] = -dxs
        return out

    def _compute_embeddings_and_fit_pca(self) -> Tuple[torch.Tensor, torch.Tensor]:
        if len(self.seqs) == 0:
            return torch.empty(0, self.embed_dim), torch.empty(0, dtype=torch.long)
        self.encoder.eval()
        with torch.inference_mode():
            bases: List[torch.Tensor] = []
            for s in self.seqs:
                x = s.unsqueeze(0).to(self.device)
                base = self.encoder.seq_to_basefeat(x)
                bases.append(base.squeeze(0).cpu())
            base_mat = torch.stack(bases, dim=0)
            self.encoder.fit_pca(base_mat.to(self.device))
            embs = self.encoder.transform(base_mat.to(self.device))
        labels = torch.tensor(self.labels, dtype=torch.long)
        return embs.detach().cpu(), labels

    def update_geometry_and_conformal(self) -> None:
        if len(self.seqs) == 0 or len(self.gestures) == 0:
            return
        embs_mat, labels = self._compute_embeddings_and_fit_pca()
        if embs_mat.shape[0] == 0:
            return
        rng = torch.Generator().manual_seed(42)
        idx = torch.randperm(embs_mat.shape[0], generator=rng)
        n_total = len(idx)
        n_calib = max(1, int(self.cfg.calib_frac * n_total))
        calib_idx = idx[:n_calib]
        train_idx = idx[n_calib:] if n_total - n_calib > 0 else calib_idx
        self.gauss.fit(embs_mat[train_idx].to(self.device), labels[train_idx].to(self.device))
        for attr in ("mu", "cov", "chol", "logdet"):
            t = getattr(self.gauss, attr, None)
            if t is not None:
                setattr(self.gauss, attr, t.to(self.device))
        self.gauss.fit_conformal(embs_mat[calib_idx].to(self.device), labels[calib_idx].to(self.device))

    @torch.inference_mode()
    def infer(self, seq_bt_d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
        if self.num_classes == 0:
            return torch.tensor(0.0), torch.tensor(-1), {"reason": "no_classes"}
        z = self.encode(seq_bt_d.to(self.device))[0]
        idx, conf, nll, margin = self.gauss.predict_with_margin(z)
        abstain = self.gauss.abstain(idx, nll)
        info = {"nll": nll, "abstain": bool(abstain), "margin": margin}
        if abstain:
            return torch.tensor(0.0), torch.tensor(-1), info
        return torch.tensor(conf), torch.tensor(idx), info

    def export_artifacts(self) -> Dict[str, Any]:
        payload: Dict[str, Any] = {}
        payload["cfg"] = dict(self.cfg.__dict__)
        payload["gestures"] = [{"name": g.name, "action": g.action} for g in self.gestures]
        payload["seqs"] = [s.cpu().numpy() for s in self.seqs]
        payload["labels"] = list(map(int, self.labels))
        if self.encoder._pca_ready:
            payload["pca"] = {
                "mean": self.encoder.pca_mean.detach().cpu().numpy(),
                "components": self.encoder.pca_components.detach().cpu().numpy(),
                "scale": self.encoder.pca_scale.detach().cpu().numpy(),
            }
        if self.gauss.mu is not None:
            payload["gauss"] = {
                "mu": self.gauss.mu.detach().cpu().numpy(),
                "cov": self.gauss.cov.detach().cpu().numpy(),
                "logdet": self.gauss.logdet.detach().cpu().numpy(),
                "q_global": self.gauss.q_global,
                "q_per_class": self.gauss.q_per_class,
            }
        return payload

    def import_artifacts(self, payload: Dict[str, Any]) -> None:
        cfg_dict = payload.get("cfg", {})
        for k, v in cfg_dict.items():
            if hasattr(self.cfg, k):
                setattr(self.cfg, k, v)
        gest = payload.get("gestures", [])
        self.gestures = [Gesture(**g) for g in gest]
        self.seqs = [torch.from_numpy(np.array(arr)).float() for arr in payload.get("seqs", [])]
        self.labels = list(map(int, payload.get("labels", [])))
        p = payload.get("pca", None)
        if p is not None:
            self.encoder.pca_mean = torch.from_numpy(np.array(p["mean"])).to(self.device).float()
            self.encoder.pca_components = torch.from_numpy(np.array(p["components"])).to(self.device).float()
            self.encoder.pca_scale = torch.from_numpy(np.array(p["scale"])).to(self.device).float()
            self.encoder._pca_ready = True
        else:
            self.encoder._pca_ready = False
        g = payload.get("gauss", None)
        if g is not None:
            mu = torch.from_numpy(np.array(g.get("mu"))).float().to(self.device)
            cov = torch.from_numpy(np.array(g.get("cov"))).float().to(self.device)
            logdet = torch.from_numpy(np.array(g.get("logdet"))).float().to(self.device)
            K, E, _ = cov.shape
            chol_list = []
            for k in range(K):
                L, info = torch.linalg.cholesky_ex(cov[k])
                if int(info.item()) != 0:
                    L = torch.linalg.cholesky(cov[k] + 1e-3 * torch.eye(E, device=self.device))
                chol_list.append(L)
            chol = torch.stack(chol_list, dim=0)
            self.gauss.mu = mu
            self.gauss.cov = cov
            self.gauss.chol = chol
            self.gauss.logdet = logdet
            self.gauss.K = mu.shape[0]
            self.gauss.q_global = g.get("q_global", None)
            self.gauss.q_per_class = {int(k): float(v) for k, v in g.get("q_per_class", {}).items()}
        else:
            self.update_geometry_and_conformal()


# -----------------------------
# Unit check
# -----------------------------

def _unit_test_shapes():
    cfg = ModelConfig()
    mgr = ModelManager(cfg, device=torch.device("cpu"))
    T = cfg.seq_len
    D = 42 if not cfg.use_deltas else 84
    seq = torch.randn(T, D)
    z = mgr.encode(seq.unsqueeze(0))
    assert z.shape[-1] <= cfg.pca_embed_dim

if __name__ == "__main__":
    print("Running minimal unit check...")
    _unit_test_shapes()
    print("OK")


Running minimal unit check...
OK


I0000 00:00:1754757338.076268 4204849 gl_context.cc:369] GL version: 2.1 (2.1 INTEL-22.5.11), renderer: Intel(R) Iris(TM) Plus Graphics OpenGL Engine (1x6x8 (fused) LP


W0000 00:00:1754757338.229874 4207319 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1754757338.384625 4207319 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
