In [None]:
# Cell 0 — hugging face token
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("hf_token")

In [None]:
# Cell 1 — Environment, imports, Hugging Face API init, logging, system probes
# -----------------------------------------------------------------------------
# Purpose: central imports and HF init. Uses psutil + pynvml (if available) to monitor memory.
import os
import sys
import time
import math
import logging
from pathlib import Path
from typing import Optional, Tuple, Dict, List, Any

# Core data / vision / torch imports
import numpy as np
import pandas as pd
import cv2
from tqdm.auto import tqdm

import torch
import torch.nn as nn

# optional monitoring libs (psutil and pynvml recommended for accurate metrics)
try:
    import psutil
except Exception:
    psutil = None

try:
    import pynvml
    pynvml.nvmlInit()
    _NVML_AVAILABLE = True
except Exception:
    _NVML_AVAILABLE = False

# Hugging Face API
from huggingface_hub import HfApi

# Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
log = logging.getLogger("cataract_pipeline")

# Initialize HF API client using HF_TOKEN env var; user can set HF_TOKEN in notebook runtime
# HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HF_HUB_TOKEN")
if not HF_TOKEN:
    log.warning("HF_TOKEN not found in environment — uploading will fail until HF_TOKEN is provided.")
hf_api = HfApi(token=HF_TOKEN)


In [None]:
# Cell 2 — Configuration (tweak as needed)
# -----------------------------------------------------------------------------
# Comments: set paths, memory thresholds, multi-GPU options, and other tuning.
CONFIG = dict(
    BASE_DIR="/kaggle/input/cataract-101/cataract-101",   # change if dataset is elsewhere
    WORK_DIR="/kaggle/working/cataract-101-generated",
    FPS=10,
    FEATURE_BACKBONE="resnet50",
    PRETRAINED_BACKBONE=True,
    BATCH_FRAME=8,            # initial frames per model batch for encoding (will adapt on OOM)
    IMG_SIZE=112,
    SEQ_LEN=200,
    SAMPLE_VIDEOS=4,       # limit processed videos for testing
    PRINT_DEBUG=True,
    MEMMAP_DIR="/kaggle/working/cataract-101-generated/memmaps",
    # Memory safety thresholds: fraction of RAM/VRAM usage at which processing will stop (0..1)
    MAX_RAM_FRAC=0.90,        # stop when host RAM >= this fraction
    MAX_VRAM_FRAC=0.92,       # stop when any GPU's VRAM >= this fraction
    MIN_BATCH_FRAME=1,        # don't reduce batch below this
    GPU_DEVICE_IDS=None,      # None => autodetect
    SAFE_SLEEP_SEC=0.1,       # tiny sleep in loops to yield
    HF_REPO_ID="Mateo4/cataract-101-generated",
    HF_REPO_TYPE="dataset",
)
# Create WORK_DIR early so later cells can assume it exists
from pathlib import Path
WORK_DIR = Path(CONFIG["WORK_DIR"])
WORK_DIR.mkdir(parents=True, exist_ok=True)


In [None]:
# Cell 3 — Hardware / memory utilities and printer (read often)
# -----------------------------------------------------------------------------
# Comments: helper functions to inspect RAM and GPU VRAM, plus a guard function
def host_ram_usage_frac() -> float:
    """Return fraction of used host RAM (0..1)."""
    if psutil is None:
        # fallback: unknown, be conservative
        return 0.0
    mem = psutil.virtual_memory()
    return float(mem.used) / float(mem.total)

def gpu_memory_info() -> List[Dict[str, Any]]:
    """Return list of dicts for each GPU: {'id', 'used', 'total', 'used_frac'}.
       If NVML unavailable, returns empty list."""
    out = []
    if not _NVML_AVAILABLE:
        return out
    device_count = pynvml.nvmlDeviceGetCount()
    for i in range(device_count):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
        used = float(meminfo.used)
        total = float(meminfo.total)
        out.append({"id": i, "used": used, "total": total, "used_frac": used / total if total > 0 else 0.0})
    return out

def any_vram_exceeds(frac_thresh: float) -> bool:
    for g in gpu_memory_info():
        if g["used_frac"] >= frac_thresh:
            return True
    return False

def guard_memory(max_ram_frac: float, max_vram_frac: float) -> Tuple[bool,str]:
    """Return (ok, message). ok==False means we've reached thresholds and should stop."""
    ramf = host_ram_usage_frac()
    if ramf >= max_ram_frac:
        return False, f"Host RAM fraction {ramf:.3f} >= max {max_ram_frac}"
    if _NVML_AVAILABLE:
        for g in gpu_memory_info():
            if g["used_frac"] >= max_vram_frac:
                return False, f"GPU {g['id']} VRAM fraction {g['used_frac']:.3f} >= max {max_vram_frac}"
    return True, "OK"

def print_system_summary():
    log.info("Torch CUDA available: %s", torch.cuda.is_available())
    log.info("Torch CUDA devices: %d", torch.cuda.device_count())
    log.info("Host RAM used fraction: %.3f", host_ram_usage_frac())
    if _NVML_AVAILABLE:
        for g in gpu_memory_info():
            log.info("GPU %d VRAM used %.3f / %.3f (frac %.3f)", g['id'], g['used']/1e9, g['total']/1e9, g['used_frac'])
    else:
        log.info("NVML not available — GPU usage not shown.")


In [None]:
# Cell 4 — Core helpers (col detection, to_frame) with added defensive checks and comments
# -----------------------------------------------------------------------------
import re
_int_re = re.compile(r'(-?\d+(\.\d+)?)')

def find_col(df: pd.DataFrame, keys: List[str]) -> Optional[str]:
    keys = [k.lower() for k in keys]
    for c in df.columns:
        cl = c.lower()
        for k in keys:
            if k in cl:
                return c
    return None

def to_frame(x: Any, fps: int = CONFIG["FPS"]) -> Optional[int]:
    """Convert annotation value to frame index. Supports integers, seconds, timecodes like MM:SS or HH:MM:SS."""
    if pd.isnull(x):
        return None
    if isinstance(x, (int, np.integer)):
        return int(x)
    s = str(x).strip()
    if ':' in s:
        try:
            parts = [float(p) for p in s.split(':')]
            if len(parts) == 3:
                secs = parts[0]*3600 + parts[1]*60 + parts[2]
            elif len(parts) == 2:
                secs = parts[0]*60 + parts[1]
            else:
                secs = parts[0]
            return int(round(secs * fps))
        except Exception:
            return None
    m = _int_re.search(s)
    if m:
        try:
            xf = float(m.group(1))
        except:
            return None
        # Heuristic: numbers <=20 are seconds
        if xf <= 20:
            return int(round(xf * fps))
        return int(round(xf))
    return None

def normalize_phase_map(phases_df: pd.DataFrame) -> Dict[str, int]:
    mapping = {}
    if phases_df is None or phases_df.shape[0] == 0:
        return mapping
    cols = [c.lower() for c in phases_df.columns]
    if 'phase_id' in cols and 'phase_name' in cols:
        id_col = phases_df.columns[cols.index('phase_id')]
        name_col = phases_df.columns[cols.index('phase_name')]
        for _, r in phases_df.iterrows():
            mapping[str(r[name_col])] = int(r[id_col])
        return mapping
    if phases_df.shape[1] == 2:
        id_col, name_col = phases_df.columns[0], phases_df.columns[1]
        for _, r in phases_df.iterrows():
            mapping[str(r[name_col])] = int(r[id_col])
        return mapping
    # fallback: last column as names
    names = phases_df.iloc[:, -1].astype(str).tolist()
    for i, n in enumerate(names):
        mapping[n] = i
    return mapping


In [None]:
# Cell 5 — parse_annotations (step 1) with clear comments and safety checks
# -----------------------------------------------------------------------------
def parse_annotations(base_dir: str = CONFIG["BASE_DIR"], out_dir: Path = WORK_DIR, fps: int = CONFIG["FPS"], sample_videos: Optional[int] = CONFIG["SAMPLE_VIDEOS"]):
    """
    Reads annotations.csv and phases.csv (if present), writes:
      - segments_filled.csv
      - labels_npy/<video_id>.npy  (per-frame labels)
      - phase_map.csv
      - video_file_map.csv (if videos folder exists)
    This is robust to many annotation formats: explicit start/end, per-frame lists, MM:SS timecodes.
    """
    base_dir = Path(base_dir)
    ann_path = base_dir / "annotations.csv"
    phases_path = base_dir / "phases.csv"
    if not ann_path.exists():
        raise FileNotFoundError(f"{ann_path} missing")
    ann = pd.read_csv(ann_path)
    log.info("Loaded annotations.csv (%d rows)", len(ann))

    # load phase map if available
    phase_map = {}
    if phases_path.exists():
        try:
            p_df = pd.read_csv(phases_path)
            phase_map = normalize_phase_map(p_df)
            log.info("Loaded phase_map (%d entries)", len(phase_map))
        except Exception as e:
            log.warning("Failed to read phases.csv: %s", e)

    col_video = find_col(ann, ['video','video_id','videoid','vid','file','filename']) or ann.columns[0]
    col_phase = find_col(ann, ['phase','label','phase_id','phaseid','class','action','phase_name'])
    col_start = find_col(ann, ['start_frame','startframe','start_f','start'])
    col_end = find_col(ann, ['end_frame','endframe','end_f','end'])
    col_frame = find_col(ann, ['frame','frame_id','frameid'])
    col_start_time = find_col(ann, ['start_time','startsec','start_s','startseconds'])
    col_end_time = find_col(ann, ['end_time','endsec','end_s','endseconds'])
    log.debug("Columns detected: video=%s phase=%s start=%s end=%s frame=%s", col_video, col_phase, col_start, col_end, col_frame)

    segments = []
    if col_start is not None and col_end is not None and col_start in ann.columns and col_end in ann.columns:
        log.info("Parsing explicit start/end columns.")
        for idx, row in ann.iterrows():
            vid = str(row[col_video]) if col_video in ann.columns else str(idx)
            sfrm = to_frame(row[col_start], fps=fps)
            efrm = to_frame(row[col_end], fps=fps)
            if sfrm is None or efrm is None:
                continue
            if col_phase and col_phase in ann.columns and not pd.isnull(row[col_phase]):
                raw = row[col_phase]
                key = str(raw)
                if key in phase_map:
                    pid = int(phase_map[key])
                else:
                    try:
                        pid = int(float(raw))
                    except:
                        pid = len(phase_map)
                        phase_map[key] = pid
            else:
                pid = 0
            if efrm < sfrm:
                sfrm, efrm = efrm, sfrm
            segments.append((vid, int(sfrm), int(efrm), int(pid)))
    elif col_frame is not None and col_frame in ann.columns:
        log.info("Parsing compound/per-frame column: %s", col_frame)
        def parse_cell(cell):
            if pd.isnull(cell):
                return []
            if isinstance(cell, (int, np.integer)):
                v = int(cell)
                return [(v, v, None)]
            s = str(cell).strip()
            s2 = s.replace(',', ';').replace('|', ';')
            toks = [t.strip() for t in re.split(r'[;]+', s2) if t.strip()]
            out = []
            if len(toks) >= 3 and len(toks) % 3 == 0:
                for i in range(0, len(toks), 3):
                    a, b, c = toks[i], toks[i+1], toks[i+2]
                    a_f, b_f = to_frame(a, fps), to_frame(b, fps)
                    if a_f is None or b_f is None:
                        continue
                    out.append((a_f, b_f, c))
                if out:
                    return out
            if len(toks) == 2:
                a_f, b_f = to_frame(toks[0], fps), to_frame(toks[1], fps)
                if a_f is not None and b_f is not None:
                    return [(a_f, b_f, None)]
            m = re.match(r'^\s*(\S+)\s*[-–:]\s*(\S+)\s*$', s)
            if m:
                a_f, b_f = to_frame(m.group(1), fps), to_frame(m.group(2), fps)
                if a_f is not None and b_f is not None:
                    return [(a_f, b_f, None)]
            parts = s.split()
            if len(parts) >= 2:
                a_f, b_f = to_frame(parts[0], fps), to_frame(parts[1], fps)
                if a_f is not None and b_f is not None:
                    return [(a_f, b_f, None)]
            single = to_frame(s, fps)
            if single is not None:
                return [(single, single, None)]
            return []
        for idx, row in ann.iterrows():
            vid = str(row[col_video]) if col_video in ann.columns else str(idx)
            parsed = parse_cell(row[col_frame])
            if not parsed:
                continue
            for a, b, c in parsed:
                pval = c if (c is not None and str(c).lower() not in ['nan','none','']) else (row[col_phase] if col_phase and col_phase in ann.columns else None)
                if pval is None or (isinstance(pval, float) and np.isnan(pval)):
                    pid = 0
                else:
                    key = str(pval)
                    if key in phase_map:
                        pid = int(phase_map[key])
                    else:
                        try:
                            pid = int(float(key))
                        except:
                            pid = len(phase_map)
                            phase_map[key] = pid
                sfrm, efrm = int(a), int(b)
                if efrm < sfrm:
                    sfrm, efrm = efrm, sfrm
                segments.append((vid, sfrm, efrm, pid))
    else:
        raise ValueError("No suitable start/end/frame column found in annotations.csv")

    seg_df = pd.DataFrame(segments, columns=['video_id','start_frame','end_frame','phase_id'])
    if seg_df.empty:
        raise RuntimeError("No segments parsed from annotations.csv")
    seg_df = seg_df[(seg_df['start_frame'] >= 0) & (seg_df['end_frame'] >= 0)].copy()
    seg_df = seg_df.sort_values(['video_id','start_frame','end_frame']).reset_index(drop=True)

    # fix 1-based indexing -> convert to 0-based
    min_start = int(seg_df['start_frame'].min())
    if min_start >= 1:
        log.info("Detected probable 1-based indexing (min start_frame=%d) -> converting to 0-based", min_start)
        seg_df['start_frame'] -= 1
        seg_df['end_frame'] -= 1
        seg_df = seg_df[seg_df['end_frame'] >= seg_df['start_frame']].reset_index(drop=True)

    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    seg_df.to_csv(out_dir / "segments_filled.csv", index=False)
    log.info("Saved segments_filled.csv (%d segments, %d videos)", len(seg_df), seg_df['video_id'].nunique())

    # build per-video label arrays and save
    LABELS_DIR = out_dir / "labels_npy"
    LABELS_DIR.mkdir(parents=True, exist_ok=True)
    for vid, g in seg_df.groupby('video_id'):
        max_frame = int(g['end_frame'].max()) + 1
        labels = np.full((max_frame,), -1, dtype=np.int32)
        for _, r in g.iterrows():
            s, e = int(r['start_frame']), int(r['end_frame'])
            if s < 0: s = 0
            if e < s: continue
            labels[s:e+1] = int(r['phase_id'])
        np.save(LABELS_DIR / f"{vid}.npy", labels, allow_pickle=False)
    log.info("Saved label arrays to %s", LABELS_DIR)

    # write phase_map
    phase_map_items = sorted(phase_map.items(), key=lambda x: x[1])
    pd.DataFrame(phase_map_items, columns=['phase_name','phase_id']).to_csv(out_dir / "phase_map.csv", index=False)

    # optional: write video_file_map if videos folder exists
    vids_folder = Path(base_dir) / "videos"
    if vids_folder.exists():
        vm = {}
        for p in vids_folder.iterdir():
            if p.is_file():
                vm[p.stem] = str(p)
        if vm:
            pd.DataFrame(list(vm.items()), columns=['video_id','video_file']).to_csv(out_dir / "video_file_map.csv", index=False)
            log.info("Saved video_file_map.csv")

    return seg_df, phase_map


In [None]:
# Cell 6 — build_backbone + preprocessing with multi-GPU / amp support
# -----------------------------------------------------------------------------
from torchvision import models
import torch.nn.functional as F

def build_backbone(backbone_name: str = CONFIG['FEATURE_BACKBONE'], pretrained: bool = CONFIG['PRETRAINED_BACKBONE'], device_ids: Optional[List[int]] = CONFIG['GPU_DEVICE_IDS']):
    """
    Build a ResNet backbone, strip classifier, and return (model, feat_dim).
    If device_ids has multiple entries, we return a DataParallel-wrapped model (note: DataParallel duplicates model across gpus).
    """
    name = backbone_name.lower()
    if name == 'resnet50':
        base = models.resnet50(pretrained=pretrained)
        feat_dim = 2048
    elif name == 'resnet18':
        base = models.resnet18(pretrained=pretrained)
        feat_dim = 512
    else:
        raise ValueError("Unsupported backbone: " + backbone_name)

    backbone = nn.Sequential(*list(base.children())[:-1])  # remove classifier & avgpool left as global pool
    backbone.eval()

    # Device placement: prefer multi-gpu if requested and available
    if torch.cuda.is_available() and (device_ids is None or len(device_ids) == 0):
        device_ids = list(range(torch.cuda.device_count()))
    if torch.cuda.is_available() and device_ids:
        # send model to first GPU and wrap in DataParallel (simple path)
        backbone = backbone.to(f"cuda:{device_ids[0]}")
        if len(device_ids) > 1:
            backbone = nn.DataParallel(backbone, device_ids=device_ids)
    else:
        backbone = backbone.to("cpu")

    # Minor perf tweaks
    torch.backends.cudnn.benchmark = True
    return backbone, feat_dim

def preprocess_frame_cv2(frame: np.ndarray, img_size: int = CONFIG['IMG_SIZE']):
    """Convert BGR cv2 frame -> normalized CHW float32 numpy array."""
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    h, w = frame.shape[:2]
    if (h, w) != (img_size, img_size):
        frame = cv2.resize(frame, (img_size, img_size), interpolation=cv2.INTER_LINEAR)
    arr = frame.astype(np.float32) / 255.0
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
    arr = (arr - mean) / std
    arr = arr.transpose(2, 0, 1)   # HWC -> CHW
    return arr


In [None]:
# Cell 7 — extract_features_streaming (updated: uses torch.amp.autocast safely)
# -----------------------------------------------------------------------------
from contextlib import nullcontext

def amp_context_for_device(device: str):
    """
    Return a callable that yields a context manager for mixed precision.
    Usage:
        amp_ctx = amp_context_for_device(device_str)
        with amp_ctx():
            ...
    On CUDA devices returns torch.amp.autocast(device_type='cuda', enabled=True).
    On CPU returns a no-op context (nullcontext).
    """
    if torch.cuda.is_available() and device and str(device).startswith("cuda"):
        return lambda: torch.amp.autocast(device_type='cuda', enabled=True)
    else:
        return lambda: nullcontext()

def extract_features_streaming(work_dir: Path = WORK_DIR, base_dir: str = CONFIG['BASE_DIR'], sample_videos: Optional[int] = CONFIG['SAMPLE_VIDEOS']):
    """
    Streaming feature extraction:
     - Reads labels_npy/*.npy to find fragments to produce.
     - For each base video, creates features_by_video/<base_vid>.npy (memmap if frame count known),
       then slices that into per-fragment files saved in features/<stem>.npy.

    Memory techniques and robustness:
     - Uses memmap for full-video features when frame count known.
     - Writes in small flushes and calls torch.cuda.empty_cache() after freeing tensors.
     - Uses torch.amp.autocast (via amp_context_for_device) for mixed precision on CUDA.
     - Adapts batch size downward on OOM; stops if memory thresholds exceeded.
    """
    work_dir = Path(work_dir)
    features_dir = work_dir / "features"
    features_by_video_dir = work_dir / "features_by_video"
    features_dir.mkdir(parents=True, exist_ok=True)
    features_by_video_dir.mkdir(parents=True, exist_ok=True)

    # Build video file map if present
    video_file_map = {}
    vm_path = work_dir / "video_file_map.csv"
    if vm_path.exists():
        vm_df = pd.read_csv(vm_path)
        for _, r in vm_df.iterrows():
            video_file_map[str(r['video_id'])] = str(r['video_file'])
    videos_folder = Path(base_dir) / "videos"
    video_files = {}
    if videos_folder.exists():
        for p in videos_folder.iterdir():
            if p.is_file():
                video_files[p.stem] = str(p)

    # collect label stems and group by base video id (token before ';' if present)
    label_stems = sorted([p.stem for p in (work_dir/'labels_npy').glob('*.npy')])
    if sample_videos is not None:
        label_stems = label_stems[:sample_videos]
    groups = {}
    for stem in label_stems:
        base = re.split(r'[;,\|:]+', stem)[0].strip()
        groups.setdefault(base, []).append(stem)

    log.info("Found %d label fragments grouped into %d base videos", len(label_stems), len(groups))

    # Build backbone
    # autodetect device ids for multi-gpu
    if CONFIG['GPU_DEVICE_IDS'] is None:
        device_ids = list(range(torch.cuda.device_count())) if torch.cuda.is_available() else []
    else:
        device_ids = CONFIG['GPU_DEVICE_IDS']
    backbone, feat_dim = build_backbone(CONFIG['FEATURE_BACKBONE'], CONFIG['PRETRAINED_BACKBONE'], device_ids=device_ids)

    # initial batch_frame (may reduce on OOM)
    batch_frame = int(CONFIG['BATCH_FRAME'])

    # For each base video, create features_by_video file then slice fragments
    for base_vid, stems in groups.items():
        log.info("Processing base video: %s | fragments: %d", base_vid, len(stems))
        # memory guard before starting a new video
        ok, msg = guard_memory(CONFIG['MAX_RAM_FRAC'], CONFIG['MAX_VRAM_FRAC'])
        if not ok:
            log.warning("Aborting processing: memory guard tripped before starting video %s: %s", base_vid, msg)
            break

        # find video path
        video_path = None
        if base_vid in video_file_map:
            candidate = Path(video_file_map[base_vid])
            if not candidate.is_absolute():
                candidate = Path(base_dir) / "videos" / candidate
            if candidate.exists():
                video_path = str(candidate)
        if video_path is None and base_vid in video_files:
            video_path = video_files[base_vid]
        if video_path is None:
            # try fuzzy
            for key, path in video_files.items():
                if key.startswith(base_vid) or base_vid in key:
                    video_path = path
                    break
        if video_path is None:
            log.warning("No source video for %s — skipping", base_vid)
            continue

        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            log.warning("Failed to open %s — skipping", video_path)
            continue
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if cap.get(cv2.CAP_PROP_FRAME_COUNT) > 0 else None
        full_feat_path = features_by_video_dir / f"{base_vid}.npy"

        # If precomputed full features exist, slice them and continue
        if full_feat_path.exists():
            log.info("Full-video features already exist for %s — slicing fragments", base_vid)
            try:
                full_feats = np.load(full_feat_path, mmap_mode='r')
            except Exception as e:
                log.warning("Failed to mmap existing features file: %s (will recompute). Error: %s", full_feat_path, e)
                full_feat_path.unlink(missing_ok=True)
            else:
                for stem in stems:
                    outp = features_dir / f"{stem}.npy"
                    if outp.exists():
                        continue
                    lab_path = work_dir / 'labels_npy' / f"{stem}.npy"
                    if not lab_path.exists():
                        log.warning("Label missing for %s — skipping", stem)
                        continue
                    lab = np.load(lab_path)
                    L = lab.shape[0]
                    if full_feats.shape[0] >= L:
                        frag_feats = full_feats[:L]
                    else:
                        pad_n = L - full_feats.shape[0]
                        pad_row = full_feats[-1][None].repeat(pad_n, axis=0)
                        frag_feats = np.concatenate([full_feats, pad_row], axis=0)
                    np.save(outp, frag_feats, allow_pickle=False)
                cap.release()
                continue

        # create memmap if we know frame count
        try:
            # Determine a device string for AMP helper (could be "cuda:0" or "cpu")
            try:
                _device = next(backbone.parameters()).device
            except StopIteration:
                # unexpected: backbone has no parameters; default to cpu
                _device = torch.device("cpu")
            device_str = str(_device)

            if total_frames and total_frames > 0:
                log.info("Creating memmap for %s frames=%d feat_dim=%d", base_vid, total_frames, feat_dim)
                mm = np.lib.format.open_memmap(str(full_feat_path), mode='w+', dtype='float32', shape=(total_frames, feat_dim))
                write_pos = 0
                batch = []
                with torch.no_grad():
                    amp_ctx = amp_context_for_device(device_str)
                    while True:
                        ret, frame = cap.read()
                        if not ret:
                            break
                        arr = preprocess_frame_cv2(frame, img_size=CONFIG['IMG_SIZE'])
                        batch.append(arr)
                        if len(batch) >= batch_frame:
                            ok, msg = guard_memory(CONFIG['MAX_RAM_FRAC'], CONFIG['MAX_VRAM_FRAC'])
                            if not ok:
                                log.warning("Memory threshold reached during encoding: %s", msg)
                                raise RuntimeError("Memory threshold reached")
                            try:
                                with amp_ctx():
                                    tensor_batch = torch.from_numpy(np.stack(batch, axis=0)).to(next(backbone.parameters()).device)
                                    out = backbone(tensor_batch).view(tensor_batch.size(0), -1).cpu().numpy()
                                    mm[write_pos:write_pos+out.shape[0]] = out.astype('float32')
                                    write_pos += out.shape[0]
                                    batch = []
                            except RuntimeError as e:
                                # Handle CUDA OOM by stepping down batch size and retrying
                                log.warning("RuntimeError during encoding (likely OOM): %s — reducing batch_frame and retrying", e)
                                torch.cuda.empty_cache()
                                if batch_frame > CONFIG['MIN_BATCH_FRAME']:
                                    batch_frame = max(CONFIG['MIN_BATCH_FRAME'], batch_frame // 2)
                                    log.info("Reduced batch_frame to %d", batch_frame)
                                    time.sleep(0.5)
                                    continue
                                else:
                                    raise
                    # leftover batch
                    if batch:
                        with amp_ctx():
                            tensor_batch = torch.from_numpy(np.stack(batch, axis=0)).to(next(backbone.parameters()).device)
                            out = backbone(tensor_batch).view(tensor_batch.size(0), -1).cpu().numpy()
                            mm[write_pos:write_pos+out.shape[0]] = out.astype('float32')
                            write_pos += out.shape[0]
                cap.release()
                # truncate if fewer frames encoded than declared
                if write_pos != mm.shape[0]:
                    log.debug("Truncating memmap from %d -> %d", mm.shape[0], write_pos)
                    mm.flush()
                    del mm
                    arr = np.load(str(full_feat_path), mmap_mode='r')
                    arr2 = arr[:write_pos]
                    np.save(str(full_feat_path), arr2, allow_pickle=False)
                else:
                    mm.flush()
                    del mm
                log.info("Saved full-video features: %s", full_feat_path)
            else:
                # Unknown frame count -> accumulate in chunks and save
                feats_chunks = []
                batch = []
                with torch.no_grad():
                    amp_ctx = amp_context_for_device(device_str)
                    while True:
                        ret, frame = cap.read()
                        if not ret:
                            break
                        arr = preprocess_frame_cv2(frame, img_size=CONFIG['IMG_SIZE'])
                        batch.append(arr)
                        if len(batch) >= batch_frame:
                            ok, msg = guard_memory(CONFIG['MAX_RAM_FRAC'], CONFIG['MAX_VRAM_FRAC'])
                            if not ok:
                                log.warning("Memory threshold reached during encoding: %s", msg)
                                raise RuntimeError("Memory threshold reached")
                            try:
                                with amp_ctx():
                                    tensor_batch = torch.from_numpy(np.stack(batch, axis=0)).to(next(backbone.parameters()).device)
                                    out = backbone(tensor_batch).view(tensor_batch.size(0), -1).cpu().numpy()
                                    feats_chunks.append(out.astype('float32'))
                                    batch = []
                            except RuntimeError as e:
                                log.warning("OOM during encoding: %s", e)
                                torch.cuda.empty_cache()
                                if batch_frame > CONFIG['MIN_BATCH_FRAME']:
                                    batch_frame = max(CONFIG['MIN_BATCH_FRAME'], batch_frame // 2)
                                    log.info("Reduced batch_frame to %d", batch_frame)
                                    time.sleep(0.5)
                                    continue
                                else:
                                    raise
                    if batch:
                        with amp_ctx():
                            tensor_batch = torch.from_numpy(np.stack(batch, axis=0)).to(next(backbone.parameters()).device)
                            out = backbone(tensor_batch).view(tensor_batch.size(0), -1).cpu().numpy()
                            feats_chunks.append(out.astype('float32'))
                cap.release()
                if len(feats_chunks) == 0:
                    log.warning("No frames encoded for %s", base_vid)
                    continue
                full_feats = np.concatenate(feats_chunks, axis=0)
                np.save(full_feat_path, full_feats, allow_pickle=False)
                log.info("Saved fallback full-video features: %s shape %s", full_feat_path, full_feats.shape)
                del feats_chunks, full_feats

            # Slice and save per-fragment features using mmap load (minimal RAM)
            try:
                full_feats = np.load(full_feat_path, mmap_mode='r')
            except Exception as e:
                log.error("Failed to open saved full video features for slicing: %s (%s)", full_feat_path, e)
                continue
            for stem in stems:
                outp = features_dir / f"{stem}.npy"
                if outp.exists():
                    continue
                lab_path = work_dir / 'labels_npy' / f"{stem}.npy"
                if not lab_path.exists():
                    log.warning("Label missing for fragment %s — skipping", stem)
                    continue
                lab = np.load(lab_path)
                L = lab.shape[0]
                if full_feats.shape[0] >= L:
                    frag_feats = full_feats[:L]
                else:
                    pad_n = L - full_feats.shape[0]
                    pad_row = full_feats[-1][None].repeat(pad_n, axis=0)
                    frag_feats = np.concatenate([full_feats, pad_row], axis=0)
                np.save(outp, frag_feats, allow_pickle=False)
            # Free and clear GPU cache
            del full_feats
            torch.cuda.empty_cache()
            time.sleep(CONFIG['SAFE_SLEEP_SEC'])
        except Exception as e:
            log.error("Error processing %s: %s", base_vid, e)
            try:
                cap.release()
            except:
                pass
            # if memmap exists but a partial product was created, leave it for inspection
            continue

    log.info("Streaming feature extraction finished. Per-video features in %s and per-fragment features in %s", features_by_video_dir, features_dir)


In [None]:
# Cell 8 — ClipDataset (mmap-backed) with careful mem usage & docstring
# -----------------------------------------------------------------------------
import torch.utils.data as data

class ClipDataset(data.Dataset):
    """
    Dataset that returns clips of length seq_len from memmapped per-video features.
    Keeps labels for chosen videos in-memory (labels are small), but loads features lazily with mmap.
    """
    def __init__(self, work_dir: Path = WORK_DIR, seq_len: int = CONFIG['SEQ_LEN'], sample_videos: Optional[int] = CONFIG['SAMPLE_VIDEOS'], seed: int = 42):
        self.work_dir = Path(work_dir)
        self.features_dir = self.work_dir / 'features'
        self.labels_dir = self.work_dir / 'labels_npy'
        self.seq_len = seq_len

        label_stems = sorted([p.stem for p in self.labels_dir.glob('*.npy')])
        feature_stems = set([p.stem for p in self.features_dir.glob('*.npy')])
        avail = [s for s in label_stems if s in feature_stems]
        if len(avail) == 0:
            raise RuntimeError("No videos with both labels and features found. Run feature extraction first.")

        rng = np.random.RandomState(seed)
        rng.shuffle(avail)
        if sample_videos is not None:
            avail = avail[:sample_videos]
        self.videos = avail

        # load labels fully (labels arrays are small relative to features)
        self.video_labels = {v: np.load(self.labels_dir / f"{v}.npy", mmap_mode=None) for v in self.videos}

        # build sliding-window index with stride = seq_len // 2
        self.index = []
        for v in self.videos:
            n = len(self.video_labels[v])
            if n < self.seq_len:
                continue
            stride = max(1, self.seq_len // 2)
            for s in range(0, n - self.seq_len + 1, stride):
                self.index.append((v, s))

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        vid, s = self.index[idx]
        feat_path = self.features_dir / f"{vid}.npy"
        if not feat_path.exists():
            raise FileNotFoundError(f"Missing features for {vid} at {feat_path}")
        feats = np.load(feat_path, mmap_mode='r')
        clip_feats = feats[s:s+self.seq_len].astype('float32')
        clip_lbls = self.video_labels[vid][s:s+self.seq_len].astype('int64')
        return {'video_id': vid, 'start': s, 'feats': torch.from_numpy(clip_feats).float(), 'labels': torch.from_numpy(clip_lbls).long()}


In [None]:
# Cell 9 — quick_train_example with amp and gradient accumulation (optional)
# -----------------------------------------------------------------------------
def quick_train_example(work_dir: Path = WORK_DIR, seq_len: int = 100, batch_size: int = 2, epochs: int = 3, lr: float = 1e-3):
    """
    Small training loop demonstrating safe GPU usage:
      - uses AMP for mixed precision
      - optional DataParallel if multiple GPUs detected
      - uses gradient accumulation to control effective batch size
    """
    ds = ClipDataset(work_dir=work_dir, seq_len=seq_len, sample_videos=CONFIG['SAMPLE_VIDEOS'])
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, collate_fn=lambda x: x)
    # find input dim by peeking at one feature file
    feat_sample_path = next((work_dir/'features').glob('*.npy'))
    input_dim = np.load(feat_sample_path, mmap_mode='r').shape[1]
    rnn = nn.GRU(input_dim, 64, batch_first=True)
    head = nn.Linear(64, 10)  # change 10 -> num classes
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda" and torch.cuda.device_count() > 1:
        rnn = nn.DataParallel(rnn)
        head = nn.DataParallel(head)
    rnn.to(device); head.to(device)
    opt = torch.optim.Adam(list(rnn.parameters()) + list(head.parameters()), lr=lr)
    scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

    for epoch in range(epochs):
        log.info("Epoch %d/%d", epoch+1, epochs)
        for batch in dl:
            # collate
            feats = torch.stack([b['feats'] for b in batch], dim=0).to(device)
            labels = torch.stack([b['labels'] for b in batch], dim=0).to(device)
            with torch.cuda.amp.autocast(enabled=(device=="cuda")):
                out, _ = rnn(feats)
                logits = head(out)  # B x T x C
                loss = nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-1)
            opt.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            torch.cuda.empty_cache()
        log.info("Epoch %d done (loss %.4f)", epoch+1, float(loss.detach().cpu().item()))


In [None]:
# Cell 10 — Main usage snippet (run in __main__ or cells sequentially)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    print_system_summary()
    # 1) parse annotations & create label arrays
    seg_df, phase_map = parse_annotations(CONFIG['BASE_DIR'], WORK_DIR, CONFIG['FPS'], CONFIG['SAMPLE_VIDEOS'])
    # 2) extract features streaming
    extract_features_streaming(WORK_DIR, CONFIG['BASE_DIR'], CONFIG['SAMPLE_VIDEOS'])
    # 3) dataset and optional quick train
    # ds = ClipDataset(WORK_DIR)
    # quick_train_example(work_dir=WORK_DIR, seq_len=CONFIG['SEQ_LEN'], batch_size=2, epochs=1)
    print("All done for this run. Review logs above.")


In [None]:
# Cell 11 — Upload output folder to Hugging Face (final step)
# -----------------------------------------------------------------------------
# Note: requires HF_TOKEN in environment. Uses hf_api uploaded earlier.
folder_to_upload = str(WORK_DIR)
repo_id = CONFIG['HF_REPO_ID']
repo_type = CONFIG['HF_REPO_TYPE']

# Safety check: do not attempt upload if HF_TOKEN missing
if not HF_TOKEN:
    log.error("HF_TOKEN missing — cannot upload. Set HF_TOKEN environment variable and re-run.")
else:
    # upload_folder will fail if repo doesn't exist; you can create a dataset repo on HF or set repo_id appropriately.
    log.info("Uploading folder %s to Hugging Face repo %s (type=%s). This may take time.", folder_to_upload, repo_id, repo_type)
    try:
        hf_api.upload_folder(folder_path=folder_to_upload, repo_id=repo_id, repo_type=repo_type)
        log.info("Upload finished (or the call returned). Check the HF repo to confirm.")
    except Exception as e:
        log.error("Upload failed: %s", e)
