<a href="https://colab.research.google.com/github/elenancalima/Troph_Min_5to2/blob/main/Minimal_5_to_2_backbone.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# PIP-only, GPU-capable stack (T4 / CUDA 12.1)
# - numpy==1.26.4  (avoids NumPy 2.x ABI churn in Colab)
# - torch==2.4.1 (cu121)
# - pillow (PNG I/O)
import subprocess, sys, os

def sh(cmd):
    print("$", cmd)
    r = subprocess.run(cmd, shell=True)
    if r.returncode != 0:
        raise SystemExit(f"Command failed: {cmd}")

# Remove potentially conflicting preinstalls
sh("pip -q uninstall -y torch torchvision torchaudio numpy || true")

# Pin NumPy FIRST
sh("pip -q install --no-deps --only-binary=:all: numpy==1.26.4")

# PyTorch for CUDA 12.1 (works on CPU too; will use GPU if present)
TORCH_IDX = "https://download.pytorch.org/whl/cu121"
sh(f"pip -q install --index-url {TORCH_IDX} torch==2.4.1")

# Pillow for PNG I/O
sh("pip -q install --upgrade pillow")

print("✅ Installed. The runtime will now restart to load the new NumPy cleanly...")
# Hard restart so the new binary NumPy is used (prevents dtype-size mismatch).
os.kill(os.getpid(), 9)


$ pip -q uninstall -y torch torchvision torchaudio numpy || true
$ pip -q install --no-deps --only-binary=:all: numpy==1.26.4
$ pip -q install --index-url https://download.pytorch.org/whl/cu121 torch==2.4.1
$ pip -q install --upgrade pillow


BELOW SETUP IS UNTESTED

In [None]:
# Smart, GPU-ready setup for Colab (T4). Reuses preinstalls if healthy.
import os, sys, subprocess

def sh(cmd):
    print("$", cmd)
    r = subprocess.run(cmd, shell=True)
    return r.returncode == 0

def torch_ok():
    try:
        import torch
        # Needs to be CUDA-enabled and have a known CUDA build (prefer cu121 for T4)
        if not torch.cuda.is_available(): return False
        cu = getattr(torch.version, "cuda", None) or ""
        # Accept cu121 (preferred) or anything 12.x that actually runs on GPU
        return cu.startswith("12")  # keep flexible, but you can tighten to "12.1"
    except Exception:
        return False

def numpy_ok():
    try:
        import numpy as np
        # Quick sanity: can we import compiled random pieces?
        import numpy.random._bounded_integers  # fails if ABI mismatch
        return True
    except Exception:
        return False

need_torch = not torch_ok()
need_numpy = not numpy_ok()

print(f"precheck -> need_numpy={need_numpy}, need_torch={need_torch}")

if need_numpy or need_torch:
    # Clean conflicting wheels first to avoid mixed ABIs
    sh("pip -q uninstall -y torch torchvision torchaudio numpy || true")
    # Pin NumPy FIRST to avoid ABI churn
    if not sh("pip -q install --no-deps --only-binary=:all: numpy==1.26.4"):
        raise SystemExit("Failed to install numpy 1.26.4")
    # Install a GPU-capable Torch (cu121 works well on T4)
    TORCH_IDX = "https://download.pytorch.org/whl/cu121"
    if not sh(f"pip -q install --index-url {TORCH_IDX} torch==2.4.1"):
        raise SystemExit("Failed to install torch 2.4.1 (cu121)")
    # Pillow for PNG I/O
    sh("pip -q install --upgrade pillow")
    print("✅ Installed fixed stack. Restarting kernel to load cleanly…")
    os.kill(os.getpid(), 9)
else:
    print("✅ Environment looks good—no reinstall needed.")


ACTUAL RUN STARTS HERE

In [7]:
import numpy as np, torch, PIL, platform, sys
print("python", sys.version.split()[0], "|", platform.platform())
print("numpy", np.__version__)
print("torch", torch.__version__, "cuda?", torch.cuda.is_available())
if torch.cuda.is_available():
    print("cuda device:", torch.cuda.get_device_name(0))
print("pillow", PIL.__version__)

# Deterministic-ish
import random
random.seed(0); np.random.seed(0); torch.manual_seed(0)
try:
    torch.use_deterministic_algorithms(True)
except Exception:
    pass
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"


python 3.12.12 | Linux-6.6.105+-x86_64-with-glibc2.35
numpy 1.26.4
torch 2.4.1+cu121 cuda? True
cuda device: Tesla T4
pillow 11.3.0


In [8]:
import torch.nn as nn

# Stage 1: RGB (850→170) -> 2 dense maps @170x170
class Stage1Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 2, 1)
        )
    def forward(self, x170_rgb):
        return torch.sigmoid(self.net(x170_rgb))  # (B,2,170,170)

# Stage 2: concat [2 stage1 + 1 abdomen + 1 front + 1 simple_troph + 1 simple_part] -> 2 final maps
class Stage2Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(6, 16, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(16, 2, 1)
        )
    def forward(self, x6):
        return torch.sigmoid(self.net(x6))       # (B,2,170,170)

stage1 = Stage1Net().to(device).eval()
stage2 = Stage2Net().to(device).eval()


In [9]:
import os, re, glob
from pathlib import Path
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np

TARGET_850 = (850, 850)   # (W,H) for Pillow
TARGET_170 = (170, 170)

def ensure_exists(*folders):
    for d in folders:
        if not os.path.isdir(d):
            raise RuntimeError(f"Required folder missing: {d}")

def list_indexed_pngs(folder):
    paths = sorted(glob.glob(os.path.join(folder, "*.png")))
    if not paths:
        raise RuntimeError(f"No .png found in: {folder}")
    items = []
    seen = set()
    for p in paths:
        m = re.search(r'(\d+)(?=\D*$)', os.path.basename(p))  # last digit block
        if not m:
            raise RuntimeError(f"No numeric frame index in filename: {p}")
        s = m.group(1); idx = int(s); pad = len(s)
        if idx in seen:
            raise RuntimeError(f"Duplicate frame index {idx} in {folder}")
        seen.add(idx)
        items.append((idx, p, pad))
    d = {idx: p for idx, p, _ in items}
    pad_width = max(pad for _,_,pad in items)
    return d, pad_width

def load_rgb_850_as_tensor(path):
    im = Image.open(path).convert("RGB")
    if im.size != TARGET_850:
        raise RuntimeError(f"input_vid must be 850x850 (W,H); got {im.size} for {path}")
    arr = np.asarray(im, dtype=np.float32) / 255.0      # (H,W,3)
    return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0)  # (1,3,850,850)

def load_bin_850_as_tensor01(path):
    im = Image.open(path).convert("L")
    if im.size != TARGET_850:
        raise RuntimeError(f"Binary 850 must be 850x850 (W,H); got {im.size} for {path}")
    arr = (np.asarray(im, dtype=np.uint8) >= 128).astype(np.float32)  # 0/1
    return torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)            # (1,1,850,850)

def load_bin_170_as_tensor01(path):
    im = Image.open(path).convert("L")
    if im.size != TARGET_170:
        raise RuntimeError(f"Binary 170 must be 170x170 (W,H); got {im.size} for {path}")
    arr = (np.asarray(im, dtype=np.uint8) >= 128).astype(np.float32)
    return torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)            # (1,1,170,170)

def save_binary_170_png(path, t01):
    if t01.ndim == 4: t01 = t01[0,0]
    arr = (t01.cpu().numpy() >= 0.5).astype(np.uint8) * 255
    Image.fromarray(arr, mode="L").save(path)


In [10]:
def run_inference(video_root: str, threshold: float = 0.5, verbose_first=False):
    # Inputs
    d_in   = os.path.join(video_root, "input_vid")
    d_abd  = os.path.join(video_root, "abdomen_mask")
    d_fb   = os.path.join(video_root, "front_body_mask")
    d_st   = os.path.join(video_root, "simple_troph_point_heatmap")
    d_spl  = os.path.join(video_root, "simple_participant_line_heatmap")
    ensure_exists(d_in, d_abd, d_fb, d_st, d_spl)

    # Outputs
    d_out_tp  = os.path.join(video_root, "pred_troph_point")
    d_out_pl  = os.path.join(video_root, "pred_participant_line")
    Path(d_out_tp).mkdir(parents=True, exist_ok=True)
    Path(d_out_pl).mkdir(parents=True, exist_ok=True)

    # Index maps (idx -> path) and pad width from input_vid
    M_in,  pad_width = list_indexed_pngs(d_in)
    M_abd, _ = list_indexed_pngs(d_abd)
    M_fb,  _ = list_indexed_pngs(d_fb)
    M_st,  _ = list_indexed_pngs(d_st)
    M_spl, _ = list_indexed_pngs(d_spl)

    S_in, S_abd, S_fb, S_st, S_spl = map(set, (M_in.keys(), M_abd.keys(), M_fb.keys(), M_st.keys(), M_spl.keys()))
    if not (S_in == S_abd == S_fb == S_st == S_spl):
        all_idx = sorted(S_in | S_abd | S_fb | S_st | S_spl)
        def miss(name, S):
            missing = [i for i in all_idx if i not in S]
            return f"{name}: missing {len(missing)} → {missing[:20]}{' ...' if len(missing)>20 else ''}"
        msg = "Frame-index mismatch across folders:\n  " + "\n  ".join([
            miss("input_vid", S_in), miss("abdomen_mask", S_abd), miss("front_body_mask", S_fb),
            miss("simple_troph_point_heatmap", S_st), miss("simple_participant_line_heatmap", S_spl)
        ])
        raise RuntimeError(msg)

    idx_list = sorted(S_in); n = len(idx_list)
    print(f"[OK] Found {n} aligned frame indices under: {video_root}")
    print(f"     Writing 170×170 binaries to:\n       {d_out_tp}\n       {d_out_pl}")

    with torch.inference_mode():
        for k, idx in enumerate(idx_list):
            # --- load to CPU
            x_rgb850 = load_rgb_850_as_tensor(M_in[idx])
            abd850   = load_bin_850_as_tensor01(M_abd[idx])
            fb850    = load_bin_850_as_tensor01(M_fb[idx])
            st170    = load_bin_170_as_tensor01(M_st[idx])
            spl170   = load_bin_170_as_tensor01(M_spl[idx])

            # --- move to device
            x_rgb850 = x_rgb850.to(device, non_blocking=True)
            abd850   = abd850.to(device, non_blocking=True)
            fb850    = fb850.to(device, non_blocking=True)
            st170    = st170.to(device, non_blocking=True)
            spl170   = spl170.to(device, non_blocking=True)

            # --- downsample to 170
            x_rgb170 = F.interpolate(x_rgb850, size=(170,170), mode="bilinear", align_corners=False)
            abd170   = F.interpolate(abd850,   size=(170,170), mode="nearest")
            fb170    = F.interpolate(fb850,    size=(170,170), mode="nearest")

            # --- Stage 1 & 2
            f12 = stage1(x_rgb170)                                   # (1,2,170,170)
            x6  = torch.cat([f12, abd170, fb170, st170, spl170], dim=1)  # (1,6,170,170)
            y2  = stage2(x6)                                          # (1,2,170,170)

            if verbose_first and k == 0:
                print("Shapes @first:",
                      "f12", tuple(f12.shape),
                      "abd", tuple(abd170.shape),
                      "fb",  tuple(fb170.shape),
                      "st",  tuple(st170.shape),
                      "spl", tuple(spl170.shape))

            # --- save (threshold to binary)
            a = (y2[:,0:1] >= threshold).float()
            b = (y2[:,1:2] >= threshold).float()
            save_binary_170_png(os.path.join(d_out_tp, f"{idx:0{pad_width}d}.png"), a)
            save_binary_170_png(os.path.join(d_out_pl, f"{idx:0{pad_width}d}.png"), b)

            if (k+1) % 100 == 0 or k == n-1:
                print(f"  processed {k+1}/{n}")

    print("[DONE] Inference complete.")


In [11]:
from PIL import Image, ImageDraw
import os, glob
from pathlib import Path

VIDEO_ROOT = "/content/fake_video_root_t4"
for sd in ["input_vid","abdomen_mask","front_body_mask","simple_troph_point_heatmap","simple_participant_line_heatmap"]:
    Path(os.path.join(VIDEO_ROOT, sd)).mkdir(parents=True, exist_ok=True)

H=W=850; h=w=170; N=20
for i in range(N):
    # input 850 rgb
    img = Image.new("RGB", (W,H), (0,0,0))
    d = ImageDraw.Draw(img)
    d.ellipse([(100+(i*7)%W-40,200-40),(100+(i*7)%W+40,200+40)], fill=(40,120,240))
    d.rectangle([(300,300),(500,500)], fill=(200,60,60))
    d.text((600,100), f"{i+1:03d}", fill=(210,210,210))
    img.save(os.path.join(VIDEO_ROOT,"input_vid",f"frame_{i+1:05d}.png"))

    # abdomen 850 bin
    a = Image.new("L", (W,H), 0); d = ImageDraw.Draw(a)
    d.ellipse([(200+(i*5)%W-120,600-120),(200+(i*5)%W+120,600+120)], fill=255)
    a.save(os.path.join(VIDEO_ROOT,"abdomen_mask",f"frame_{i+1:05d}.png"))

    # front 850 bin
    f = Image.new("L", (W,H), 0); d = ImageDraw.Draw(f)
    y = (i*9) % H; d.line([(50,y),(W-50,y)], fill=255, width=5)
    f.save(os.path.join(VIDEO_ROOT,"front_body_mask",f"frame_{i+1:05d}.png"))

    # simple 170 bins
    s1 = Image.new("L", (w,h), 0); d = ImageDraw.Draw(s1)
    d.ellipse([(20+(i*3)%w-3,30-3),(20+(i*3)%w+3,30+3)], fill=255)
    s1.save(os.path.join(VIDEO_ROOT,"simple_troph_point_heatmap",f"frame_{i+1:05d}.png"))

    s2 = Image.new("L", (w,h), 0); d = ImageDraw.Draw(s2)
    y2 = (i*4) % h; d.line([(10,y2),(w-10,y2)], fill=255, width=1)
    s2.save(os.path.join(VIDEO_ROOT,"simple_participant_line_heatmap",f"frame_{i+1:05d}.png"))

print("[OK] Fake data at:", VIDEO_ROOT)
run_inference(VIDEO_ROOT, threshold=0.5, verbose_first=True)

A = sorted(glob.glob(os.path.join(VIDEO_ROOT, "pred_troph_point", "*.png")))
B = sorted(glob.glob(os.path.join(VIDEO_ROOT, "pred_participant_line", "*.png")))
print("Output counts:", len(A), len(B))


[OK] Fake data at: /content/fake_video_root_t4
[OK] Found 20 aligned frame indices under: /content/fake_video_root_t4
     Writing 170×170 binaries to:
       /content/fake_video_root_t4/pred_troph_point
       /content/fake_video_root_t4/pred_participant_line
Shapes @first: f12 (1, 2, 170, 170) abd (1, 1, 170, 170) fb (1, 1, 170, 170) st (1, 1, 170, 170) spl (1, 1, 170, 170)


  Image.fromarray(arr, mode="L").save(path)


  processed 20/20
[DONE] Inference complete.
Output counts: 20 20


This takes considerably longer than YOLO inference, beware.

In [12]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# --- Paths you gave ---
videoName   = "FoodLimit2.5D_gR0033_feeding"
GDRIVE_ROOT = "/content/drive/MyDrive"
VIDEO_ROOT  = f"{GDRIVE_ROOT}/MainConnection_VidRoots/{videoName}"  # this is the video root

print("VIDEO_ROOT =", VIDEO_ROOT)

# Optional: quick peek to confirm subfolders exist (won't list files, just names)
import os
expected = [
    "input_vid", "abdomen_mask", "front_body_mask",
    "simple_troph_point_heatmap", "simple_participant_line_heatmap"
]
missing = [d for d in expected if not os.path.isdir(os.path.join(VIDEO_ROOT, d))]
if missing:
    raise RuntimeError(f"Missing expected subfolders under VIDEO_ROOT: {missing}")

# Run inference (uses GPU if available)
run_inference(VIDEO_ROOT, threshold=0.5, verbose_first=True)


Mounted at /content/drive
VIDEO_ROOT = /content/drive/MyDrive/MainConnection_VidRoots/FoodLimit2.5D_gR0033_feeding
[OK] Found 207 aligned frame indices under: /content/drive/MyDrive/MainConnection_VidRoots/FoodLimit2.5D_gR0033_feeding
     Writing 170×170 binaries to:
       /content/drive/MyDrive/MainConnection_VidRoots/FoodLimit2.5D_gR0033_feeding/pred_troph_point
       /content/drive/MyDrive/MainConnection_VidRoots/FoodLimit2.5D_gR0033_feeding/pred_participant_line
Shapes @first: f12 (1, 2, 170, 170) abd (1, 1, 170, 170) fb (1, 1, 170, 170) st (1, 1, 170, 170) spl (1, 1, 170, 170)


  Image.fromarray(arr, mode="L").save(path)


  processed 100/207
  processed 200/207
  processed 207/207
[DONE] Inference complete.
