
# Ground→Space Trajectory Modeling (Microscopy + Optional RNA Auxiliary Head)

This extension adapts the existing ConvLSTM tutorial to your dataset layout:
- TIFF microscopy stacks per sample/timepoint
- Optional RNA features (when available) used via an auxiliary head
- Trajectory construction grouped by `(Spaceflight, Material, Medium)`, sorted by `Time`
- Primary objective: predict **Space** trajectories **from Ground** trajectories

> If `s_OSD-627.txt` is provided, the parser below will extract `(Spaceflight, Material, Medium, Time, SampleID)` for clean trajectory linking. If not provided, we fall back to the provided manifest CSVs to align images and RNA columns.


In [4]:

# --- Setup & Imports ---
import os, re, json, zipfile, io
from pathlib import Path
import numpy as np
import pandas as pd

try:
    import tifffile as tiff
except Exception as e:
    raise RuntimeError("tifffile is required. Please install via pip if missing.") from e

# Torch is optional at load time; only required for the model/training section
try:
    import torch
    import torch.nn as nn
    from torch.utils.data import Dataset, DataLoader
except Exception as e:
    print("Note: PyTorch not found at import time. You can still run preprocessing and mapping.")

INPUT_ZIP  = Path("microscopy.zip")      # your original
OUTPUT_ZIP = Path("microscopy_ds.zip")   # where to write downsampled tiffs
DOWNSAMPLE = 2     # integer factor (2, 3, 4, ...) — uses stride decimation (fast & low RAM)
COMPRESSION = 'deflate'  # 'deflate' (zlib), 'lzma', or None
DATA_DIR_RNA = Path("Normalized_counts")
# Expected optional files:
# S_OSD_TXT = DATA_DIR / "s_OSD-627.txt"  # Provide this to unlock full parser
S_OSD_TXT = Path("s_OSD-627.txt")
MANIFEST_CSV = Path("updates_manifest.tsv")
PROVISIONAL_CSV = Path("Provisional_image_RNA_manifest__edit__chosen_rna_col__later_.csv")

# Example RNA xlsx (already uploaded) – update/extend as you add the rest:
RNA_SHEETS = [
    DATA_DIR_RNA / "SSMicro_day1v3_bytime.xlsx",
    DATA_DIR_RNA / "SSGround_day1v3_bytime.xlsx",
    DATA_DIR_RNA / "SS_day3_bygravity.xlsx",
    DATA_DIR_RNA / "Micro_day3_bymaterial.xlsx",
    DATA_DIR_RNA / "LIS_day3_bygravity.xlsx",
    DATA_DIR_RNA / "Ground_day3_bymaterial.xlsx"
]


In [None]:
# def _decimate(arr, factor):
#     """Fast, low-memory decimation by integer factor (area-shrink alternative without deps)."""
#     if arr.ndim == 2:        # [H,W]
#         return arr[::factor, ::factor]
#     elif arr.ndim == 3:
#         # We assume grayscale stack [T,H,W] (your case). If it's RGB per frame, see note below.
#         return arr[:, ::factor, ::factor]
#     else:
#         # squeeze and try again
#         arr2 = np.squeeze(arr)
#         if arr2.ndim in (2,3):
#             return _decimate(arr2, factor)
#         raise ValueError(f"Unexpected array shape {arr.shape}")

# def _ensure_stack_THW(arr):
#     """Return [T,H,W] float32."""
#     arr = np.asarray(arr)
#     if arr.ndim == 2:    # [H,W]
#         arr = arr[None, ...]
#     elif arr.ndim != 3:
#         arr = np.squeeze(arr)
#         if arr.ndim == 2:
#             arr = arr[None, ...]
#         elif arr.ndim != 3:
#             raise ValueError(f"Cannot coerce shape {arr.shape} to [T,H,W]")
#     return arr.astype(np.float32, copy=False)

# def downsample_tiff_bytes(in_bytes, factor, compression=COMPRESSION):
#     """
#     Read a TIFF from bytes, downsample frame-by-frame by integer factor,
#     and return new TIFF bytes (compressed).
#     Uses streaming write to keep memory low.
#     """
#     out_buf = io.BytesIO()
#     with tiff.TiffFile(io.BytesIO(in_bytes)) as tf, tiff.TiffWriter(out_buf, bigtiff=True) as tw:
    #     # Try to iterate by TIFF pages/series to avoid loading entire stacks
    #     # Many microscopy stacks expose a single series with multiple pages (frames)
    #     series = tf.series[0]
    #     # When the series is paged, iterate pages; else fall back to asarray()
    #     if len(series.pages) > 1:
    #         for page in series.pages:
    #             frame = page.asarray()          # [H,W] (expected grayscale)
    #             frame = _decimate(frame, factor)
    #             tw.write(
    #                 frame.astype(np.float32),
    #                 photometric='minisblack',
    #                 compression=compression
    #             )
    #     else:
    #         # Single array; could be [T,H,W] or [H,W]
    #         arr = _ensure_stack_THW(series.asarray())
    #         arr_ds = _decimate(arr, factor)    # [T,H',W']
    #         # Write as multi-page TIFF
    #         for t in range(arr_ds.shape[0]):
    #             tw.write(
    #                 arr_ds[t].astype(np.float32),
    #                 photometric='minisblack',
    #                 compression=compression
    #             )
    # out_buf.seek(0)
    # return out_buf.read()

# def _is_tiff_magic(b4: bytes) -> bool:
#     return b4 in (b'II*\x00', b'MM\x00*', b'II+\x00', b'MM\x00+')

# # Process all TIFFs in the input zip and write downsampled versions to output zip
# assert INPUT_ZIP.exists(), f"Missing {INPUT_ZIP}"
# with zipfile.ZipFile(INPUT_ZIP, "r") as zin, zipfile.ZipFile(OUTPUT_ZIP, "w", compression=zipfile.ZIP_STORED) as zout:
#     names = [n for n in zin.namelist()
#              if n.lower().endswith((".tif", ".tiff"))
#              and not n.endswith("/")
#              and not n.startswith("__MACOSX/")]
#     print(f"Found {len(names)} TIFF-like entries")
#     skipped = 0
#     for i, name in enumerate(sorted(names)):
#         with zin.open(name) as f:
#             head = f.read(4)
#             if not _is_tiff_magic(head):
#                 skipped += 1
#                 continue
#             data = head + f.read()

#         try:
#             ds_bytes = downsample_tiff_bytes(data, DOWNSAMPLE, compression=COMPRESSION)
#             # Write to zip with same relative name (you can prefix if you want)
#             zout.writestr(name, ds_bytes)
#         except Exception as e:
#             skipped += 1
#             print(f"[skip] {name}: {e}")
#             continue

#         if (i+1) % 10 == 0:
#             print(f"Processed {i+1}/{len(names)}...")

# print(f"Done. Wrote downsampled TIFFs to {OUTPUT_ZIP}. Skipped {skipped} items.")

# ===============================================
# NEW TIFF → PNG MERGE + PNG TILING PIPELINE
# ===============================================

import os
from PIL import Image
import numpy as np
from pathlib import Path

# -------- TIFF → single PNG (combine z-layers) --------

def combine_tiff_layers_to_png(tiff_path, output_png):
    """
    Combine all layers of a multi-layer TIFF into a single PNG.
    Uses max projection (recommended for confocal morphology).
    """
    tiff = Image.open(tiff_path)

    frames = []
    try:
        while True:
            frame = np.array(tiff.convert("L"), dtype=np.float32)
            frames.append(frame)
            tiff.seek(tiff.tell() + 1)
    except EOFError:
        pass

    if len(frames) == 0:
        raise ValueError(f"No frames found in TIFF: {tiff_path}")

    # max projection across Z
    combined = np.max(frames, axis=0)
    combined = np.clip(combined, 0, 255).astype(np.uint8)

    Image.fromarray(combined).save(output_png)

# -------- Resize + tile PNGs --------

def resize_to_512(img):
    return img.resize((512, 512), Image.LANCZOS)

def split_image(img, n):
    tiles = []
    w, h = img.size
    tile_w = w // n
    tile_h = h // n
    for r in range(n):
        for c in range(n):
            box = (c*tile_w, r*tile_h, (c+1)*tile_w, (r+1)*tile_h)
            tiles.append(img.crop(box))
    return tiles

def process_png_folder(input_folder, output_folder, n=2):
    """
    Reads processed PNGs (generated from TIFF).
    Resizes each to 512x512.
    Splits into NxN tiles.
    Saves tiles with suffix _01, _02, ...
    """
    os.makedirs(output_folder, exist_ok=True)

    for fname in os.listdir(input_folder):
        if not fname.lower().endswith(".png"):
            continue

        img = Image.open(os.path.join(input_folder, fname))
        img = resize_to_512(img)
        tiles = split_image(img, n)

        base = os.path.splitext(fname)[0]
        for i, tile in enumerate(tiles, start=1):
            outname = f"{base}_{i:02d}.png"
            tile.save(os.path.join(output_folder, outname))

    print("✓ PNG tiling complete.")


Found 479 TIFFs
Processed 10/479 …
Processed 20/479 …
Processed 30/479 …
Processed 40/479 …
Processed 50/479 …
Processed 60/479 …
Processed 70/479 …
Processed 80/479 …
Processed 90/479 …
Processed 100/479 …
Processed 110/479 …
Processed 120/479 …
Processed 130/479 …
Processed 140/479 …
Processed 150/479 …
Processed 160/479 …
Processed 170/479 …
Processed 180/479 …
Processed 190/479 …
Processed 200/479 …
Processed 210/479 …
Processed 220/479 …
Processed 230/479 …
Processed 240/479 …
Processed 250/479 …
Processed 260/479 …
Processed 270/479 …
Processed 280/479 …
Processed 290/479 …
Processed 300/479 …
Processed 310/479 …
Processed 320/479 …
Processed 330/479 …
Processed 340/479 …
Processed 350/479 …
Processed 360/479 …
Processed 370/479 …
Processed 380/479 …
Processed 390/479 …
Processed 400/479 …
Processed 410/479 …
Processed 420/479 …
Processed 430/479 …
Processed 440/479 …
Processed 450/479 …
Processed 460/479 …
Processed 470/479 …
Done. Skipped: 0


In [None]:
# ======================================
# Preview PNG tiles instead of TIFF
# ======================================
from PIL import Image
import matplotlib.pyplot as plt

png_folder = Path("png_tiles")   # <-- change to your output folder
png_files = list(png_folder.glob("*.png"))

print(f"Found {len(png_files)} PNG tiles.")

if png_files:
    img = Image.open(png_files[0])
    plt.imshow(img, cmap="gray")
    plt.title(f"Preview: {png_files[0].name}")
    plt.axis("off")


# if OUTPUT_ZIP.exists():
#     with zipfile.ZipFile(OUTPUT_ZIP, "r") as zf:
#         # Collect all TIFF-like members
#         tif_files = [name for name in zf.namelist() if name.lower().endswith((".tif", ".tiff"))]
#         print(f"Found {len(tif_files)} TIFFs in zip:", tif_files[:5], "..." if len(tif_files) > 5 else "")

#         # Example: read one TIFF file into a numpy array
#         def read_tif_from_zip(zip_file, filename):
#             with zip_file.open(filename) as f:
#                 img_bytes = io.BytesIO(f.read())
#                 return tiff.imread(img_bytes)

#         # Example: preview shape of the first TIFF
#         if len(tif_files) > 0:
#             sample_img = read_tif_from_zip(zf, tif_files[0])
#             print("Example image shape:", sample_img.shape)

# else:
#     raise FileNotFoundError(f"ZIP archive not found at {OUTPUT_ZIP!r}. Please update the path.")

Found: 479 TIFFs.
First few: ['microscopy/LSDS-55_microscopy_1.1.tif', 'microscopy/LSDS-55_microscopy_1.1001.tif', 'microscopy/LSDS-55_microscopy_1.1002.tif', 'microscopy/LSDS-55_microscopy_1.1003.tif', 'microscopy/LSDS-55_microscopy_1.2.tif']
Shape of parsed TIFF stack: (1, 256, 256)


In [8]:
from pathlib import Path
import pandas as pd
import re, unicodedata

S_OSD_TXT = Path("s_OSD-627.txt")

def _norm(s: str) -> str:
    s = unicodedata.normalize("NFKC", str(s))
    s = s.lower().strip()
    s = re.sub(r"\s+", " ", s)
    # keep words but drop punctuation so we can match phrases robustly
    s = re.sub(r"[\[\]\(\)\{\}:;,_\-\/\\]+", " ", s)
    return s

def _find_col(cols, *aliases):
    """
    Return the first column whose normalized name contains ALL words
    of any alias phrase provided.
    """
    norm_map = {c: _norm(c) for c in cols}
    for alias in aliases:
        phrases = alias if isinstance(alias, (list, tuple)) else [alias]
        for phrase in phrases:
            phrase_norm = _norm(phrase)
            words = phrase_norm.split()
            for col, n in norm_map.items():
                if all(w in n for w in words):
                    return col
    return None

def parse_s_osd(path: Path) -> pd.DataFrame:
    df = pd.read_csv(path, sep="\t", dtype=str)
    df.columns = [c.strip() for c in df.columns]

    # Core columns
    col_sample = _find_col(
        df.columns,
        "sample name", "sample id"
    )
    col_source = _find_col(
        df.columns,
        "source name"
    )
    col_sf = _find_col(
        df.columns,
        # common ISA-tab encodings
        "factor value spaceflight",
        "characteristics spaceflight",
        "spaceflight"
    )
    col_time = _find_col(
        df.columns,
        "factor value time point",
        "factor value time",
        "time point",
        "time"
    )

    # Material shows up under many names; search broadly
    col_material = _find_col(
        df.columns,
        "factor value material",
        "characteristics material",
        "material",
        "growth surface",
        "factor value growth surface",
        "characteristics growth surface",
        "substrate",
        "coupon material",
        "alloy",
        "stainless steel"
    )

    # Medium often appears as Parameter Value[...] in ISA-tab
    col_medium = _find_col(
        df.columns,
        "parameter value sample media information",
        "sample media information",
        "factor value medium",
        "characteristics medium",
        "medium",
    )

    # Hard requirements (don’t fail on Material/Medium)
    required = {
        "Sample Name": col_sample,
        "Source Name": col_source,
        "Spaceflight": col_sf,
        "Time": col_time,
    }
    missing = [k for k, v in required.items() if v is None]
    if missing:
        raise ValueError(f"s_OSD file is missing expected columns (none matched): {missing}\n"
                         f"Available columns:\n{list(df.columns)}")

    # Build output
    out = pd.DataFrame({
        "sample_id": df[col_sample].str.strip(),
        "source_id": df[col_source].str.strip(),
        "spaceflight": df[col_sf].str.strip(),
        "time": df[col_time].str.strip(),
        # Optional fields with safe fallbacks
        "material": (df[col_material].str.strip() if col_material else "UnknownMaterial"),
        "medium": (df[col_medium].str.strip() if col_medium else "UnknownMedium"),
    })

    # Normalize numeric time
    out["time"] = pd.to_numeric(out["time"], errors="coerce").astype("Int64")

    # Fill blanks after strip
    out["material"] = out["material"].fillna("UnknownMaterial").replace("", "UnknownMaterial")
    out["medium"]   = out["medium"].fillna("UnknownMedium").replace("", "UnknownMedium")

    # Generate TIFF candidate names that match your naming convention
    out["microscopy_tif"] = "LSDS-55_microscopy_" + out["sample_id"] + ".tif"
    out["microscopy_tif_alt_lower_g"] = "LSDS-55_microscopy_" + out["sample_id"].str.replace(r"^G", "g", regex=True) + ".tif"
    out["microscopy_tif_alt_double_dot"] = "LSDS-55_microscopy_" + out["sample_id"].str.replace(r"^([Gg]\d+)\.(\d+)$", r"\1..\2", regex=True) + ".tif"

    # Debug: show what we matched so you can confirm quickly
    print("Matched columns →",
          dict(sample=col_sample, source=col_source, spaceflight=col_sf,
               time=col_time, material=col_material, medium=col_medium))

    return out[[
        "sample_id","source_id","spaceflight","material","medium","time",
        "microscopy_tif","microscopy_tif_alt_lower_g","microscopy_tif_alt_double_dot"
    ]]

# Use the parser
if S_OSD_TXT.exists():
    s_osd_df = parse_s_osd(S_OSD_TXT)
else:
    s_osd_df = pd.DataFrame(columns=[
        "sample_id","source_id","spaceflight","material","medium","time",
        "microscopy_tif","microscopy_tif_alt_lower_g","microscopy_tif_alt_double_dot"
    ])

display(s_osd_df.head(10))
print("Parsed rows:", len(s_osd_df))
print("Unique materials (first 10):", s_osd_df["material"].unique()[:10])
print("Unique media (first 10):", s_osd_df["medium"].unique()[:10])


Unnamed: 0,sample_id,source_id,spaceflight,material,medium,time,microscopy_tif,microscopy_tif_alt_lower_g,microscopy_tif_alt_double_dot


Parsed rows: 0
Unique materials (first 10): []
Unique media (first 10): []


In [9]:

# --- Fallback linking via provided manifests ---
# Map images (their filename suffix) to suggested RNA column names.
manifest_df = pd.read_csv(MANIFEST_CSV) if MANIFEST_CSV.exists() else pd.DataFrame()
provisional_df = pd.read_csv(PROVISIONAL_CSV) if PROVISIONAL_CSV.exists() else pd.DataFrame()

print("Manifest preview:")
display(manifest_df.head(5) if not manifest_df.empty else "No manifest found")

print("Provisional preview:")
display(provisional_df.head(5) if not provisional_df.empty else "No provisional mapping found")

# Build an image->RNA columns mapping (list) using suggested_rna_cols
image_to_rna_cols = {}
if not provisional_df.empty:
    for _,row in provisional_df.iterrows():
        key = str(row.get("tif_suffix","")).strip()
        cols = []
        for c in ["suggested_rna_cols","suggested_rna_cols_2","suggested_rna_cols_3","suggested_rna_cols_4"]:
            val = row.get(c, None)
            if isinstance(val, str) and val.strip():
                cols.extend([s.strip() for s in val.split(",") if s.strip()])
        image_to_rna_cols[key] = sorted(set(cols))

image_to_rna_cols


Manifest preview:


Unnamed: 0,tif_file\tsample_id\tsample_key\tspaceflight\ttime\tmaterial_type\tmedium\trna_available\trna_sample_name
0,LSDS-55_microscopy_1.1.tif\t1.1\t1.1\tSpace Fl...
1,LSDS-55_microscopy_1.1001.tif\t1.1\t1.1\tSpace...
2,LSDS-55_microscopy_1.1002.tif\t1.1\t1.1\tSpace...
3,LSDS-55_microscopy_1.1003.tif\t1.1\t1.1\tSpace...
4,LSDS-55_microscopy_1.2.tif\t1.2\t1.2\tSpace Fl...


Provisional preview:


Unnamed: 0,tif_file,sample_key,spaceflight,time,material_type,rna_available,rna_sample_name,tif_relpath,modality
0,LSDS-55_microscopy_G1..1001.tif,G1.1,Ground,1.0,Cells,False,,LSDS-55_microscopy_G1..1001.tif,microscopy
1,LSDS-55_microscopy_G1.1.tif,G1.1,Ground,1.0,Cells,False,,LSDS-55_microscopy_G1.1.tif,microscopy
2,LSDS-55_microscopy_G1.1002.tif,G1.1,Ground,1.0,Cells,False,,LSDS-55_microscopy_G1.1002.tif,microscopy
3,LSDS-55_microscopy_G1.1003.tif,G1.1,Ground,1.0,Cells,False,,LSDS-55_microscopy_G1.1003.tif,microscopy
4,LSDS-55_microscopy_G1.2.tif,G1.2,Ground,1.0,Cells,False,,LSDS-55_microscopy_G1.2.tif,microscopy


{'': []}

In [26]:
!pip install openpyxl



In [10]:

# --- Load RNA tables and standardize column keys ---
def load_rna_tables(xlsx_paths):
    tables = {}
    for p in xlsx_paths:
        if not Path(p).exists():
            print(f"RNA sheet missing: {p}")
            continue
        try:
            df = pd.read_excel(p)
        except Exception:
            # Some Excel files need engine specification depending on environment; try again
            df = pd.read_excel(p, engine="openpyxl")
        # Normalize column names (strip spaces)
        df.columns = [str(c).strip() for c in df.columns]
        tables[Path(p).stem] = df
    return tables

rna_tables = load_rna_tables(RNA_SHEETS)
for name, df in rna_tables.items():
    print(name, df.shape, "cols:", list(df.columns[:12]), "...")


RNA sheet missing: Normalized_counts\SSMicro_day1v3_bytime.xlsx
RNA sheet missing: Normalized_counts\SSGround_day1v3_bytime.xlsx
RNA sheet missing: Normalized_counts\SS_day3_bygravity.xlsx
RNA sheet missing: Normalized_counts\Micro_day3_bymaterial.xlsx
RNA sheet missing: Normalized_counts\LIS_day3_bygravity.xlsx
RNA sheet missing: Normalized_counts\Ground_day3_bymaterial.xlsx


In [11]:
import re

def collect_available_rna_prefixes(rna_tables):
    prefixes_by_sf = {'G': set(), 'F': set()}
    cols_by_prefix = {}
    pat = re.compile(r'^(?P<sf>[GF])(?P<prefix>\d+)\.(?P<rest>.+)$')
    for _, df in (rna_tables or {}).items():
        for c in df.columns:
            m = pat.match(str(c))
            if not m:
                continue
            sf = m.group('sf')
            pfx = int(m.group('prefix'))
            prefixes_by_sf[sf].add(pfx)
            cols_by_prefix.setdefault((sf, pfx), set()).add(c)
    return {k: sorted(v) for k,v in prefixes_by_sf.items()}, cols_by_prefix

def parse_img_token(img_name):
    m = re.search(r'([GF])(\d+)\.\.?(\d+)', img_name)
    if not m: return None
    sf = m.group(1); base_idx = int(m.group(2)); rest = m.group(3)
    return sf, base_idx, rest

def nearest_prefix(sf, want_prefix, prefixes_by_sf):
    avail = prefixes_by_sf.get(sf, [])
    if not avail: return None
    return min(avail, key=lambda x: abs(x - want_prefix))

def map_img_to_rna_cols(img_name, prefixes_by_sf, cols_by_prefix):
    tok = parse_img_token(img_name)
    if tok is None: return []
    sf, base_idx, rest = tok
    pfx = nearest_prefix(sf, base_idx, prefixes_by_sf)
    if pfx is None: return []
    cand = f"{sf}{pfx}.{rest}"
    pool = cols_by_prefix.get((sf, pfx), set())
    return [cand] if cand in pool else []

prefixes_by_sf, cols_by_prefix = collect_available_rna_prefixes(rna_tables)


In [None]:
import io, zipfile
import numpy as np
import tifffile as tiff
from pathlib import Path

# --- Configure your zip path (already set earlier) ---
# DATA_ZIP_PATH = Path("microscopy_images.zip")

# --- Magic byte check for TIFF/BigTIFF ---
def _is_tiff_magic(first4: bytes) -> bool:
    # Classic TIFF:  b'II*\x00' or b'MM\x00*'
    # BigTIFF:       b'II+\x00' or b'MM\x00+'
    return first4 in (b'II*\x00', b'MM\x00*', b'II+\x00', b'MM\x00+')

def read_tif_from_zip_validated(zf: zipfile.ZipFile, member: str) -> np.ndarray:
    # Quick header check
    with zf.open(member) as f:
        head = f.read(4)
        if not _is_tiff_magic(head):
            raise ValueError(f"Not a TIFF by magic: {member!r} header={head!r}")
        # Read full payload now
        data = head + f.read()
    # arr = tiff.imread(io.BytesIO(data))
    arr = np.array(Image.open(io.BytesIO(data)))

    # Ensure [T,H,W]
    if arr.ndim == 2:
        arr = arr[None, ...]
    else:
        arr = np.squeeze(arr)
        if arr.ndim == 2:
            arr = arr[None, ...]
        elif arr.ndim != 3:
            raise ValueError(f"Unexpected array ndim={arr.ndim} for {member}")
    return arr.astype(np.float32)

# --- Build tif_map safely from ZIP ---
tif_map = {}
skipped = []

if not OUTPUT_ZIP.exists():
    raise FileNotFoundError(f"ZIP archive not found: {OUTPUT_ZIP}")

with zipfile.ZipFile(OUTPUT_ZIP, "r") as zf:
    candidates = [
        name for name in zf.namelist()
        if name.lower().endswith((".tif", ".tiff"))
        and not name.endswith("/")
        and not name.startswith("__MACOSX/")
    ]
    print(f"Scanning {len(candidates)} TIFF-like entries in ZIP...")

    for name in sorted(candidates):
        try:
            arr = read_tif_from_zip_validated(zf, name)  # or your read_tif_from_zip(...)
            base = Path(name).name                        # <<<<<< KEY CHANGE: basename
            tif_map[base] = arr
        except Exception as e:
            skipped.append((name, str(e)))

print(f"Loaded {len(tif_map)} TIFFs into memory (keyed by basename).")
if skipped:
    print(f"Skipped {len(skipped)} entries (not TIFF or unreadable); showing up to 10:")
    for n, msg in skipped[:10]:
        print("  -", n, "->", msg)

# Optional: quick peek at shapes
summary = {k: v.shape for k, v in list(tif_map.items())[:5]}
print("Sample shapes:", summary)


Scanning 479 TIFF-like entries in ZIP...
Loaded 479 TIFFs into memory (keyed by basename).
Sample shapes: {'LSDS-55_microscopy_1.1.tif': (1, 256, 256), 'LSDS-55_microscopy_1.1001.tif': (1, 256, 256), 'LSDS-55_microscopy_1.1002.tif': (1, 256, 256), 'LSDS-55_microscopy_1.1003.tif': (1, 256, 256), 'LSDS-55_microscopy_1.2.tif': (1, 256, 256)}


In [13]:
def resolve_tif_name(row, available_names: set[str]):
    """
    Given a row of s_osd_df and a set of available TIFF names (keys of tif_map),
    return the first matching filename or None.
    """
    candidates = []
    for col in ("microscopy_tif", "microscopy_tif_alt_lower_g", "microscopy_tif_alt_double_dot"):
        if col in row and pd.notna(row[col]) and str(row[col]).strip():
            candidates.append(str(row[col]).strip())
    for c in candidates:
        if c in available_names:
            return c
    return None


In [14]:
# --- Trajectory construction (fixed) ---
def build_trajectories(s_osd_df, image_to_rna_cols, tif_map):
    groups = {}
    if not s_osd_df.empty:
        for (sf, mat, med), sub in s_osd_df.groupby(['spaceflight','material','medium']):
            sub2 = sub.sort_values('time')
            key = (sf, mat, med)
            times, samples, images, rna_cols = [], [], [], []
            for _, r in sub2.iterrows():
                times.append(int(r['time']))
                samples.append(r['sample_id'])
                img_name = r['microscopy_tif']
                images.append(img_name)

                cols = image_to_rna_cols.get(r['sample_id'], [])
                if not cols:
                    cols = map_img_to_rna_cols(img_name, prefixes_by_sf, cols_by_prefix)
                rna_cols.append(cols)
            groups[key] = dict(times=times, samples=samples, images=images, rna_cols=rna_cols)
    else:
        # Fallback: infer by numeric suffix present in available names
        key = ("UnknownSpaceflight", "UnknownMaterial", "UnknownMedium")
        entries = []
        for tif_name in sorted(available):
            m = re.search(r"(\d+\.\d+)", tif_name)
            t = None
            if m:
                try:
                    t = int(m.group(1).split(".")[1])
                except Exception:
                    t = None
            suffix = m.group(1) if m else None
            entries.append((tif_name, t, suffix))
        entries.sort(key=lambda x: (x[1] if x[1] is not None else 9999, x[0]))
        times   = [e[1] for e in entries]
        images  = [e[0] for e in entries]
        samples = [e[2] for e in entries]  # suffix as a stand-in
        rna_cols = [image_to_rna_cols.get(e[2], []) for e in entries]
        groups[key] = dict(times=times, samples=samples, images=images, rna_cols=rna_cols)

    return groups


In [15]:
# --- Patch 3: Rebuild and inspect trajectories ---
groups = build_trajectories(s_osd_df, image_to_rna_cols, tif_map)

# Debug: print a quick summary
for k, b in list(groups.items())[:5]:
    n_valid = sum(1 for x in b['images'] if x)
    print(f"{k}: times={len(b['times'])}, resolvable_images={n_valid}/{len(b['images'])}")


NameError: name 'available' is not defined

In [33]:
from typing import Optional, Tuple, Dict
from sklearn.decomposition import PCA

def _infer_day_from_table_name(name: str) -> Optional[int]:
    low = name.lower()
    if "day1" in low:
        return 1
    if "day3" in low:
        return 3
    return None

def _candidate_rna_columns(df: pd.DataFrame) -> list:
    # Keep only columns that look like RNA sample columns such as G4.1, F16.2, etc.
    keep_pref = ("G4.", "F4.", "G10.", "F10.", "G16.", "F16.", "G17.", "F17.")
    return [c for c in df.columns if any(c.startswith(p) for p in keep_pref)]

def _prefix_of(col: str) -> Optional[str]:
    # "G4.1" -> "G4", "F16.7" -> "F16"
    m = re.match(r"^([GF]\d+)\.\d+$", col)
    return m.group(1) if m else None

def build_rna_group_features(
    rna_tables: Dict[str, pd.DataFrame],
    use_pca: bool = True,
    pca_dim: int = 16
) -> Tuple[Dict[Tuple[str, str, str, int], np.ndarray], dict]:
    """
    Returns:
      rna_group_vecs: dict keyed by (spaceflight, material, medium, day) -> np.ndarray (fixed length)
      feature_meta:   dict with keys:
                      - 'cols': list of aligned feature columns (before PCA)
                      - 'pca':  fitted PCA object or None
                      - 'rna_dim': final dimension after PCA (or original width if PCA not used)
                      - 'map': helper mapping (rna_prefix -> row index in aligned matrix)
    Notes:
      * Because we don't have an exact mapping from (sf,mat,med,day) to RNA prefix (e.g., G4/F16),
        we first build vectors per RNA *prefix* (G4, F16, ...). Later, when attaching to images/groups,
        we’ll drop a zero vector if no compatible prefix is found.
    """
    # 1) Collect all RNA-like columns across all tables
    table_keep = {name: _candidate_rna_columns(df) for name, df in rna_tables.items()}
    all_cols = sorted({c for cols in table_keep.values() for c in cols})
    if not all_cols:
        # No RNA-style columns anywhere -> return empty/zero config
        return {}, {'cols': [], 'pca': None, 'rna_dim': 0, 'map': {}}

    # 2) Aggregate per-table column means for stability
    per_table_means = {}
    for name, df in rna_tables.items():
        cols = table_keep.get(name, [])
        if not cols:
            continue
        means = {}
        for c in cols:
            s = pd.to_numeric(df[c], errors="coerce")
            means[c] = float(np.nanmean(s.to_numpy())) if s.notna().any() else 0.0
        per_table_means[name] = means

    # 3) Build per-RNA-prefix Series (index = all_cols, values = means, fill missing with 0)
    #    Example prefixes: G4, F4, G10, F10, G16, F16, G17, F17
    prefix_set = sorted({_prefix_of(c) for c in all_cols if _prefix_of(c) is not None})
    if not prefix_set:
        return {}, {'cols': [], 'pca': None, 'rna_dim': 0, 'map': {}}

    rows = []
    row_keys = []   # each row key will be the RNA-prefix (e.g., "G4")
    for pref in prefix_set:
        # Merge from all tables (some cols might exist in >1 table; average them)
        vals = {}
        for col in all_cols:
            if _prefix_of(col) == pref:
                # collect from all tables that have this column
                have = [per_table_means[t][col] for t in per_table_means if col in per_table_means[t]]
                vals[col] = float(np.mean(have)) if len(have) > 0 else 0.0
            else:
                vals[col] = 0.0
        rows.append(pd.Series(vals, index=all_cols, dtype=float))
        row_keys.append(pref)

    aligned = pd.DataFrame(rows, index=row_keys, columns=all_cols).fillna(0.0)  # [num_prefixes, num_features]

    # 4) Optional PCA to fixed dim (safe guards)
    fitted_pca = None
    X = aligned.to_numpy(dtype=np.float32)
    # If all-zero rows, PCA will complain; keep only rows with some variance to fit
    fit_mask = (np.abs(X).sum(axis=1) > 0)
    X_fit = X[fit_mask]
    final_dim = X.shape[1]
    if use_pca and X_fit.shape[0] >= 2 and X_fit.shape[1] > 1:
        n_comp = min(pca_dim, X_fit.shape[1], X_fit.shape[0])  # cannot exceed #rows or #features
        if n_comp >= 2:
            fitted_pca = PCA(n_components=n_comp, svd_solver='auto', random_state=0)
            fitted_pca.fit(X_fit)
            X = fitted_pca.transform(X)
            final_dim = X.shape[1]
        else:
            # too few samples to PCA meaningfully; skip
            fitted_pca = None
            final_dim = X.shape[1]
    else:
        fitted_pca = None
        final_dim = X.shape[1]

    # 5) Package: we do not yet know (sf, mat, med, day) mapping here, so we return prefix vectors only.
    #    Callers will choose a prefix (e.g., based on sample-id like 'G4.*') or fall back to zeros.
    rna_group_vecs = {}  # we’ll fill at attach time; keeping empty here is fine
    feature_meta = {
        'cols': all_cols,
        'pca': fitted_pca,
        'rna_dim': int(final_dim),
        'map': {pref: X[i] for i, pref in enumerate(row_keys)}
    }
    return rna_group_vecs, feature_meta


In [34]:
rna_group_vecs, rna_feature_meta = build_rna_group_features(rna_tables, pca_dim=16, use_pca=True)
GLOBAL_RNA_DIM = rna_feature_meta.get('rna_dim', 0)
RNA_PREFIX_TO_VEC = rna_feature_meta.get('map', {})  # e.g., {'G4': np.array([...]), 'F16': ...}
print("RNA feature meta:", rna_feature_meta)

RNA feature meta: {'cols': ['F10.1', 'F10.2', 'F10.3', 'F10.4', 'F10.5', 'F10.6', 'F10.8', 'F16.1', 'F16.2', 'F16.3', 'F16.4', 'F16.5', 'F16.6', 'F16.7', 'F17.1', 'F17.2', 'F17.3', 'F17.4', 'F4.1', 'F4.2', 'F4.3', 'F4.4', 'F4.6', 'F4.7', 'F4.8', 'G10.1', 'G10.2', 'G10.3', 'G10.4', 'G10.5', 'G10.6', 'G10.8', 'G16.1', 'G16.2', 'G16.3', 'G16.4', 'G16.5', 'G16.6', 'G16.7', 'G17.1', 'G17.2', 'G17.3', 'G17.4', 'G4.1', 'G4.2', 'G4.3', 'G4.4', 'G4.6', 'G4.7', 'G4.8'], 'pca': PCA(n_components=8, random_state=0), 'rna_dim': 8, 'map': {'F10': array([-1.1397461e+03, -2.7714600e+03,  2.5533074e+03,  7.7301483e+02,
        8.2930435e+01,  2.1947708e+01,  4.6491943e+01, -3.3378601e-05],
      dtype=float32), 'F16': array([-6.6841602e+02, -3.9546271e+02, -2.4458098e+03,  2.2803062e+03,
        1.3993834e+02,  3.5055450e+01,  6.7505676e+01,  1.2874603e-04],
      dtype=float32), 'F17': array([-3.37156677e+02, -1.30221024e+02, -4.32754852e+02, -6.27714844e+02,
       -2.20994293e+02, -8.36751175e+01, -9

In [35]:
def _guess_rna_prefix(sample_or_img: str) -> Optional[str]:
    # Look for patterns like 'G4.' or 'F16.' in the name; return 'G4' or 'F16'
    m = re.search(r"([GF]\d+)\.(?:\d+)", sample_or_img)
    return m.group(1) if m else None


In [36]:
import torch.nn.functional as F

RESIZE_HW = (512, 512)
# --- Dataset with optional RNA auxiliary features ---
class TrajectoryDataset(Dataset):
    def __init__(self, groups, tif_map, use_rna=True, resize_hw=RESIZE_HW):
        self.samples = []
        self.use_rna = use_rna and (GLOBAL_RNA_DIM > 0)
        self.resize_hw = resize_hw

        for key, bundle in groups.items():
            times = bundle['times']
            images = bundle['images']

            for i in range(len(times)-1):
                img_in_name = images[i]
                img_out_name = images[i+1]
                if img_in_name not in tif_map or img_out_name not in tif_map:
                    continue

                X_seq = tif_map[img_in_name]   # [T,H,W] or [1,H,W]
                Y_seq = tif_map[img_out_name]  # [T,H,W] or [1,H,W]

                # --- NEW: fixed-length RNA vector (or zeros) ---
                if self.use_rna and GLOBAL_RNA_DIM > 0:
                    pref_in  = _guess_rna_prefix(img_in_name)
                    pref_out = _guess_rna_prefix(img_out_name)
                    # use the input prefix if available; fallback to output; else zeros
                    vec = None
                    if pref_in in RNA_PREFIX_TO_VEC:
                        vec = RNA_PREFIX_TO_VEC[pref_in]
                    elif pref_out in RNA_PREFIX_TO_VEC:
                        vec = RNA_PREFIX_TO_VEC[pref_out]
                    if vec is None:
                        vec = np.zeros((GLOBAL_RNA_DIM,), dtype=np.float32)
                    rna_vec = vec.astype(np.float32)
                else:
                    rna_vec = None  # will be converted to zeros later to avoid collate errors

                self.samples.append({
                    "X_seq": X_seq,
                    "Y_seq": Y_seq,
                    "rna_seq": rna_vec,
                    "meta": dict(key=key, t_in=times[i], t_out=times[i+1],
                                 img_in=img_in_name, img_out=img_out_name)
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        # X_seq, Y_seq are [T,H,W] numpy arrays
        X = item["X_seq"][None, ...]  # [1,T,H,W] treat T as channels
        Y = item["Y_seq"][None, ...]  # [1,T,H,W]
        X = torch.from_numpy(X).float()
        Y = torch.from_numpy(Y).float()

        # --- normalize spatial sizes ---
        if self.resize_hw is not None:
            if X.shape[-2:] != self.resize_hw:
                X = F.interpolate(X, size=self.resize_hw, mode="area")
            if Y.shape[-2:] != self.resize_hw:
                Y = F.interpolate(Y, size=self.resize_hw, mode="area")

        # RNA vector (as you implemented) -> aux tensor or zeros
        if self.use_rna and GLOBAL_RNA_DIM > 0:
            if item["rna_seq"] is None:
                aux = torch.zeros((GLOBAL_RNA_DIM,), dtype=torch.float32)
            else:
                aux = torch.from_numpy(item["rna_seq"]).float()
        else:
            aux = None

        return X, Y, aux, item["meta"]

print("Sample TIFF keys:", list(tif_map.keys())[:5])
dataset = TrajectoryDataset(groups, tif_map, use_rna=True)
print("Dataset samples:", len(dataset))
if len(dataset) > 0:
    X, Y, aux, meta = dataset[0]
    print("Sample[0] shapes:", X.shape, Y.shape)
if len(dataset):
    X, Y, aux, meta = dataset[0]
    print("X shape:", (tuple(X.shape) if hasattr(X, 'shape') else type(X)))
    print("Y shape:", (tuple(Y.shape) if hasattr(Y, 'shape') else type(Y)))
    print("Aux RNA shape:", (tuple(aux.shape) if aux is not None else None))
    print("Meta:", meta)


Sample TIFF keys: ['LSDS-55_microscopy_1.1.tif', 'LSDS-55_microscopy_1.1001.tif', 'LSDS-55_microscopy_1.1002.tif', 'LSDS-55_microscopy_1.1003.tif', 'LSDS-55_microscopy_1.2.tif']
Dataset samples: 467
Sample[0] shapes: torch.Size([1, 1, 512, 512]) torch.Size([1, 1, 512, 512])
X shape: (1, 1, 512, 512)
Y shape: (1, 1, 512, 512)
Aux RNA shape: (8,)
Meta: {'key': ('Ground', 'Cells', 'LB broth (Lennox) supplemented with KNO3'), 't_in': 1, 't_out': 1, 'img_in': 'LSDS-55_microscopy_G1.1002.tif', 'img_out': 'LSDS-55_microscopy_G1.1003.tif'}


In [37]:

# --- Simple ConvNet baseline with optional RNA head (replace with ConvLSTM as needed) ---
if 'torch' in globals():
    class ConvLSTMCell(nn.Module):
        def __init__(self, in_channels, hidden_channels, kernel_size=3, padding=1):
            super().__init__()
            self.hidden_channels = hidden_channels
            self.conv = nn.Conv2d(in_channels + hidden_channels,
                                4 * hidden_channels,
                                kernel_size=kernel_size,
                                padding=padding)

        def forward(self, x_t, state):
            # x_t: [B, C, H, W]; state: (h, c) each [B, hidden, H, W]
            h_prev, c_prev = state
            gates = self.conv(torch.cat([x_t, h_prev], dim=1))
            i, f, o, g = torch.chunk(gates, 4, dim=1)
            i = torch.sigmoid(i); f = torch.sigmoid(f); o = torch.sigmoid(o); g = torch.tanh(g)
            c = f * c_prev + i * g
            h = o * torch.tanh(c)
            return h, c

        def init_state(self, B, H, W, device=None, dtype=None):
            h = torch.zeros(B, self.hidden_channels, H, W, device=device, dtype=dtype)
            c = torch.zeros(B, self.hidden_channels, H, W, device=device, dtype=dtype)
            return h, c

    class ConvLSTMLayer(nn.Module):
        """
        Full-sequence ConvLSTM.
        Input : [B, T, C, H, W]
        Output: [B, T, hidden, H, W]
        """
        def __init__(self, in_channels, hidden_channels, kernel_size=3, padding=1):
            super().__init__()
            self.cell = ConvLSTMCell(in_channels, hidden_channels, kernel_size, padding)

        def forward(self, x):
            B, T, C, H, W = x.shape
            device, dtype = x.device, x.dtype
            h, c = self.cell.init_state(B, H, W, device, dtype)
            outs = []
            for t in range(T):
                h, c = self.cell(x[:, t], (h, c))
                outs.append(h)
            return torch.stack(outs, dim=1)  # [B, T, hidden, H, W]


# ===== ConvLSTM SmallPredictor (Keras-equivalent) =====
    class SmallPredictor(nn.Module):
        """
        Matches:
        [ConvLSTM x4 (40 ch, return_sequences), BN between] -> Conv3D(1, k=(1,3,3)) -> sigmoid
        Input : [B, 1, T, H, W]
        Output: [B, 1, T, H, W]
        """
        def __init__(self, use_rna=False, rna_dim=0):
            super().__init__()
            self.use_rna = use_rna and (rna_dim > 0)

            self.l1 = ConvLSTMLayer(1, 40, 3, 1)
            self.bn1 = nn.BatchNorm3d(40)
            self.l2 = ConvLSTMLayer(40, 40, 3, 1)
            self.bn2 = nn.BatchNorm3d(40)
            self.l3 = ConvLSTMLayer(40, 40, 3, 1)
            self.bn3 = nn.BatchNorm3d(40)
            self.l4 = ConvLSTMLayer(40, 40, 3, 1)
            self.bn4 = nn.BatchNorm3d(40)

            # temporal kernel = 1 preserves T
            self.head = nn.Conv3d(40, 1, kernel_size=(1, 3, 3), padding=(0, 1, 1))

            if self.use_rna:
                self.rna_mlp = nn.Sequential(
                    nn.Linear(rna_dim, 16),
                    nn.ReLU(),
                    nn.Linear(16, 1)  # scalar per-batch
                )

        @staticmethod
        def _bn3d(bn, y_btchw):
            y_bcthw = y_btchw.permute(0, 2, 1, 3, 4)  # [B,C,T,H,W]
            y_bcthw = bn(y_bcthw)
            return y_bcthw.permute(0, 2, 1, 3, 4)     # [B,T,C,H,W]

        def forward(self, x, rna=None):
            # x: [B, 1, T, H, W] -> [B, T, 1, H, W]
            B, C, T, H, W = x.shape
            y = x.permute(0, 2, 1, 3, 4)

            y = self._bn3d(self.bn1, self.l1(y))
            y = self._bn3d(self.bn2, self.l2(y))
            y = self._bn3d(self.bn3, self.l3(y))
            y = self._bn3d(self.bn4, self.l4(y))

            # head expects [B,C,T,H,W]
            y = y.permute(0, 2, 1, 3, 4)  # [B,40,T,H,W]

            if self.use_rna and (rna is not None):
                bias = self.rna_mlp(rna).view(B, 1, 1, 1, 1)
                y = y + bias

            y = self.head(y)           # [B,1,T,H,W]
            y = torch.sigmoid(y)
            return y

    # Probe RNA dim from dataset (first non-null)
    rna_dim = 0
    for i in range(len(dataset)):
        _,_,aux,_ = dataset[i]
        if aux is not None:
            rna_dim = aux.shape[-1]
            break
    print("RNA dim:", rna_dim)

    model = SmallPredictor(use_rna=(rna_dim>0), rna_dim=rna_dim)
    print(model)

else:
    print("Torch not available; skip model construction.")


RNA dim: 8
SmallPredictor(
  (l1): ConvLSTMLayer(
    (cell): ConvLSTMCell(
      (conv): Conv2d(41, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (bn1): BatchNorm3d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (l2): ConvLSTMLayer(
    (cell): ConvLSTMCell(
      (conv): Conv2d(80, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (bn2): BatchNorm3d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (l3): ConvLSTMLayer(
    (cell): ConvLSTMCell(
      (conv): Conv2d(80, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (bn3): BatchNorm3d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (l4): ConvLSTMLayer(
    (cell): ConvLSTMCell(
      (conv): Conv2d(80, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (bn4): BatchNorm3d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (head): Conv3d(40, 1, kernel_size=(1, 3, 3), stride=(1,

In [38]:
import torch

def robust_norm01(t: torch.Tensor, q_low=0.01, q_high=0.99, eps=1e-6):
    # t: [B,1,H,W] float
    # compute robust min/max per-sample using quantiles
    q1 = torch.quantile(t.flatten(2), q_low, dim=2, keepdim=True)
    q9 = torch.quantile(t.flatten(2), q_high, dim=2, keepdim=True)
    span = (q9 - q1).clamp_min(eps)
    t_flat = (t.flatten(2) - q1).clamp_min(0) / span
    t = t_flat.view_as(t).clamp(0, 1)
    return t


In [39]:
# --- Memory-safe ConvLSTM training loop (drop-in replacement) ---
if 'torch' in globals() and len(dataset) > 0:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    from torch.nn.utils import clip_grad_norm_

    # --------- knobs you can tweak without touching the core loop ----------
    BATCH_SIZE = 1                 # bump if you can
    NUM_WORKERS = 2                # >0 speeds CPU->GPU input
    PIN_MEMORY = True              # good for CUDA
    T_SUB = 1                      # e.g., 2 = use every 2nd frame (cuts memory ~1/2)
    CROP_HW = None                 # e.g., (256,256) to crop HxW from top-left
    MAX_GRAD_NORM = 1.0
    LR = 3e-4
    WD = 1e-5
    HUBER_BETA = 0.01

    # cudnn autotune can help perf; safe for fixed-size inputs
    torch.backends.cudnn.benchmark = True

    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        drop_last=False,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
    loss_fn = nn.SmoothL1Loss(beta=HUBER_BETA)

    use_amp = (device.type == "cuda")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    # ---------- helpers ----------
    def robust_norm_seq(x):
        # x: [B, C, T, H, W]; returns [B, C, T, H, W] in [0,1]
        B, C, T, H, W = x.shape
        xf = x.reshape(B, C, T * H * W)
        q1 = torch.quantile(xf, 0.01, dim=2, keepdim=True)
        q9 = torch.quantile(xf, 0.99, dim=2, keepdim=True)
        span = (q9 - q1).clamp_min(1e-6)
        xf = ((xf - q1).clamp_min(0) / span).clamp(0, 1)
        return xf.view(B, C, T, H, W)

    def maybe_subsample_time(z, step=T_SUB):
        # z: [B, C, T, H, W]
        if step is None or step <= 1:
            return z
        return z[:, :, ::step]

    def maybe_crop_hw(z, crop=CROP_HW):
        # z: [B, C, T, H, W]
        if not crop:
            return z
        Hc, Wc = crop
        return z[..., :Hc, :Wc]

    # ---------- train ----------
    model.train()
    for step, (X, Y, aux, meta) in enumerate(loader):
        # Shapes to device
        X = X.to(device, dtype=torch.float32, non_blocking=True)  # [B,1,T,H,W]
        Y = Y.to(device, dtype=torch.float32, non_blocking=True)  # [B,1,T,H,W]

        # Optional temporal subsampling & cropping (reduces memory a lot)
        if T_SUB and T_SUB > 1:
            X = maybe_subsample_time(X, T_SUB)
            Y = maybe_subsample_time(Y, T_SUB)
        if CROP_HW:
            X = maybe_crop_hw(X, CROP_HW)
            Y = maybe_crop_hw(Y, CROP_HW)

        # RNA aux handling
        if hasattr(model, "rna_mlp") and hasattr(model.rna_mlp[0], "in_features"):
            aux_dim = model.rna_mlp[0].in_features
        else:
            aux_dim = 0
        if aux_dim > 0:
            aux = aux.to(device, dtype=torch.float32, non_blocking=True) if (aux is not None) \
                  else torch.zeros((X.shape[0], aux_dim), device=device)
        else:
            aux = None

        # Normalization OUTSIDE autograd to avoid huge graphs
        with torch.no_grad():
            Xn = robust_norm_seq(X)
            Yn = robust_norm_seq(Y)

        opt.zero_grad(set_to_none=True)

        # Forward + loss in AMP (on GPU) to halve activation memory
        with torch.cuda.amp.autocast(enabled=use_amp):
            pred = model(Xn, rna=aux)  # [B,1,T,H,W]

            # If a spatial mismatch slipped in, align Yn to pred
            if Yn.shape[-2:] != pred.shape[-2:]:
                B, C, T, H, W = Yn.shape
                Yn_4d = Yn.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)  # [B*T,1,H,W]
                Yn_4d = F.interpolate(Yn_4d, size=pred.shape[-2:], mode="bilinear", align_corners=False)
                Hp, Wp = pred.shape[-2:]
                Yn = Yn_4d.reshape(B, T, C, Hp, Wp).permute(0, 2, 1, 3, 4)

            loss = loss_fn(pred, Yn)

        # Backward (scaled) + clip + step
        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        clip_grad_norm_(model.parameters(), max_norm=MAX_GRAD_NORM)
        scaler.step(opt)
        scaler.update()

        # (Optional) lightweight logging
        if (step % 10) == 0:
            print(f"step {step:05d} | loss {loss.item():.5f} | X {tuple(X.shape)}")

        # Emergency break if something grows without bound (rare)
        if not torch.isfinite(loss):
            print("Non-finite loss detected; breaking to protect the run.")
            break

else:
    print("Skipping training (no torch or no samples).")


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
  with torch.cuda.amp.autocast(enabled=use_amp):


step 00000 | loss 0.48366 | X (1, 1, 1, 512, 512)
step 00010 | loss 0.35357 | X (1, 1, 1, 512, 512)
step 00020 | loss 0.29567 | X (1, 1, 1, 512, 512)
step 00030 | loss 0.25087 | X (1, 1, 1, 512, 512)
step 00040 | loss 0.20535 | X (1, 1, 1, 512, 512)
step 00050 | loss 0.20392 | X (1, 1, 1, 512, 512)
step 00060 | loss 0.20267 | X (1, 1, 1, 512, 512)
step 00070 | loss 0.17091 | X (1, 1, 1, 512, 512)
step 00080 | loss 0.19620 | X (1, 1, 1, 512, 512)
step 00090 | loss 0.22131 | X (1, 1, 1, 512, 512)
step 00100 | loss 0.15991 | X (1, 1, 1, 512, 512)
step 00110 | loss 0.25544 | X (1, 1, 1, 512, 512)
step 00120 | loss 0.19646 | X (1, 1, 1, 512, 512)
step 00130 | loss 0.13182 | X (1, 1, 1, 512, 512)
step 00140 | loss 0.22160 | X (1, 1, 1, 512, 512)
step 00150 | loss 0.24231 | X (1, 1, 1, 512, 512)
step 00160 | loss 0.22744 | X (1, 1, 1, 512, 512)
step 00170 | loss 0.15848 | X (1, 1, 1, 512, 512)
step 00180 | loss 0.15031 | X (1, 1, 1, 512, 512)
step 00190 | loss 0.16474 | X (1, 1, 1, 512, 512)


In [43]:
# --- Save trained model ---
MODEL_SAVE_PATH = "convlstm_bacteria_growth.pt"
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Model saved to {MODEL_SAVE_PATH}")


Model saved to convlstm_bacteria_growth.pt


In [40]:
# --- meta normalization helpers ---
def _unwrap_meta(meta_batch):
    """
    Make a single-sample meta dict from what DataLoader collates.
    Handles:
      - list/tuple of dicts
      - dict of lists
      - plain dict
    """
    if isinstance(meta_batch, (list, tuple)) and len(meta_batch) > 0:
        return meta_batch[0]
    if isinstance(meta_batch, dict):
        # dict of lists? take the first element from each list
        any_val = next(iter(meta_batch.values())) if len(meta_batch) else None
        if isinstance(any_val, (list, tuple)) and len(any_val) > 0:
            return {k: v[0] for k, v in meta_batch.items()}
        return meta_batch
    return {"key": "UNK"}

def _normalize_group_key(g):
    """
    Convert group key into a hashable thing for dicts:
      - list -> tuple
      - other iterables -> tuple
      - None -> "UNK"
    """
    if g is None:
        return "UNK"
    if isinstance(g, (list, tuple)):
        return tuple(g)
    # strings are fine
    if isinstance(g, str):
        return g
    # anything iterable (e.g., numpy array)
    try:
        return tuple(g)
    except Exception:
        return str(g)


In [47]:
import math, time, torch, torch.nn.functional as F
import numpy as np
from collections import defaultdict
from torch.utils.data import DataLoader

# ---- knobs to control eval cost ----
EVAL_MAX_BATCHES = 50      # e.g., 50 to cap runtime
EVAL_T_SUB = 1               # e.g., 2 -> every 2nd frame
EVAL_CROP_HW = (256,256)          # e.g., (256,256)
EVAL_COMPUTE_SSIM = False    # turn off to speed up a lot
EVAL_DOWNSAMPLE_FOR_METRICS = 2  # e.g., 2 -> metrics on half-res (cheap)
EVAL_NUM_WORKERS = 2
EVAL_PIN_MEMORY  = True

def _to_float01(x: torch.Tensor) -> torch.Tensor:
    return x.clamp(0.0, 1.0)

def mae(pred, tgt):
    pred = _to_float01(pred); tgt = _to_float01(tgt)
    return torch.mean(torch.abs(pred - tgt)).item()

def mse(pred, tgt):
    pred = _to_float01(pred); tgt = _to_float01(tgt)
    return torch.mean((pred - tgt) ** 2).item()

def psnr(pred, tgt, data_range=1.0):
    m = mse(pred, tgt)
    if m <= 1e-12:
        return float("inf")
    return 20.0 * math.log10(data_range) - 10.0 * math.log10(m)

# fast Gaussian kernel for SSIM (only used if EVAL_COMPUTE_SSIM=True)
def _gaussian_kernel(size=11, sigma=1.5, device="cpu", dtype=torch.float32):
    coords = torch.arange(size, device=device, dtype=dtype) - size // 2
    g = torch.exp(-(coords**2)/(2*sigma**2))
    g = g / g.sum()
    k2 = (g[:, None] @ g[None, :])
    return (k2 / k2.sum())

def ssim(pred, tgt, data_range=1.0, size=11, sigma=1.5):
    k = _gaussian_kernel(size=size, sigma=sigma, device=pred.device, dtype=pred.dtype)[None, None, :, :]
    C1 = (0.01 * data_range) ** 2
    C2 = (0.03 * data_range) ** 2
    def filt(x): return F.conv2d(x, k, padding=size//2, groups=1)
    mu_x, mu_y = filt(pred), filt(tgt)
    mu_x2, mu_y2, mu_xy = mu_x*mu_x, mu_y*mu_y, mu_x*mu_y
    sigma_x2 = filt(pred*pred) - mu_x2
    sigma_y2 = filt(tgt*tgt) - mu_y2
    sigma_xy = filt(pred*tgt) - mu_xy
    ssim_map = ((2*mu_xy + C1) * (2*sigma_xy + C2)) / ((mu_x2 + mu_y2 + C1) * (sigma_x2 + sigma_y2 + C2))
    return ssim_map.mean().item()

# fast min/max normalization (no quantiles)
def fast_norm_seq(x):
    # x: [B, C, T, H, W]
    B, C, T, H, W = x.shape
    # x_flat = x.view(B, C, -1)            # ❌ causes error if x is non-contiguous
    x_flat = x.reshape(B, C, -1)           # ✅ handles non-contiguous tensors
    lo = torch.amin(x_flat, dim=2, keepdim=True)
    hi = torch.amax(x_flat, dim=2, keepdim=True)
    span = (hi - lo).clamp_min(1e-6)
    x_norm = ((x_flat - lo) / span).clamp(0, 1)
    # return x_norm.view(B, C, T, H, W)    # ❌
    return x_norm.reshape(B, C, T, H, W)   # ✅

def maybe_subsample_time(z, step):
    return z if (step is None or step <= 1) else z[:, :, ::step]

def maybe_crop_hw(z, crop):
    if not crop: return z
    Hc, Wc = crop
    return z[..., :Hc, :Wc]

def maybe_downsample_for_metrics(x4, factor):
    # x4: [B,1,H,W]
    if factor is None or factor <= 1: return x4
    return F.interpolate(x4, scale_factor=1.0/factor, mode="area")

@torch.no_grad()
def evaluate_model_fast(model, dataset, device, reduce_over_time="mean"):
    loader = DataLoader(
        dataset, batch_size=1, shuffle=False,
        num_workers=EVAL_NUM_WORKERS, pin_memory=EVAL_PIN_MEMORY
    )
    model.eval()
    use_amp = (device.type == "cuda")

    agg = dict(mae=[], mse=[], psnr=[])
    if EVAL_COMPUTE_SSIM:
        agg["ssim"] = []

    t0 = time.time()
    for step, (X, Y, aux, meta) in enumerate(loader):
        X = X.to(device, dtype=torch.float32, non_blocking=True)  # [B,1,T,H,W]
        Y = Y.to(device, dtype=torch.float32, non_blocking=True)

        # throttle cost
        if EVAL_T_SUB and EVAL_T_SUB > 1:
            X = maybe_subsample_time(X, EVAL_T_SUB)
            Y = maybe_subsample_time(Y, EVAL_T_SUB)
        if EVAL_CROP_HW:
            X = maybe_crop_hw(X, EVAL_CROP_HW)
            Y = maybe_crop_hw(Y, EVAL_CROP_HW)

        # fast normalization
        Xn = fast_norm_seq(X)
        Yn = fast_norm_seq(Y)

        # aux
        if hasattr(model, "rna_mlp") and hasattr(model.rna_mlp[0], "in_features"):
            k = model.rna_mlp[0].in_features
            aux = (aux.to(device, dtype=torch.float32, non_blocking=True)
                   if (aux is not None) else torch.zeros((X.shape[0], k), device=device))
        else:
            aux = None

        with torch.cuda.amp.autocast(enabled=use_amp):
            pred_seq = model(Xn, rna=aux)  # [B,1,T,H,W]

        # reduce to 2D
        if reduce_over_time == "mean":
            pred2d = pred_seq.mean(dim=2)
            Y2d    = Yn.mean(dim=2)
        else:  # 'last'
            pred2d = pred_seq[:, :, -1]
            Y2d    = Yn[:, :, -1]

        # downsample for metric speed, if requested
        if EVAL_DOWNSAMPLE_FOR_METRICS and EVAL_DOWNSAMPLE_FOR_METRICS > 1:
            pred2d_m = maybe_downsample_for_metrics(pred2d, EVAL_DOWNSAMPLE_FOR_METRICS)
            Y2d_m    = maybe_downsample_for_metrics(Y2d,    EVAL_DOWNSAMPLE_FOR_METRICS)
        else:
            pred2d_m, Y2d_m = pred2d, Y2d

        # metrics
        agg["mae"].append(mae(pred2d_m, Y2d_m))
        agg["mse"].append(mse(pred2d_m, Y2d_m))
        agg["psnr"].append(psnr(pred2d_m, Y2d_m, data_range=1.0))
        if EVAL_COMPUTE_SSIM:
            agg["ssim"].append(ssim(pred2d_m, Y2d_m, data_range=1.0))

        if (step % 10) == 0:
            dt = time.time() - t0
            print(f"[eval] {step} samples in {dt:.1f}s")

        if (EVAL_MAX_BATCHES is not None) and (step + 1 >= EVAL_MAX_BATCHES):
            break

    # summarize
    out = {k: (float(np.mean(v)) if len(v)>0 else float("nan")) for k,v in agg.items()}
    return out

# ---- run fast eval ----
if 'torch' in globals() and len(dataset) > 0:
    device_eval = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device_eval)
    summary = evaluate_model_fast(model, dataset, device_eval, reduce_over_time="mean")
    print("== Fast eval ==")
    for k,v in summary.items():
        print(f"{k.upper():5s}: {v:.4f}")
else:
    print("Skip eval: no torch or empty dataset.")


  with torch.cuda.amp.autocast(enabled=use_amp):


[eval] 0 samples in 5.3s
[eval] 10 samples in 22.0s
[eval] 20 samples in 35.6s
[eval] 30 samples in 47.8s
[eval] 40 samples in 60.1s
== Fast eval ==
MAE  : 0.0825
MSE  : 0.0141
PSNR : 20.7919


In [44]:
# import math, torch, torch.nn.functional as F
# import numpy as np
# from collections import defaultdict

# # --- helpers ---
# def _to_float01(x: torch.Tensor) -> torch.Tensor:
#     """Clamp to [0,1]; assumes your pipeline already scales to [0,1]."""
#     return x.clamp(0.0, 1.0)

# def mae(pred, tgt):
#     pred = _to_float01(pred); tgt = _to_float01(tgt)
#     return torch.mean(torch.abs(pred - tgt)).item()

# def mse(pred, tgt):
#     pred = _to_float01(pred); tgt = _to_float01(tgt)
#     return torch.mean((pred - tgt) ** 2).item()

# def psnr(pred, tgt, data_range=1.0):
#     m = mse(pred, tgt)
#     if m <= 1e-12:
#         return float("inf")
#     return 20.0 * math.log10(data_range) - 10.0 * math.log10(m)

# # --- SSIM (PyTorch) ---
# def _gaussian_kernel(size=11, sigma=1.5, device="cpu", dtype=torch.float32):
#     coords = torch.arange(size, device=device, dtype=dtype) - size // 2
#     g = torch.exp(-(coords**2)/(2*sigma**2))
#     g = g / g.sum()
#     k2 = (g[:, None] @ g[None, :])
#     return (k2 / k2.sum())

# def ssim(pred, tgt, data_range=1.0, size=11, sigma=1.5):
#     """
#     pred/tgt: [B,1,H,W] tensors in [0,1]
#     Returns mean SSIM over batch.
#     """
#     device, dtype = pred.device, pred.dtype
#     pred = _to_float01(pred); tgt = _to_float01(tgt)

#     C1 = (0.01 * data_range) ** 2
#     C2 = (0.03 * data_range) ** 2

#     k = _gaussian_kernel(size=size, sigma=sigma, device=device, dtype=dtype)[None, None, :, :]
#     def filt(x): return F.conv2d(x, k, padding=size//2, groups=1)

#     mu_x, mu_y = filt(pred), filt(tgt)
#     mu_x2, mu_y2, mu_xy = mu_x*mu_x, mu_y*mu_y, mu_x*mu_y

#     sigma_x2 = filt(pred*pred) - mu_x2
#     sigma_y2 = filt(tgt*tgt) - mu_y2
#     sigma_xy = filt(pred*tgt) - mu_xy

#     ssim_map = ((2*mu_xy + C1) * (2*sigma_xy + C2)) / ((mu_x2 + mu_y2 + C1) * (sigma_x2 + sigma_y2 + C2))
#     return ssim_map.mean().item()

# # --- safe meta helpers (no-ops if you already have your own) ---
# def _unwrap_meta(meta):
#     # meta may be a list/tuple of dicts when coming from DataLoader
#     if isinstance(meta, (list, tuple)) and len(meta) > 0:
#         meta = meta[0]
#     return meta if isinstance(meta, dict) else {}

# def _normalize_group_key(k):
#     # Expect a tuple like (spaceflight, material, medium), else make a string key
#     if isinstance(k, (tuple, list)):
#         return tuple(k)
#     return str(k)

# # --- normalization like training ---
# def robust_norm_seq(x):
#     # x: [B, C, T, H, W] -> [B, C, T, H, W] in [0,1]
#     B, C, T, H, W = x.shape
#     xf = x.reshape(B, C, T * H * W)
#     q1 = torch.quantile(xf, 0.01, dim=2, keepdim=True)
#     q9 = torch.quantile(xf, 0.99, dim=2, keepdim=True)
#     span = (q9 - q1).clamp_min(1e-6)
#     xf = ((xf - q1).clamp_min(0) / span).clamp(0, 1)
#     return xf.view(B, C, T, H, W)

# # --- evaluation loop ---
# @torch.no_grad()
# def evaluate_model(model, loader, device, max_batches=None, reduce_over_time="mean"):
#     """
#     reduce_over_time: 'mean' (default) or 'last' or None (for per-frame aggregation)
#     """
#     model.eval()
#     agg = dict(mae=[], mse=[], psnr=[], ssim=[])
#     per_group = defaultdict(lambda: dict(mae=[], mse=[], psnr=[], ssim=[]))

#     seen = 0
#     for step, (X, Y, aux, meta) in enumerate(loader):
#         # Move to device
#         X = X.to(device=device, dtype=torch.float32)  # [B,1,T,H,W]
#         Y = Y.to(device=device, dtype=torch.float32)  # [B,1,T,H,W]

#         # Normalize like training (no_grad)
#         with torch.no_grad():
#             Xn = robust_norm_seq(X)
#             Yn = robust_norm_seq(Y)

#         # Prepare aux if the model expects it
#         if hasattr(model, "rna_mlp") and hasattr(model.rna_mlp[0], "in_features"):
#             aux_dim = model.rna_mlp[0].in_features
#             aux = (aux.to(device, dtype=torch.float32) if (aux is not None) else
#                    torch.zeros((X.shape[0], aux_dim), device=device))
#         else:
#             aux = None

#         # Forward
#         pred_seq = model(Xn, rna=aux)           # [B,1,T,H,W]

#         # Reduce to 2D for metrics
#         if reduce_over_time == "mean":
#             pred2d = pred_seq.mean(dim=2)       # [B,1,H,W]
#             Y2d    = Yn.mean(dim=2)             # [B,1,H,W]
#         elif reduce_over_time == "last":
#             pred2d = pred_seq[:, :, -1]         # [B,1,H,W]
#             Y2d    = Yn[:, :, -1]
#         else:
#             # Per-frame metrics aggregated over T (uncomment to use)
#             # B, C, T, H, W = pred_seq.shape
#             # p4 = pred_seq.permute(0,2,1,3,4).reshape(B*T, C, H, W)
#             # y4 = Yn.permute(0,2,1,3,4).reshape(B*T, C, H, W)
#             # m_mae  = mae(p4, y4); m_mse = mse(p4, y4)
#             # m_psnr = psnr(p4, y4, data_range=1.0); m_ssim = ssim(p4, y4, data_range=1.0)
#             # (then append and continue)
#             raise NotImplementedError("Set reduce_over_time to 'mean' or 'last' for 2D metrics.")

#         # Metrics (2D)
#         m_mae  = mae(pred2d, Y2d)
#         m_mse  = mse(pred2d, Y2d)
#         m_psnr = psnr(pred2d, Y2d, data_range=1.0)
#         m_ssim = ssim(pred2d, Y2d, data_range=1.0)

#         agg["mae"].append(m_mae)
#         agg["mse"].append(m_mse)
#         agg["psnr"].append(m_psnr)
#         agg["ssim"].append(m_ssim)

#         meta0 = _unwrap_meta(meta)
#         group_key = _normalize_group_key(meta0.get("key", "UNK"))
#         per_group[group_key]["mae"].append(m_mae)
#         per_group[group_key]["mse"].append(m_mse)
#         per_group[group_key]["psnr"].append(m_psnr)
#         per_group[group_key]["ssim"].append(m_ssim)

#         seen += 1
#         if (max_batches is not None) and (seen >= max_batches):
#             break

#     def _summ(d):
#         return {k: (float(np.mean(v)) if len(v) > 0 else float("nan")) for k, v in d.items()}

#     summary = _summ(agg)
#     per_group_summary = {g: _summ(mdict) for g, mdict in per_group.items()}
#     return summary, per_group_summary

# # --- Run eval ---
# if 'torch' in globals() and len(dataset) > 0:
#     from torch.utils.data import DataLoader
#     eval_loader = DataLoader(dataset, batch_size=1, shuffle=False)
#     device_eval = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     model = model.to(device_eval)
#     summary, per_group = evaluate_model(model, eval_loader, device_eval, max_batches=200, reduce_over_time="mean")
#     print("== Global metrics ==")
#     for k, v in summary.items():
#         print(f"{k.upper():5s}: {v:.4f}")
#     print("\n== Per-group (first 8) ==")
#     for i, (g, s) in enumerate(per_group.items()):
#         if i >= 8: break
#         gstr = " | ".join(map(str, g)) if isinstance(g, (tuple, list)) else str(g)
#         print(f"{gstr}\n  " + "  ".join([f"{k}:{s[k]:.4f}" for k in ("mae","mse","psnr","ssim")]))
# else:
#     print("Skip eval: no torch or empty dataset.")


KeyboardInterrupt: 

In [48]:
# # ===== Visualization Utilities =====
# import numpy as np
# from pathlib import Path
# import tifffile as tiff
# import matplotlib.pyplot as plt

# def _to_np_uint8(x01):
#     x = (np.clip(x01, 0, 1) * 255.0).round().astype(np.uint8)
#     return x

# def save_timelapse_tiff(input_seq, pred_seq, target_seq, out_path: Path):
#     """
#     input_seq, pred_seq, target_seq: numpy arrays in [T, H, W] scaled to [0,1]
#     Writes a multi-page TIFF where each page is an H x (3W) panel: [input | pred | target].
#     """
#     out_path = Path(out_path)
#     out_path.parent.mkdir(parents=True, exist_ok=True)

#     T, H, W = input_seq.shape
#     pages = []
#     for t in range(T):
#         panel = np.concatenate([
#             _to_np_uint8(input_seq[t]),
#             _to_np_uint8(pred_seq[t]),
#             _to_np_uint8(target_seq[t])
#         ], axis=1)  # H x (3W)
#         pages.append(panel)

#     with tiff.TiffWriter(str(out_path), bigtiff=True) as tw:
#         for p in pages:
#             tw.write(p, photometric='minisblack')

# def plot_growth_curve(input_seq, pred_seq, target_seq, out_png: Path, title="Bacterial growth proxy"):
#     """
#     Simple growth proxy = sum of normalized intensities per frame.
#     """
#     out_png = Path(out_png)
#     out_png.parent.mkdir(parents=True, exist_ok=True)

#     g_in  = input_seq.reshape(input_seq.shape[0], -1).sum(axis=1)
#     g_pr  = pred_seq.reshape(pred_seq.shape[0], -1).sum(axis=1)
#     g_tar = target_seq.reshape(target_seq.shape[0], -1).sum(axis=1)

#     plt.figure(figsize=(6,4))
#     plt.plot(g_in, label="input")
#     plt.plot(g_pr, label="prediction")
#     plt.plot(g_tar, label="target")
#     plt.xlabel("frame (t)")
#     plt.ylabel("sum of intensities")
#     plt.title(title)
#     plt.legend()
#     plt.tight_layout()
#     plt.savefig(out_png, dpi=150)
#     plt.close()

# def make_viz_for_sample(model, dataset, idx=0, device="cpu", out_dir="viz"):
#     """
#     Runs a single dataset sample through the model and writes:
#       - timelapse panels: viz/sample_{idx}.tif
#       - growth curve:     viz/sample_{idx}_growth.png
#     """
#     model.eval()
#     X, Y, aux, meta = dataset[idx]   # X/Y: [1,T,H,W] torch
#     with torch.no_grad():
#         X = X.unsqueeze(0).to(device)        # [1,1,T,H,W]
#         Y = Y.unsqueeze(0).to(device)        # [1,1,T,H,W]
#         # normalize like training (frame-wise)
#         def _robust_norm_seq_torch(z):
#             B,C,T,H,W = z.shape
#             zf = z.view(B,C,-1)
#             q1 = torch.quantile(zf, 0.01, dim=2, keepdim=True)
#             q9 = torch.quantile(zf, 0.99, dim=2, keepdim=True)
#             span = (q9 - q1).clamp_min(1e-6)
#             zf = ((zf - q1).clamp_min(0) / span).clamp(0,1)
#             return zf.view(B,C,T,H,W)
#         Xn = _robust_norm_seq_torch(X)
#         Yn = _robust_norm_seq_torch(Y)

#         # RNA zeros if needed
#         if hasattr(model, "rna_mlp") and hasattr(model.rna_mlp[0], "in_features"):
#             aux_dim = model.rna_mlp[0].in_features
#             aux = torch.zeros((1, aux_dim), device=device)
#         else:
#             aux = None

#         Pred = model(Xn, rna=aux)  # [1,1,T,H,W]

#     # to numpy [T,H,W] in [0,1]
#     x_np = Xn.squeeze(0).squeeze(0).cpu().numpy()
#     y_np = Yn.squeeze(0).squeeze(0).cpu().numpy()
#     p_np = Pred.squeeze(0).squeeze(0).cpu().numpy()

#     out_dir = Path(out_dir)
#     out_dir.mkdir(parents=True, exist_ok=True)
#     tiff_path = out_dir / f"sample_{idx}.tif"
#     png_path  = out_dir / f"sample_{idx}_growth.png"

#     save_timelapse_tiff(x_np, p_np, y_np, tiff_path)
#     plot_growth_curve(x_np, p_np, y_np, png_path, title=f"Growth (sample {idx})")

#     print(f"[viz] wrote {tiff_path} and {png_path}")


In [49]:
# try:
#     import imageio.v2 as imageio
#     def write_gif_from_tiff(tiff_path, gif_path):
#         stack = tiff.imread(str(tiff_path))  # [T, H, 3W] uint8
#         imageio.mimsave(gif_path, [frame for frame in stack], duration=0.2)
#         print(f"[viz] wrote {gif_path}")
#     write_gif_from_tiff(tiff_path, tiff_path.with_suffix(".gif"))
# except Exception as _e:
#     pass


In [None]:
# ===== Visualization Utilities =====
from pathlib import Path
import numpy as np
import tifffile as tiff
import matplotlib.pyplot as plt
import torch

def _to_np_uint8(x01: np.ndarray) -> np.ndarray:
    x = (np.clip(x01, 0, 1) * 255.0).round().astype(np.uint8)
    return x

def save_timelapse_tiff(input_seq, pred_seq, target_seq, out_path: Path):
    """
    input_seq, pred_seq, target_seq: numpy arrays in [T, H, W] scaled to [0,1]
    Writes a multi-page TIFF where each page is an H x (3W) panel: [input | pred | target].
    """
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    T, H, W = input_seq.shape
    with tiff.TiffWriter(str(out_path), bigtiff=True) as tw:
        for t_idx in range(T):
            panel = np.concatenate([
                _to_np_uint8(input_seq[t_idx]),
                _to_np_uint8(pred_seq[t_idx]),
                _to_np_uint8(target_seq[t_idx]),
            ], axis=1)  # H x (3W)
            tw.write(panel, photometric='minisblack')

def plot_growth_curve(input_seq, pred_seq, target_seq, out_png: Path, title="Bacterial growth proxy"):
    """
    Simple growth proxy = sum of normalized intensities per frame.
    """
    out_png = Path(out_png)
    out_png.parent.mkdir(parents=True, exist_ok=True)

    g_in  = input_seq.reshape(input_seq.shape[0], -1).sum(axis=1)
    g_pr  = pred_seq.reshape(pred_seq.shape[0], -1).sum(axis=1)
    g_tar = target_seq.reshape(target_seq.shape[0], -1).sum(axis=1)

    plt.figure(figsize=(6,4))
    plt.plot(g_in,  label="input")
    plt.plot(g_pr,  label="prediction")
    plt.plot(g_tar, label="target")
    plt.xlabel("frame (t)")
    plt.ylabel("sum of intensities")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def _fast_norm_seq_torch(z: torch.Tensor) -> torch.Tensor:
    """Min/max normalization to [0,1] per-sample across T*H*W. Safer & faster than quantiles."""
    B,C,T,H,W = z.shape
    zf = z.reshape(B, C, -1)
    lo = torch.amin(zf, dim=2, keepdim=True)
    hi = torch.amax(zf, dim=2, keepdim=True)
    span = (hi - lo).clamp_min(1e-6)
    zf = ((zf - lo) / span).clamp(0,1)
    return zf.reshape(B,C,T,H,W)

def make_viz_for_sample(model, dataset, idx=0, device="cpu", out_dir="viz", make_gif=True):
    """
    Runs a single dataset sample through the model and writes:
      - timelapse panels: viz/sample_{idx}.tif
      - growth curve:     viz/sample_{idx}_growth.png
      - optional GIF:     viz/sample_{idx}.gif
    Returns (tiff_path, png_path, gif_path or None)
    """
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    if idx < 0 or idx >= len(dataset):
        raise IndexError(f"idx {idx} out of range for dataset of length {len(dataset)}")

    model.eval()
    X, Y, aux, meta = dataset[idx]   # X/Y: [1,T,H,W] torch
    with torch.no_grad():
        X = X.unsqueeze(0).to(device)  # [1,1,T,H,W]
        Y = Y.unsqueeze(0).to(device)  # [1,1,T,H,W]

        Xn = _fast_norm_seq_torch(X)
        Yn = _fast_norm_seq_torch(Y)

        # RNA zeros if needed
        if hasattr(model, "rna_mlp") and hasattr(model.rna_mlp[0], "in_features"):
            aux_dim = model.rna_mlp[0].in_features
            aux = torch.zeros((1, aux_dim), device=device)
        else:
            aux = None

        Pred = model(Xn, rna=aux)  # [1,1,T,H,W]

    # to numpy [T,H,W] in [0,1]
    x_np = Xn.squeeze(0).squeeze(0).detach().cpu().numpy()
    y_np = Yn.squeeze(0).squeeze(0).detach().cpu().numpy()
    p_np = Pred.squeeze(0).squeeze(0).detach().cpu().numpy()

    tiff_path = out_dir / f"sample_{idx}.tif"
    png_path  = out_dir / f"sample_{idx}_growth.png"

    save_timelapse_tiff(x_np, p_np, y_np, tiff_path)
    plot_growth_curve(x_np, p_np, y_np, png_path, title=f"Growth (sample {idx})")

    gif_path = None
    if make_gif:
        try:
            import imageio.v2 as imageio
            # stack = tiff.imread(str(tiff_path))  # [T, H, 3W] uint8
            stack = np.array(Image.open(str(tiff_path)))
            imageio.mimsave(tiff_path.with_suffix(".gif"), [frame for frame in stack], duration=0.2)
            gif_path = tiff_path.with_suffix(".gif")
        except Exception as e:
            print(f"[viz] GIF creation skipped: {e}")

    print(f"[viz] wrote {tiff_path} and {png_path}" + (f" and {gif_path}" if gif_path else ""))
    return tiff_path, png_path, gif_path


In [51]:
# Pick a sample, run viz, and get the paths back
device_viz = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device_viz)

tiff_path, png_path, gif_path = make_viz_for_sample(
    model, dataset, idx=0, device=device_viz, out_dir="viz", make_gif=True
)
print("Outputs:", tiff_path, png_path, gif_path)


[viz] wrote viz/sample_0.tif and viz/sample_0_growth.png and viz/sample_0.gif
Outputs: viz/sample_0.tif viz/sample_0_growth.png viz/sample_0.gif


In [52]:
# ==== Predict + show: input | prediction | ground-truth trajectory ====
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import tifffile as tiff
import torch

def _fast_norm_seq_torch(z: torch.Tensor) -> torch.Tensor:
    """Min/max normalization to [0,1] per-sample across T*H*W."""
    B,C,T,H,W = z.shape
    zf = z.reshape(B, C, -1)
    lo = torch.amin(zf, dim=2, keepdim=True)
    hi = torch.amax(zf, dim=2, keepdim=True)
    span = (hi - lo).clamp_min(1e-6)
    zf = ((zf - lo) / span).clamp(0,1)
    return zf.reshape(B,C,T,H,W)

def _to_uint8(x01: np.ndarray) -> np.ndarray:
    return np.clip(x01,0,1).astype(np.float32) * 255.0

@torch.no_grad()
def predict_and_viz_triptych(model, dataset, idx=0, device=None, out_dir="viz",
                             sample_every=1, max_cols=16, make_gif=True):
    """
    Produces:
      - triptych grid PNG:   viz/triptych_{idx}.png
      - multipage panel TIFF viz/triptych_{idx}.tif  (each page: [Input | Pred | GT])
      - optional GIF         viz/triptych_{idx}.gif
    """
    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval(); model = model.to(device)

    # --- pull sample ---
    X, Y, aux, meta = dataset[idx]      # X,Y: [1,T,H,W]
    X = X.unsqueeze(0).to(device)       # [1,1,T,H,W]
    Y = Y.unsqueeze(0).to(device)       # [1,1,T,H,W]

    # normalize like training (fast)
    Xn = _fast_norm_seq_torch(X)
    Yn = _fast_norm_seq_torch(Y)

    # RNA aux (zeros if needed)
    if hasattr(model, "rna_mlp") and hasattr(model.rna_mlp[0], "in_features"):
        aux_dim = model.rna_mlp[0].in_features
        aux_t = torch.zeros((1, aux_dim), device=device)
    else:
        aux_t = None

    # --- predict ---
    Pred = model(Xn, rna=aux_t)  # [1,1,T,H,W]

    # --- to numpy [T,H,W] in [0,1] ---
    x_np = Xn[0,0].detach().cpu().numpy()   # [T,H,W]
    y_np = Yn[0,0].detach().cpu().numpy()   # [T,H,W]
    p_np = Pred[0,0].detach().cpu().numpy() # [T,H,W]

    # --- subsample frames for the grid ---
    T = x_np.shape[0]
    stride = max(1, sample_every)
    cols = min(max_cols, (T + stride - 1) // stride)
    sel = list(range(0, T, stride))[:cols]

    # --- make a single PNG grid: rows = (Input, Pred, GT), columns = time ---
    fig_h = 3.2
    fig_w = 1.6 * cols
    fig, axes = plt.subplots(3, cols, figsize=(fig_w, fig_h), squeeze=False)
    rows = [("Input", x_np), ("Prediction", p_np), ("Ground truth", y_np)]
    for r, (title, arr) in enumerate(rows):
        for c, t_idx in enumerate(sel):
            ax = axes[r, c]
            ax.imshow(arr[t_idx], cmap="gray", vmin=0, vmax=1)
            ax.axis("off")
            if c == 0:
                ax.set_title(title, fontsize=10, pad=4, loc="left")
            if r == 0:
                ax.set_xlabel(f"t={t_idx}", fontsize=9)
    plt.tight_layout()
    png_path = out_dir / f"triptych_{idx}.png"
    fig.savefig(png_path, dpi=150)
    plt.close(fig)

    # --- multipage TIFF with side-by-side panels per frame: [Input | Pred | GT] ---
    tiff_path = out_dir / f"triptych_{idx}.tif"
    with tiff.TiffWriter(str(tiff_path), bigtiff=True) as tw:
        for t in range(T):
            panel = np.concatenate([
                _to_uint8(x_np[t]),
                _to_uint8(p_np[t]),
                _to_uint8(y_np[t]),
            ], axis=1).astype(np.uint8)  # H x (3W)
            tw.write(panel, photometric='minisblack')

    # --- optional GIF flipping through time ---
    gif_path = None
    if make_gif:
        try:
            import imageio.v2 as imageio
            # build frames as wide panels
            frames = []
            for t in range(T):
                panel = np.concatenate([
                    _to_uint8(x_np[t]),
                    _to_uint8(p_np[t]),
                    _to_uint8(y_np[t]),
                ], axis=1).astype(np.uint8)
                frames.append(panel)
            gif_path = out_dir / f"triptych_{idx}.gif"
            imageio.mimsave(gif_path, frames, duration=0.2)
        except Exception as e:
            print(f"[viz] GIF skipped: {e}")
            gif_path = None

    print(f"[viz] wrote:\n - {png_path}\n - {tiff_path}" + (f"\n - {gif_path}" if gif_path else ""))
    return png_path, tiff_path, gif_path


In [57]:
from pathlib import Path
import torch

def viz_subset(
    model, dataset, indices=None, out_dir="viz_subset",
    sample_every=1, max_cols=None, make_gif=False
):
    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device).eval()

    # choose first N if not provided
    if indices is None:
        N = min(10, len(dataset))  # visualize at most 10
        indices = list(range(N))

    for j, idx in enumerate(indices):
        # fetch once to inspect T
        X, Y, *_ = dataset[idx]         # [1,T,H,W] each
        T = X.shape[1] if X.ndim == 4 else 1
        print(f"[subset] idx={idx} -> T={T}, HxW={X.shape[-2]}x{X.shape[-1]}")

        # force multiple columns if T>1
        mc = max_cols if (max_cols is not None) else (T if T <= 24 else 24)
        se = min(sample_every, max(1, T//mc) if mc else 1)

        # unique file prefix so nothing gets overwritten
        # include idx and T so you can trace what was rendered
        prefix = f"s{idx:04d}_T{T}"

        png_path, tiff_path, gif_path = predict_and_viz_triptych(
            model, dataset, idx=idx, device=device, out_dir=out_dir,
            sample_every=se, max_cols=mc, make_gif=make_gif,
        )

        # rename to unique names if function used generic names
        pn2  = out_dir / f"{prefix}__{png_path.name}"
        tf2  = out_dir / f"{prefix}__{tiff_path.name}"
        png_path.rename(pn2); tiff_path.rename(tf2)
        if gif_path:
            gf2 = out_dir / f"{prefix}__{gif_path.name}"
            gif_path.rename(gf2)
            print(f"[viz] wrote: {pn2}, {tf2}, {gf2}")
        else:
            print(f"[viz] wrote: {pn2}, {tf2}")



In [58]:
# visualize first 5 samples, show as many frames as possible (cap to 24 columns)
viz_subset(
    model, dataset, indices=list(range(5)),
    out_dir="viz_subset",
    sample_every=1,   # try every frame first
    max_cols=None,    # None -> auto use up to 24 columns based on T
    make_gif=False    # flip to True later
)



[subset] idx=0 -> T=1, HxW=512x512
[viz] wrote:
 - viz_subset/triptych_0.png
 - viz_subset/triptych_0.tif
[viz] wrote: viz_subset/s0000_T1__triptych_0.png, viz_subset/s0000_T1__triptych_0.tif
[subset] idx=1 -> T=1, HxW=512x512
[viz] wrote:
 - viz_subset/triptych_1.png
 - viz_subset/triptych_1.tif
[viz] wrote: viz_subset/s0001_T1__triptych_1.png, viz_subset/s0001_T1__triptych_1.tif
[subset] idx=2 -> T=1, HxW=512x512
[viz] wrote:
 - viz_subset/triptych_2.png
 - viz_subset/triptych_2.tif
[viz] wrote: viz_subset/s0002_T1__triptych_2.png, viz_subset/s0002_T1__triptych_2.tif
[subset] idx=3 -> T=1, HxW=512x512
[viz] wrote:
 - viz_subset/triptych_3.png
 - viz_subset/triptych_3.tif
[viz] wrote: viz_subset/s0003_T1__triptych_3.png, viz_subset/s0003_T1__triptych_3.tif
[subset] idx=4 -> T=1, HxW=512x512
[viz] wrote:
 - viz_subset/triptych_4.png
 - viz_subset/triptych_4.tif
[viz] wrote: viz_subset/s0004_T1__triptych_4.png, viz_subset/s0004_T1__triptych_4.tif


In [59]:
# ==== Predict + show: input | prediction | ground-truth trajectory (with contrast stretch) ====
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import tifffile as tiff
import torch

def _fast_norm_seq_torch(z: torch.Tensor) -> torch.Tensor:
    """Min/max normalization to [0,1] per-sample across T*H*W."""
    B,C,T,H,W = z.shape
    zf = z.reshape(B, C, -1)
    lo = torch.amin(zf, dim=2, keepdim=True)
    hi = torch.amax(zf, dim=2, keepdim=True)
    span = (hi - lo).clamp_min(1e-6)
    zf = ((zf - lo) / span).clamp(0,1)
    return zf.reshape(B,C,T,H,W)

def _to_uint8(x01: np.ndarray) -> np.ndarray:
    return np.clip(x01,0,1).astype(np.float32) * 255.0

def _stretch01(x, lo=2, hi=98):
    """Percentile-based contrast stretch."""
    a, b = np.percentile(x, [lo, hi])
    return np.clip((x - a) / max(b - a, 1e-6), 0, 1)

@torch.no_grad()
def predict_and_viz_triptych(model, dataset, idx=0, device=None, out_dir="viz",
                             sample_every=1, max_cols=16, make_gif=True):
    """
    Produces:
      - PNG grid (Input / Prediction / Ground truth)
      - multipage TIFF (per-frame [Input|Pred|GT])
      - optional GIF
    Applies contrast stretching to prediction frames for better visibility.
    """
    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval(); model = model.to(device)

    # --- pull sample ---
    X, Y, aux, meta = dataset[idx]      # X,Y: [1,T,H,W]
    X = X.unsqueeze(0).to(device)       # [1,1,T,H,W]
    Y = Y.unsqueeze(0).to(device)       # [1,1,T,H,W]

    # normalize like training (fast)
    Xn = _fast_norm_seq_torch(X)
    Yn = _fast_norm_seq_torch(Y)

    # RNA aux (zeros if needed)
    if hasattr(model, "rna_mlp") and hasattr(model.rna_mlp[0], "in_features"):
        aux_dim = model.rna_mlp[0].in_features
        aux_t = torch.zeros((1, aux_dim), device=device)
    else:
        aux_t = None

    # --- predict ---
    Pred = model(Xn, rna=aux_t)  # [1,1,T,H,W]

    # --- to numpy [T,H,W] in [0,1] ---
    x_np = Xn[0,0].detach().cpu().numpy()
    y_np = Yn[0,0].detach().cpu().numpy()
    p_np = Pred[0,0].detach().cpu().numpy()

    # --- contrast stretch prediction frames ---
    p_np = np.stack([_stretch01(frame, lo=2, hi=98) for frame in p_np], axis=0)

    # --- subsample frames for grid ---
    T = x_np.shape[0]
    stride = max(1, sample_every)
    cols = min(max_cols, (T + stride - 1) // stride)
    sel = list(range(0, T, stride))[:cols]

    # --- PNG grid: rows = (Input, Pred, GT), columns = time ---
    fig, axes = plt.subplots(3, cols, figsize=(1.8*cols, 3.2), squeeze=False)
    rows = [("Input", x_np), ("Prediction", p_np), ("Ground truth", y_np)]
    for r, (title, arr) in enumerate(rows):
        for c, t_idx in enumerate(sel):
            ax = axes[r, c]
            ax.imshow(arr[t_idx], cmap="gray", vmin=0, vmax=1)
            ax.axis("off")
            if c == 0:
                ax.set_ylabel(title, fontsize=10, labelpad=4)
            if r == 0:
                ax.set_title(f"t={t_idx}", fontsize=9)
    plt.tight_layout()
    png_path = out_dir / f"triptych_{idx}_stretched.png"
    fig.savefig(png_path, dpi=150)
    plt.close(fig)

    # --- multipage TIFF: [Input|Pred|GT] per frame ---
    tiff_path = out_dir / f"triptych_{idx}_stretched.tif"
    with tiff.TiffWriter(str(tiff_path), bigtiff=True) as tw:
        for t in range(T):
            panel = np.concatenate([
                _to_uint8(x_np[t]),
                _to_uint8(p_np[t]),
                _to_uint8(y_np[t]),
            ], axis=1).astype(np.uint8)
            tw.write(panel, photometric='minisblack')

    # --- optional GIF ---
    gif_path = None
    if make_gif:
        try:
            import imageio.v2 as imageio
            frames = []
            for t in range(T):
                panel = np.concatenate([
                    _to_uint8(x_np[t]),
                    _to_uint8(p_np[t]),
                    _to_uint8(y_np[t]),
                ], axis=1).astype(np.uint8)
                frames.append(panel)
            gif_path = out_dir / f"triptych_{idx}_stretched.gif"
            imageio.mimsave(gif_path, frames, duration=0.2)
        except Exception as e:
            print(f"[viz] GIF skipped: {e}")

    print(f"[viz] wrote:\n - {png_path}\n - {tiff_path}" + (f"\n - {gif_path}" if gif_path else ""))
    return png_path, tiff_path, gif_path


In [60]:
device_viz = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Try first few samples to check visibility
for i in range(3):   # visualize 3 examples
    predict_and_viz_triptych(
        model, dataset, idx=i, device=device_viz,
        out_dir="viz_stretched",
        sample_every=2,   # show every 2nd frame
        max_cols=10,      # up to 10 columns
        make_gif=False
    )


[viz] wrote:
 - viz_stretched/triptych_0_stretched.png
 - viz_stretched/triptych_0_stretched.tif
[viz] wrote:
 - viz_stretched/triptych_1_stretched.png
 - viz_stretched/triptych_1_stretched.tif
[viz] wrote:
 - viz_stretched/triptych_2_stretched.png
 - viz_stretched/triptych_2_stretched.tif


In [61]:
X, Y, _, _ = dataset[0]
print("Input shape:", X.shape)  # expected [1, T, H, W]

Input shape: torch.Size([1, 1, 512, 512])


In [62]:
predict_and_viz_triptych(
    model, dataset, idx=0,
    sample_every=1,    # show every frame
    max_cols=20,       # up to 20 columns (frames)
    make_gif=False
)

[viz] wrote:
 - viz/triptych_0_stretched.png
 - viz/triptych_0_stretched.tif


(PosixPath('viz/triptych_0_stretched.png'),
 PosixPath('viz/triptych_0_stretched.tif'),
 None)

In [63]:
predict_and_viz_triptych(
    model, dataset, idx=0,
    sample_every=1,    # show every frame
    max_cols=20,       # allow up to 20 frames side-by-side
    make_gif=False
)

[viz] wrote:
 - viz/triptych_0_stretched.png
 - viz/triptych_0_stretched.tif


(PosixPath('viz/triptych_0_stretched.png'),
 PosixPath('viz/triptych_0_stretched.tif'),
 None)

In [70]:
# --- Parse ISA-Tab and build trajectories compatible with our loaders ---

import pandas as pd
import re
from pathlib import Path

# Path to your uploaded file (you said it's /mnt/data/s_OSD-627.txt)
ISA_PATH = Path("s_OSD-627.txt")

# 1) Read + normalize columns we need
df = pd.read_csv(ISA_PATH, sep="\t", dtype=str).fillna("")

# Canonical, shorter column names
COL_MAP = {
    "Source Name": "source_id",
    "Sample Name": "sample_id",
    "Factor Value[Spaceflight]": "spaceflight",                    # Ground / Space Flight
    "Factor Value[Growth Environment]": "material",                # SS316, LIS, Silicone, etc.
    "Factor Value[Time]": "time",                                  # 1 / 2 / 3
    "Parameter Value[Sample Media Information]": "medium",         # e.g., LB broth ... KNO3, Cellulose etc.
}
for old, new in COL_MAP.items():
    if old in df.columns:
        df.rename(columns={old: new}, inplace=True)

# Keep only what we need
keep_cols = ["source_id", "sample_id", "spaceflight", "material", "time", "medium"]
for c in keep_cols:
    if c not in df.columns:
        df[c] = ""

df["time_num"] = pd.to_numeric(df["time"].str.extract(r"(\d+)")[0], errors="coerce").astype("Int64")

# 2) Clean/id helpers so we can match TIFF names robustly
def clean_id(x: str) -> str:
    # Keep alnum + . + _
    return re.sub(r"[^A-Za-z0-9._-]", "", x or "").strip()

df["source_id_clean"] = df["source_id"].apply(clean_id)
df["sample_id_clean"] = df["sample_id"].apply(clean_id)

# A flexible set of "tokens" that might appear in TIFF filenames
def id_tokens(row):
    toks = []
    if row["sample_id_clean"]:
        toks.append(row["sample_id_clean"])
    if row["source_id_clean"]:
        toks.append(row["source_id_clean"])
    # Also common short tokens like "G1.1" (prefix of many filenames)
    short_source = row["source_id_clean"].split("_")[0]
    if short_source and short_source not in toks:
        toks.append(short_source)
    return toks

df["tokens"] = df.apply(id_tokens, axis=1)

# 3) Group into trajectories by (spaceflight, material, medium), ordered by time 1→2→3
def keyify(s):
    return (str(s["spaceflight"]).strip() or "NA",
            str(s["material"]).strip() or "NA",
            str(s["medium"]).strip() or "NA")

trajectories = {}
for gkey, gdf in df.groupby(df.apply(keyify, axis=1)):
    bucket = {1: [], 2: [], 3: []}
    for _, r in gdf.iterrows():
        t = int(r["time_num"]) if pd.notna(r["time_num"]) else None
        if t in (1, 2, 3):
            # store both IDs; downstream will try both when matching TIFFs
            bucket[t].append({
                "source_id": r["source_id_clean"],
                "sample_id": r["sample_id_clean"],
                "tokens": r["tokens"],
            })
    # only keep if at least one day exists
    if any(len(v) for v in bucket.values()):
        trajectories[gkey] = bucket

# Human-friendly flat list of complete (has at least 2 days; prefer 1→2→3) trajectories
trajectory_samples = []
for gkey, days in trajectories.items():
    if sum(1 for d in (1,2,3) if len(days[d])>0) >= 2:
        trajectory_samples.append((gkey, days.get(1, []), days.get(2, []), days.get(3, [])))

print(f"[isa] groups: {len(trajectories)} | usable trajectories (>=2 days present): {len(trajectory_samples)}")

# 4) TIFF matching helper — works for both folder-based and zip-based setups

def build_name_index_from_folder(tif_paths):
    """Return a list of searchable names (relative) and a fast lookup string for matching."""
    # Normalize as posix paths for consistent substring search
    names = [str(p).replace("\\", "/") for p in tif_paths]
    return names

def build_name_index_from_zip(zip_namelist):
    names = [n.replace("\\", "/") for n in zip_namelist]
    # Filter to tiff-like
    names = [n for n in names if n.lower().endswith((".tif", ".tiff"))]
    return names

# Decide which source of TIFF names you have:
# - If you have a folder list: set TIFF_NAME_INDEX = build_name_index_from_folder(tif_files)
# - If you have a zip namelist: set TIFF_NAME_INDEX = build_name_index_from_zip(tif_files)  # where tif_files came from zf.namelist()

TIFF_NAME_INDEX = None  # <-- you must set this after your discovery code runs.

def match_tifs_for_ids(id_record, name_index, max_per_id=8):
    """Given one ISA record {'source_id','sample_id','tokens'}, return all matching TIFF paths (strings)."""
    if name_index is None:
        return []
    toks = [t for t in id_record.get("tokens", []) if t]
    hits = []
    for nm in name_index:
        nm_lower = nm.lower()
        # require every token to appear? Too strict. We'll use "any token" match, but prefer stronger matches later.
        if any(t.lower() in nm_lower for t in toks):
            hits.append(nm)
            if len(hits) >= max_per_id:
                break
    # De-duplicate, keep stable order
    out = []
    seen = set()
    for h in hits:
        if h not in seen:
            out.append(h); seen.add(h)
    return out

# 5) Example: build a compact preview of first few trajectories and their TIFF matches (once TIFF_NAME_INDEX is set)
def preview_first_n_trajectories(n=3):
    if TIFF_NAME_INDEX is None:
        print("[preview] Please set TIFF_NAME_INDEX after you enumerate TIFF files (folder or zip).")
        return
    shown = 0
    for gkey, d1, d2, d3 in trajectory_samples:
        print(f"\nGroup: {gkey}  (days present: {[d for d in (1,2,3) if len({1:d1,2:d2,3:d3}[d])>0]})")
        for day, bucket in ((1,d1),(2,d2),(3,d3)):
            if not bucket:
                continue
            print(f"  Day {day}:")
            for rec in bucket[:2]:   # show up to 2 records/day
                hits = match_tifs_for_ids(rec, TIFF_NAME_INDEX)
                print(f"    ids={rec['source_id']}/{rec['sample_id']}  ->  {len(hits)} tif(s) e.g. {hits[:2]}")
        shown += 1
        if shown >= n:
            break


[isa] groups: 12 | usable trajectories (>=2 days present): 12


In [71]:
TIFF_NAME_INDEX = build_name_index_from_zip(tif_files)

In [72]:
preview_first_n_trajectories(n=5)


Group: ('Ground', 'Cellulose Membrane', 'mAUMg-hi Pi')  (days present: [1, 2, 3])
  Day 1:
    ids=G2.4/G2.5003  ->  5 tif(s) e.g. ['microscopy/LSDS-55_microscopy_G2.4.tif', 'microscopy/LSDS-55_microscopy_G2.4003.tif']
    ids=G2.6/G2.6  ->  2 tif(s) e.g. ['microscopy/LSDS-55_microscopy_G2.6.tif', 'microscopy/LSDS-55_microscopy_G2.6001.tif']
  Day 2:
    ids=G8.5/G8.5002  ->  2 tif(s) e.g. ['microscopy/LSDS-55_microscopy_G8.5002.tif', 'microscopy/LSDS-55_microscopy_G8.5003.tif']
    ids=G8.5/G8.5003  ->  2 tif(s) e.g. ['microscopy/LSDS-55_microscopy_G8.5002.tif', 'microscopy/LSDS-55_microscopy_G8.5003.tif']
  Day 3:
    ids=G14.5/G14.5002  ->  2 tif(s) e.g. ['microscopy/LSDS-55_microscopy_G14.5002.tif', 'microscopy/LSDS-55_microscopy_G14.5003.tif']
    ids=G14.5/G14.5003  ->  2 tif(s) e.g. ['microscopy/LSDS-55_microscopy_G14.5002.tif', 'microscopy/LSDS-55_microscopy_G14.5003.tif']

Group: ('Ground', 'Lubricant Impregnated Surface (LIS)', 'LB broth (Lennox) supplemented with KNO3')  (d

In [73]:
# ==== Build a visualization dataset from your trajectories + tif_map, then render ====
import numpy as np, torch
from torch.utils.data import Dataset
from pathlib import Path

# --- helper: reduce TIFF arrays to a single 2D frame for that "day"
def _to_frame2d(arr):
    a = np.asarray(arr)
    if a.ndim == 2:
        return a.astype(np.float32)
    a = np.squeeze(a)
    if a.ndim == 2:
        return a.astype(np.float32)
    if a.ndim == 3:
        # If [T,H,W] or [Z,H,W], use mean projection (change to [0], max, etc. if you prefer)
        return a.mean(axis=0).astype(np.float32)
    # fallback: try flatten-wise reshape to last two dims
    return a.reshape(a.shape[-2], a.shape[-1]).astype(np.float32)

# --- helper: pick the first tif that is actually present in tif_map (keys by basename)
def _pick_tif_for_idrecord(id_record, tif_map, name_index=None):
    # id_record contains "tokens" you built earlier
    tokens = [t.lower() for t in id_record.get("tokens", []) if t]
    if not tokens:
        return None
    # check available basenames in tif_map
    for k in tif_map.keys():
        stem = Path(k).stem.lower()
        if any(t in stem for t in tokens):
            return k
    # optionally fall back to name_index (zip/folder namelist) then map to basename
    if name_index:
        for nm in name_index:
            stem = Path(nm).stem.lower()
            if any(t in stem for t in tokens):
                base = Path(nm).name
                if base in tif_map:
                    return base
    return None

class TrajectoryVizDataset(Dataset):
    """
    One sample per (spaceflight, material, medium) trajectory with >=3 days.
    X: [1, T_win, H, W] (e.g., days [d1, d2])
    Y: [1, T_win, H, W] (e.g., days [d2, d3])
    """
    def __init__(self, trajectories, tif_map, tiff_name_index=None, T_win=2):
        self.samples = []     # list of dicts with: key, days, frames2d
        self.T_win = T_win
        self.tif_map = tif_map

        for gkey, days_dict in trajectories.items():
            # collect ordered (day, frame2d) with one representative frame per day
            ordered = []
            for d in sorted([d for d in (1,2,3) if d in days_dict]):
                bucket = days_dict[d]
                if not bucket:
                    continue
                # choose first resolvable tif from this day's ids
                chosen = None
                for rec in bucket:
                    tk = _pick_tif_for_idrecord(rec, tif_map, name_index=tiff_name_index)
                    if tk is not None:
                        chosen = tk
                        break
                if chosen is None:
                    continue
                frame = _to_frame2d(tif_map[chosen])
                ordered.append((d, frame))
            if len(ordered) < (self.T_win + 1):
                continue

            # enforce common HxW within this group
            # choose most frequent size
            sizes = {}
            for _, fr in ordered:
                sizes[fr.shape] = sizes.get(fr.shape, 0) + 1
            ref = max(sizes.items(), key=lambda kv: kv[1])[0]
            ordered = [(d, fr) for (d, fr) in ordered if fr.shape == ref]
            if len(ordered) < (self.T_win + 1):
                continue

            # build sliding windows over days (for 3 days and T_win=2 → exactly one window)
            frames = [fr for (_, fr) in ordered]
            daylist = [d for (d, _) in ordered]
            T = len(frames)
            for t0 in range(0, T - self.T_win):
                if t0 + self.T_win >= T:
                    break
                self.samples.append({
                    "key": gkey,
                    "days": daylist[t0 : t0 + self.T_win + 1],  # e.g., [1,2,3]
                    "frames": frames,   # keep full to slice on getitem
                    "t0": t0
                })

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        t0 = s["t0"]; T_win = self.T_win
        frames = s["frames"]
        X = np.stack(frames[t0 : t0 + T_win], axis=0)           # [T_win,H,W] (e.g., Day1, Day2)
        Y = np.stack(frames[t0 + 1 : t0 + 1 + T_win], axis=0)   # [T_win,H,W] (e.g., Day2, Day3)
        X = torch.from_numpy(X)[None, ...]  # [1,T,H,W]
        Y = torch.from_numpy(Y)[None, ...]
        meta = {"key": s["key"], "days": s["days"], "t0": t0, "T": T_win}
        return X, Y, None, meta

# --- Build the viz dataset from the trajectories you already have
# trajectories: { (sf,mat,med): {1:[idrecs], 2:[...], 3:[...]} }
# tif_map: {basename -> np.ndarray}
# If you have a zip/folder name index, pass it as tiff_name_index; otherwise omit.
viz_dataset = TrajectoryVizDataset(trajectories, tif_map, tiff_name_index=TIFF_NAME_INDEX, T_win=2)
print("Viz samples:", len(viz_dataset))
if len(viz_dataset) > 0:
    X, Y, _, m = viz_dataset[0]
    print("X:", X.shape, "Y:", Y.shape, "meta:", m)

# --- Render a few samples with your stretched triptych visualizer ---
device_viz = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device_viz).eval()

for i in range(min(5, len(viz_dataset))):
    try:
        predict_and_viz_triptych(
            model, viz_dataset, idx=i, device=device_viz,
            out_dir="viz_trajectories", sample_every=1, max_cols=8, make_gif=False
        )
    except Exception as e:
        print(f"[viz] sample {i} skipped: {e}")


Viz samples: 10
X: torch.Size([1, 2, 256, 256]) Y: torch.Size([1, 2, 256, 256]) meta: {'key': ('Ground', 'Cellulose Membrane', 'mAUMg-hi Pi'), 'days': [1, 2, 3], 't0': 0, 'T': 2}
[viz] wrote:
 - viz_trajectories/triptych_0_stretched.png
 - viz_trajectories/triptych_0_stretched.tif
[viz] wrote:
 - viz_trajectories/triptych_1_stretched.png
 - viz_trajectories/triptych_1_stretched.tif
[viz] wrote:
 - viz_trajectories/triptych_2_stretched.png
 - viz_trajectories/triptych_2_stretched.tif
[viz] wrote:
 - viz_trajectories/triptych_3_stretched.png
 - viz_trajectories/triptych_3_stretched.tif
[viz] wrote:
 - viz_trajectories/triptych_4_stretched.png
 - viz_trajectories/triptych_4_stretched.tif


In [74]:
predict_and_viz_triptych(
    model, viz_dataset, idx=0,
    out_dir="viz_trajectories_bright",
    sample_every=1, max_cols=8,
    make_gif=False
)

[viz] wrote:
 - viz_trajectories_bright/triptych_0_stretched.png
 - viz_trajectories_bright/triptych_0_stretched.tif


(PosixPath('viz_trajectories_bright/triptych_0_stretched.png'),
 PosixPath('viz_trajectories_bright/triptych_0_stretched.tif'),
 None)