**Prereqs**
---

In [None]:
# If needed
# %pip install torch torchvision pandas opencv-python tqdm
# Debian/Ubuntu system packages (in a terminal): sudo apt-get install ffmpeg
# conda install -c conda-forge ffmpeg

**Config**
---

In [None]:
from pathlib import Path

# Single source video (you can expand to many later)
INPUT_VIDEO = Path("edu_data/1908 2nd-observation.mov")

# Workspace for clips + metadata + checkpoints
WORK_DIR = Path("edu_data/1908_2nd_observation_prepped_nb")
WORK_DIR.mkdir(parents=True, exist_ok=True)

# Clip parameters
SEGMENT_SECONDS = 5
TARGET_FPS = 15          # re-sample to this fps (approx) when extracting
SHORT_SIDE = 224         # resize short side; preserve aspect; make even dims

# SSL training params
EPOCHS = 100
BATCH_SIZE = 4
CLIP_FRAMES = 16         # frames per training sample
NUM_WORKERS = 0          # start with 0 in notebooks
LR = 1e-3

**Extract 5‑second clips without FFmpeg (OpenCV → frame folders)**
---

In [None]:
import cv2, math, os, pandas as pd
from tqdm import tqdm
import numpy as np

def _resize_keep_short_side(frame, short_side=224):
    h, w = frame.shape[:2]
    if h <= 0 or w <= 0:
        return frame
    if h < w:
        new_h = short_side
        new_w = int(round(w * (short_side / h)))
    else:
        new_w = short_side
        new_h = int(round(h * (short_side / w)))
    # make even dims for downstream 3D models
    new_h += new_h % 2
    new_w += new_w % 2
    return cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_CUBIC)

def extract_clip_folders_opencv(
    input_video: Path,
    out_dir: Path,
    segment_seconds: int = 5,
    target_fps: int = 15,
    short_side: int = 224,
    img_ext: str = ".png"  # .jpg is smaller but lossy; pick what you prefer
) -> pd.DataFrame:
    """
    Decode with OpenCV, re-sample frames to ~target_fps, write each 5s chunk as a folder of frames.
    Returns a metadata DataFrame with one row per clip.
    """
    out_clips_dir = out_dir / "clips_frames"
    out_clips_dir.mkdir(parents=True, exist_ok=True)

    cap = cv2.VideoCapture(str(input_video))
    assert cap.isOpened(), f"Could not open: {input_video}"

    src_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
    duration_s = total_frames / max(src_fps, 1e-6)

    # Determine frame sampling stride to approximate TARGET_FPS
    stride = max(int(round(src_fps / target_fps)), 1)
    eff_fps = src_fps / stride

    frames_per_clip = int(round(segment_seconds * eff_fps))
    clip_idx = 0
    frame_idx = 0
    sampled_idx = 0

    rows = []
    writer_count_in_clip = 0
    clip_dir = None
    start_time = 0.0

    pbar = tqdm(total=total_frames, desc="Decoding & writing frames")
    while True:
        ret, frame = cap.read()
        if not ret:
            # flush last clip
            if writer_count_in_clip > 0 and clip_dir is not None:
                end_time = start_time + (writer_count_in_clip / eff_fps)
                rows.append(dict(
                    clip_id=f"{clip_idx:06d}",
                    clip_dir=str(clip_dir),
                    source_video=str(input_video),
                    start_s=round(start_time, 3),
                    end_s=round(end_time, 3),
                    fps=round(eff_fps, 3),
                    num_frames=writer_count_in_clip,
                    duration_s=round(end_time - start_time, 3)
                ))
            break

        # keep about TARGET_FPS by skipping frames
        if frame_idx % stride == 0:
            # resize/pad to even dims on the fly
            frame = _resize_keep_short_side(frame, short_side=short_side)

            # start a new clip folder if needed
            if writer_count_in_clip == 0:
                clip_dir = out_clips_dir / f"{clip_idx:06d}"
                clip_dir.mkdir(parents=True, exist_ok=True)
                start_time = (sampled_idx / eff_fps)

            # write frame
            out_path = clip_dir / f"{writer_count_in_clip:06d}{img_ext}"
            cv2.imwrite(str(out_path), frame)
            writer_count_in_clip += 1
            sampled_idx += 1

            # close the clip if we reached 5 seconds worth of sampled frames
            if writer_count_in_clip >= frames_per_clip:
                end_time = start_time + (writer_count_in_clip / eff_fps)
                rows.append(dict(
                    clip_id=f"{clip_idx:06d}",
                    clip_dir=str(clip_dir),
                    source_video=str(input_video),
                    start_s=round(start_time, 3),
                    end_s=round(end_time, 3),
                    fps=round(eff_fps, 3),
                    num_frames=writer_count_in_clip,
                    duration_s=round(end_time - start_time, 3)
                ))
                clip_idx += 1
                writer_count_in_clip = 0
                clip_dir = None

        frame_idx += 1
        pbar.update(1)

    cap.release()

    df = pd.DataFrame(rows)
    csv_path = out_dir / "clips_metadata.csv"
    df.to_csv(csv_path, index=False)
    print(f"Wrote {len(df)} clips to {out_clips_dir}")
    print(f"Metadata: {csv_path}")
    return df

df_meta = extract_clip_folders_opencv(
    INPUT_VIDEO, WORK_DIR,
    segment_seconds=SEGMENT_SECONDS, target_fps=TARGET_FPS, short_side=SHORT_SIDE,
    img_ext=".jpg"  # switch to ".png" if you prefer lossless
)

df_meta.head()

**Visual sanity check (first clip)**
---

In [None]:
from IPython.display import display
from PIL import Image
import glob

first_clip_dir = Path(df_meta.iloc[0]["clip_dir"])
first_frames = sorted(glob.glob(str(first_clip_dir / "*.jpg")))[:4]
display(Image.open(first_frames[0]))


**Dataset & SSL model (SimCLR on 3D ResNet‑18)**
---

In [None]:
import os, random, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.transforms import v2 as T
from PIL import Image
import glob
from tqdm import tqdm

class FrameFolderDatasetSSL(Dataset):
    def __init__(self, metadata_csv, clip_frames=16, resize_hw=224, training=True, img_glob="*.jpg"):
        import pandas as pd
        self.df = pd.read_csv(metadata_csv)
        self.clip_frames = clip_frames
        self.training = training
        self.img_glob = img_glob
        self.tx = T.Compose([
            T.ToImage(),  # HWC->CHW for PIL/numpy
            T.Resize((resize_hw, resize_hw)),
            T.RandomHorizontalFlip(p=0.5) if training else T.Identity(),
            T.ColorJitter(0.4, 0.4, 0.4, 0.2) if training else T.Identity(),
            T.RandomGrayscale(p=0.2) if training else T.Identity(),
            T.ToDtype(torch.float32, scale=True),
            T.Normalize(mean=[0.45,0.45,0.45], std=[0.225,0.225,0.225]),
        ])

    def __len__(self): return len(self.df)

    def _load_clip_frames(self, clip_dir):
        frames = sorted(glob.glob(os.path.join(clip_dir, self.img_glob)))
        return frames

    def _sample_indices(self, total, Treq):
        if total <= Treq:
            idx = list(range(total))
            while len(idx) < Treq:
                idx += idx[:(Treq - len(idx))]
            return idx[:Treq]
        start = random.randint(0, total - Treq)
        return list(range(start, start + Treq))

    def _two_views(self, pil_list):
        # pil_list: list of PIL.Image (length = clip_frames)
        v1 = torch.stack([self.tx(img) for img in pil_list], dim=1)  # [C,T,H,W]
        v2 = torch.stack([self.tx(img) for img in pil_list], dim=1)
        return v1, v2

    def __getitem__(self, i):
        clip_dir = self.df.iloc[i]["clip_dir"]
        files = self._load_clip_frames(clip_dir)
        if len(files) == 0:
            # create dummy black clip if decode failed
            dummy = Image.fromarray(np.zeros((224,224,3), dtype=np.uint8))
            pil_list = [dummy for _ in range(self.clip_frames)]
            return self._two_views(pil_list)

        idx = self._sample_indices(len(files), self.clip_frames)
        pil_list = [Image.open(files[k]).convert("RGB") for k in idx]
        return self._two_views(pil_list)

class Projector(nn.Module):
    def __init__(self, in_dim, hid=2048, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid), nn.BatchNorm1d(hid), nn.ReLU(True),
            nn.Linear(hid, out_dim)
        )
    def forward(self, x): return self.net(x)

class SimCLR3D(nn.Module):
    def __init__(self, proj_dim=256):
        super().__init__()
        m = torchvision.models.video.r3d_18(weights=None)
        feat_dim = m.fc.in_features
        m.fc = nn.Identity()
        self.backbone = m
        self.projector = Projector(feat_dim, 2048, proj_dim)
    def forward(self, x):  # x: [B,C,T,H,W]
        f = self.backbone(x)
        z = F.normalize(self.projector(f), dim=1)
        return z

def info_nce_loss(z1, z2, temperature=0.2):
    B = z1.size(0)
    z = torch.cat([z1, z2], dim=0)
    sim = (z @ z.t()) / temperature
    sim.fill_diagonal_(-9e15)
    targets = torch.arange(B, device=z.device)
    targets = torch.cat([targets + B, targets], dim=0)
    return F.cross_entropy(sim, targets)

**Train (self‑supervised pretraining)**
---

In [None]:
import os, math, glob, json, time
from pathlib import Path

device = "cuda" if torch.cuda.is_available() else "cpu"

# --- dataloader (unchanged) ---
ds = FrameFolderDatasetSSL(
    metadata_csv=WORK_DIR/"clips_metadata.csv",
    clip_frames=CLIP_FRAMES, resize_hw=224, training=True, img_glob="*.jpg"
)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True,
                num_workers=NUM_WORKERS, pin_memory=(device=="cuda"), drop_last=True)

# --- model/opt (unchanged) ---
model = SimCLR3D(proj_dim=256).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

# --- AMP (speeds up on GPU) ---
use_amp = (device == "cuda")
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

# --- optional scheduler: linear warmup -> cosine decay ---
steps_per_epoch = len(dl)
total_steps = EPOCHS * steps_per_epoch
warmup_steps = max(100, steps_per_epoch // 2)  # ~half an epoch warmup

def lr_schedule(step):
    if step < warmup_steps:
        return (step + 1) / warmup_steps
    # cosine decay to 0.1x of base LR
    t = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * t))

# --- checkpointing helpers ---
CKPT_DIR = Path(WORK_DIR) / "checkpoints_ssl"
CKPT_DIR.mkdir(parents=True, exist_ok=True)
LATEST = CKPT_DIR / "latest.pt"
BEST = CKPT_DIR / "best.pt"
METRICS_JSON = CKPT_DIR / "metrics.json"

def save_ckpt(path, epoch, step, best_loss, best_epoch):
    torch.save({
        "model": model.state_dict(),
        "optimizer": opt.state_dict(),
        "scaler": scaler.state_dict(),
        "epoch": epoch,
        "global_step": step,
        "best_loss": best_loss,
        "best_epoch": best_epoch,
        "hparams": {
            "clip_frames": CLIP_FRAMES,
            "target_fps": TARGET_FPS,
            "short_side": SHORT_SIDE,
            "batch_size": BATCH_SIZE,
            "lr_base": LR,
            "epochs": EPOCHS,
        },
    }, path)

def load_ckpt_if_exists():
    if LATEST.exists():
        ckpt = torch.load(LATEST, map_location=device)
        model.load_state_dict(ckpt["model"])
        opt.load_state_dict(ckpt["optimizer"])
        try:
            scaler.load_state_dict(ckpt["scaler"])
        except Exception:
            pass
        return (ckpt.get("epoch", 0), ckpt.get("global_step", 0),
                ckpt.get("best_loss", float("inf")), ckpt.get("best_epoch", -1))
    return 0, 0, float("inf"), -1

# --- resume if possible ---
start_epoch, global_step, best_running_loss, best_epoch = load_ckpt_if_exists()
print(f"Resuming from epoch {start_epoch}, step {global_step} (best_loss={best_running_loss:.4f} @ epoch {best_epoch})")

# --- training loop ---
log_every = 50                      # steps
save_every_steps = steps_per_epoch  # once per epoch
grad_clip = 1.0

model.train()
running_loss = 0.0
t0 = time.time()

for epoch in range(start_epoch + 1, EPOCHS + 1):
    for step, (v1, v2) in enumerate(dl, start=1):
        model.train()

        # set scheduled LR
        if total_steps > 0:
            scale = lr_schedule(global_step)
            for g in opt.param_groups:
                g["lr"] = LR * scale

        v1, v2 = v1.to(device, non_blocking=True), v2.to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=use_amp):
            z1, z2 = model(v1), model(v2)
            loss = info_nce_loss(z1, z2, temperature=0.2)

        opt.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        # optional gradient clipping
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        scaler.step(opt)
        scaler.update()

        running_loss += loss.item()
        global_step += 1

        if global_step % log_every == 0:
            avg = running_loss / log_every
            elapsed = time.time() - t0
            print(f"[epoch {epoch:03d} | step {global_step:07d}] loss={avg:.4f} lr={opt.param_groups[0]['lr']:.2e} ({elapsed:.1f}s)")
            running_loss = 0.0
            t0 = time.time()

        # save latest checkpoint periodically (end of each epoch here)
        if global_step % save_every_steps == 0:
            save_ckpt(LATEST, epoch, global_step, best_running_loss, best_epoch)

    # end-of-epoch evaluation proxy: use epoch average loss
    # (running_loss resets mid-epoch; compute a quick average by re-looping is costly.
    #  Here we use the last logged 'avg' or compute a lightweight moving average.)
    # For a simple "best" gate, capture the last printed avg. If none printed in epoch, fall back to loss.item().
    epoch_loss_estimate = float(loss.item())

    if epoch_loss_estimate < best_running_loss:
        best_running_loss = epoch_loss_estimate
        best_epoch = epoch
        save_ckpt(BEST, epoch, global_step, best_running_loss, best_epoch)
        print(f"Saved BEST at epoch {epoch} with proxy loss ~{best_running_loss:.4f}")

    # always refresh LATEST at end of epoch
    save_ckpt(LATEST, epoch, global_step, best_running_loss, best_epoch)

# --- export backbone-only checkpoint for downstream tasks (same filename pattern as before) ---
final_ssl = (WORK_DIR/"clips_metadata.csv").with_suffix(".ssl_r3d18.pt")
torch.save({
    "backbone_state_dict": model.backbone.state_dict(),
    "projector_state_dict": model.projector.state_dict(),
    "hparams": dict(clip_frames=CLIP_FRAMES, target_fps=TARGET_FPS, short_side=SHORT_SIDE)
}, final_ssl)
print(f"Exported backbone to: {final_ssl}")

# write compact metrics summary
with open(METRICS_JSON, "w") as f:
    json.dump({"best_loss_proxy": best_running_loss, "best_epoch": best_epoch,
               "final_global_step": global_step}, f, indent=2)
print(f"Checkpoints in: {CKPT_DIR}")