In [1]:
%pip install monai

Collecting monai
  Downloading monai-1.5.1-py3-none-any.whl.metadata (13 kB)
Downloading monai-1.5.1-py3-none-any.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m58.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.5.1


In [2]:
# Authenticate Google account and set up project access
from google.colab import auth
auth.authenticate_user()

!gcloud config set account secasarr@ucsc.edu
!gcloud config set project clinimcl

# Confirm GPU availability
!nvidia-smi

import torch, monai
print(f"[device] {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu'} | torch {torch.__version__} | monai {monai.__version__}")


Updated property [core/account].
Are you sure you wish to set property [core/project] to clinimcl?

Do you want to continue (Y/n)?  Y

Updated property [core/project].
Sun Nov  9 22:31:48 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   31C    P0             50W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                     

In [3]:
import gcsfs
from torch.serialization import add_safe_globals
from monai.data.meta_tensor import MetaTensor
from monai.utils.enums import TraceKeys
import numpy as np

# --- dynamically add numpy internals (for torch.load) ---
safe_globals = [MetaTensor, TraceKeys]
try:
    import numpy.core.multiarray as multiarray
    safe_globals.append(multiarray._reconstruct)
except Exception as e:
    print(f"[warn] could not import numpy.core.multiarray._reconstruct: {e}")

try:
    import numpy.dtype as np_dtype
    safe_globals.append(np_dtype)
except Exception:
    pass

# Register safe globals with PyTorch
add_safe_globals(safe_globals)
print(f"[torch] registered {len(safe_globals)} safe globals for torch.load")

# --- Connect to GCS using Colab credentials ---
fs = gcsfs.GCSFileSystem(token='google_default')
base_path = "gs://clinimcl-data/OASIS3"
preproc_path = f"{base_path}/preprocessed"
raw_path = f"{base_path}/raw"

# Verify access
files = fs.glob(f"{preproc_path}/**/*.pt")
print(f"[gcs] Found {len(files)} preprocessed .pt files. Showing first 5:")
print("\n".join(files[:5]))


[torch] registered 3 safe globals for torch.load
[gcs] Found 2832 preprocessed .pt files. Showing first 5:
clinimcl-data/OASIS3/preprocessed/OAS30001_d0129.pt
clinimcl-data/OASIS3/preprocessed/OAS30001_d0757.pt
clinimcl-data/OASIS3/preprocessed/OAS30001_d2430.pt
clinimcl-data/OASIS3/preprocessed/OAS30001_d3132.pt
clinimcl-data/OASIS3/preprocessed/OAS30001_d3746.pt


In [4]:
import io, torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random

# ------------------------------------------------------------------
# Load ALL .pt files from your GCS bucket
# ------------------------------------------------------------------
all_files = [f.split("/")[-1] for f in fs.ls(preproc_path) if f.endswith(".pt")]
print(f"[dataset] Using all {len(all_files)} preprocessed files")

# ------------------------------------------------------------------
# Dataset definition
# ------------------------------------------------------------------
class OASIS3Dataset(Dataset):
    def __init__(self, fs, preproc_path, subjects, target_shape=(128,128,128)):
        self.fs = fs
        self.files = [f"{preproc_path}/{s}" for s in subjects]
        self.target_shape = target_shape

    def __len__(self):
        return len(self.files)

    def _pad_or_crop(self, x):
        c, d, h, w = x.shape
        td, th, tw = self.target_shape
        # crop center if too big
        if d > td or h > th or w > tw:
            d0 = (d - td)//2
            h0 = (h - th)//2
            w0 = (w - tw)//2
            x = x[:, d0:d0+td, h0:h0+th, w0:w0+tw]
        # pad if too small
        pad_d = max(0, td - x.shape[1])
        pad_h = max(0, th - x.shape[2])
        pad_w = max(0, tw - x.shape[3])
        if pad_d or pad_h or pad_w:
            x = F.pad(x, (0,pad_w,0,pad_h,0,pad_d))
        return x

    def __getitem__(self, idx):
        path = self.files[idx]
        with self.fs.open(path, "rb") as f:
            buf = io.BytesIO(f.read())
            tensor = torch.load(buf, map_location="cpu", weights_only=False)
        tensor = tensor.float()
        tensor = self._pad_or_crop(tensor)
        return tensor

# ------------------------------------------------------------------
# Build dataset & loader for full training set
# ------------------------------------------------------------------
target_shape = (128,128,128)
dataset = OASIS3Dataset(fs, preproc_path, all_files, target_shape)
loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)

# Smoke test
batch = next(iter(loader))
print(f"[✓] Loaded batch of shape {batch.shape}")


[dataset] Using all 2832 preprocessed files
[✓] Loaded batch of shape torch.Size([2, 1, 128, 128, 128])


In [5]:
from monai.networks.nets import resnet50
import torch
import torch.nn as nn
import torch.nn.functional as F

class ProjectionHead(nn.Module):
    def __init__(self, in_dim=2048, proj_dim=128, hidden=2048, use_bn=True):
        super().__init__()
        layers = [nn.Linear(in_dim, hidden), nn.GELU()]
        if use_bn:
            layers.append(nn.BatchNorm1d(hidden))
        layers += [nn.Linear(hidden, proj_dim)]
        self.net = nn.Sequential(*layers)

    def forward(self, x, l2norm=True):
        z = self.net(x)
        return F.normalize(z, dim=-1) if l2norm else z

class ContrastiveModel(nn.Module):
    def __init__(self, proj_dim=128):
        super().__init__()
        self.encoder = resnet50(spatial_dims=3, n_input_channels=1, num_classes=0)
        self.projector = ProjectionHead(in_dim=2048, proj_dim=proj_dim)

    @torch.cuda.amp.autocast(enabled=True)
    def forward(self, x, return_feats=False):
        feats = self.encoder(x)              # [B, 2048]
        z = self.projector(feats)            # [B, proj_dim], L2-normalized
        return (z, feats) if return_feats else z

model = ContrastiveModel(proj_dim=128).cuda()
print(f"[model] total params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")


  @torch.cuda.amp.autocast(enabled=True)


[model] total params: 50.62M


In [15]:
# ===== Cell 5 (FAST): Longitudinal contrastive training (streaming, subject-subset per epoch, 20–30 epochs, auto-upload) =====
# Goals for ~5h total:
# - IMG_SIZE=96
# - Lighter encoder (base=16)
# - Batch=8 (adjust if OOM)
# - Each epoch trains on ~25% of subjects (rotating subset), ~120 steps/epoch
# - 20–30 epochs total still cover full dataset multiple times

import os, io, re, time, math, random
from collections import defaultdict
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import GradScaler, autocast

import gcsfs

# -----------------------------
# Speed-tuned config
# -----------------------------
device        = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE    = 8                 # try 8 on A100; if OOM, set 6
EPOCHS        = 10                # set 20 (or 30 if you have time)
IMG_SIZE      = 96                # downscale from 128 -> 96 (big speedup)
LR            = 1e-4
TEMP          = 0.07
POS_PROB      = 0.5
PRINT_EVERY   = 50
EPOCH_SUBJECT_FRACTION = 0.25     # ~25% subjects per epoch (~120 steps not 472)
GCS_BUCKET_PREFIX = "gs://clinimcl-data/OASIS3/preprocessed/"
CHECKPOINT_BUCKET = "gs://clinimcl-data/checkpoints/"

print(f"[env] device={device} | epochs={EPOCHS} | batch={BATCH_SIZE} | img={IMG_SIZE}")

# Perf flags
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# ----------------------------------------
# Resize helper (keeps C=1)
# ----------------------------------------
def ensure_size(x: torch.Tensor, side: int = IMG_SIZE) -> torch.Tensor:
    if x.ndim == 3:
        x = x.unsqueeze(0)
    if x.shape[0] != 1:
        x = x[:1]
    if tuple(x.shape[1:]) == (side, side, side):
        return x
    x = x.unsqueeze(0)
    x = F.interpolate(x, size=(side, side, side), mode="trilinear", align_corners=False)
    return x.squeeze(0)

# ----------------------------------------
# Index GCS
# ----------------------------------------
fs = gcsfs.GCSFileSystem(token="google_default")
def _to_gs(p: str) -> str: return p if p.startswith("gs://") else f"gs://{p}"

raw_paths = fs.ls(GCS_BUCKET_PREFIX)
pt_files  = [_to_gs(p) for p in raw_paths if p.lower().endswith(".pt")]
print(f"[data] files={len(pt_files)}")

_fname_re = re.compile(r"(OAS3\d+|OAS\d+)\D*?_d(\d+)\.pt$", re.IGNORECASE)
def parse_subject_day(path: str):
    m = _fname_re.search(os.path.basename(path))
    return (m.group(1), int(m.group(2))) if m else (None, None)

subjects_full = defaultdict(list)
for p in pt_files:
    s, d = parse_subject_day(p)
    if s: subjects_full[s].append((d, p))
for s in subjects_full: subjects_full[s].sort(key=lambda t: t[0])

subjects_with_pairs_full = {s: tps for s, tps in subjects_full.items() if len(tps) >= 2}
all_subject_ids = list(subjects_full.keys())
pair_subject_ids = list(subjects_with_pairs_full.keys())
print(f"[data] subjects={len(all_subject_ids)} | with≥2 tps={len(pair_subject_ids)}")

# ----------------------------------------
# Safe globals for torch.load
# ----------------------------------------
from torch.serialization import add_safe_globals
try:
    from monai.data.meta_tensor import MetaTensor
    from monai.utils.enums import TraceKeys
    add_safe_globals([MetaTensor, TraceKeys, np.ndarray])
except Exception:
    add_safe_globals([np.ndarray])

# ----------------------------------------
# Streaming loader (NO local cache)
# ----------------------------------------
def load_pt(gs_url: str) -> torch.Tensor:
    with fs.open(gs_url, "rb") as f:
        buf = io.BytesIO(f.read())
    vol = torch.load(buf, map_location="cpu", weights_only=False)
    vol = torch.as_tensor(vol, dtype=torch.float32)
    vmin, vmax = float(vol.min()), float(vol.max())
    if vmax > vmin:
        vol = (vol - vmin) / (vmax - vmin + 1e-6)
    return ensure_size(vol).contiguous()

class PairDataset(Dataset):
    """Pairs from a *restricted subject subset* for the current epoch."""
    def __init__(self, subs_all, subs_pairs, allowed_all_ids, allowed_pair_ids, p_pos=0.5):
        self.sa = subs_all
        self.sp = subs_pairs
        self.ids_all  = allowed_all_ids
        self.ids_pair = allowed_pair_ids
        self.p_pos = p_pos
    def __len__(self): return max(2000, len(self.ids_all)*20)
    def __getitem__(self, idx):
        if self.ids_pair and (random.random() < self.p_pos):
            s = random.choice(self.ids_pair)
            (d1,p1),(d2,p2) = random.sample(self.sp[s], 2)
            return load_pt(p1), load_pt(p2), 1, p1, p2
        s1, s2 = random.sample(self.ids_all, 2)
        d1,p1 = random.choice(self.sa[s1])
        d2,p2 = random.choice(self.sa[s2])
        return load_pt(p1), load_pt(p2), 0, p1, p2

def make_loader_for_epoch(epoch_idx: int):
    """Rotate a subject subset each epoch to cover whole dataset across epochs."""
    rng = random.Random(epoch_idx + 12345)
    # choose subset of subjects
    n_all = len(all_subject_ids)
    n_pair = len(pair_subject_ids)
    k_all  = max(1, int(EPOCH_SUBJECT_FRACTION * n_all))
    k_pair = max(1, int(EPOCH_SUBJECT_FRACTION * n_pair))

    allowed_all  = rng.sample(all_subject_ids,  k_all)
    allowed_pair = rng.sample(pair_subject_ids, k_pair)

    ds = PairDataset(subjects_full, subjects_with_pairs_full, allowed_all, allowed_pair, p_pos=POS_PROB)
    # steps ≈ (#timepoints in subset)/BATCH is hard to know exactly; approximate by scaling original steps
    total_steps_full = math.ceil(len(pt_files) / BATCH_SIZE)
    steps_this_epoch = max(60, int(total_steps_full * EPOCH_SUBJECT_FRACTION))  # ~472*0.25 ≈ 118
    loader = DataLoader(
        ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0,
        pin_memory=(device=="cuda"), drop_last=True
    )
    return loader, steps_this_epoch

# ----------------------------------------
# Model (lighter + faster)
# ----------------------------------------
class _ConvBlock(nn.Module):
    def __init__(self, cin, cout):
        super().__init__()
        self.conv = nn.Conv3d(cin, cout, 3, padding=1)
        self.bn   = nn.BatchNorm3d(cout)
        self.act  = nn.ReLU(inplace=True)
    def forward(self, x): return self.act(self.bn(self.conv(x)))

class SafeEncoder3D(nn.Module):
    def __init__(self, in_ch=1, base=16, out_dim=256):  # base=16 (lighter than 32)
        super().__init__()
        self.b1=_ConvBlock(in_ch,base);       self.b2=_ConvBlock(base,base*2)
        self.b3=_ConvBlock(base*2,base*4);    self.b4=_ConvBlock(base*4,base*8)
        self.pool=nn.AdaptiveAvgPool3d(1);    self.fc=nn.Linear(base*8,out_dim)
    def forward(self,x):
        x=self.b1(F.max_pool3d(x,2))  # 96->48
        x=self.b2(F.max_pool3d(x,2))  # 48->24
        x=self.b3(F.max_pool3d(x,2))  # 24->12
        x=self.b4(F.max_pool3d(x,2))  # 12->6
        return self.fc(self.pool(x).flatten(1))

class ContrastiveModel(nn.Module):
    def __init__(self, proj_dim=128):
        super().__init__()
        self.backbone=SafeEncoder3D()
        self.proj=nn.Sequential(nn.Linear(256,256), nn.ReLU(True), nn.Linear(256,proj_dim))
    def forward(self,x):
        h=self.backbone(x)
        z=F.normalize(self.proj(h), dim=1)
        return z,h

model = ContrastiveModel(128).to(device)

def info_nce(z1,z2,temp=0.07):
    z1=F.normalize(z1,dim=1); z2=F.normalize(z2,dim=1)
    logits=(z1@z2.t())/temp
    labels=torch.arange(z1.size(0), device=z1.device)
    return 0.5*(F.cross_entropy(logits,labels)+F.cross_entropy(logits.t(),labels))

opt    = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scaler = GradScaler("cuda", enabled=(device=="cuda"))

# Quick shape check on a small epoch subset
_loader_preview, _ = make_loader_for_epoch(0)
with torch.inference_mode():
    a,b,y,p1,p2 = next(iter(_loader_preview))
    print(f"[check] batch a={tuple(a.shape)} b={tuple(b.shape)} y={y.tolist()[:8]}")

# ----------------------------------------
# Train with rotating subject subsets
# ----------------------------------------
total_steps_full = math.ceil(len(pt_files) / BATCH_SIZE)
print(f"\n[train] epochs={EPOCHS} | full-steps≈{total_steps_full} | per-epoch fraction={EPOCH_SUBJECT_FRACTION}")

global_step=0; t_all=time.time()
for ep in range(1, EPOCHS + 1):
    loader, steps_per_epoch = make_loader_for_epoch(ep)
    model.train(); running=[]; seen=set(); t0=time.time(); step=0

    for x1, x2, _, p1s, p2s in loader:
        for q in p1s: seen.add(q)
        for q in p2s: seen.add(q)

        x1 = x1.contiguous(memory_format=torch.channels_last_3d).to(device, non_blocking=True)
        x2 = x2.contiguous(memory_format=torch.channels_last_3d).to(device, non_blocking=True)

        opt.zero_grad(set_to_none=True)
        with autocast(device_type="cuda", enabled=(device == "cuda")):
            z1, _ = model(x1)
            z2, _ = model(x2)
            loss = info_nce(z1, z2, temp=TEMP)
        scaler.scale(loss).backward(); scaler.step(opt); scaler.update()

        running.append(loss.item()); global_step += 1; step += 1
        if PRINT_EVERY and (global_step % PRINT_EVERY == 0):
            print(f"  [ep {ep:02d} step {global_step:05d}] loss={np.mean(running):.4f}")
        if step >= steps_per_epoch: break

    cov = 100.0 * len(seen) / len(pt_files) if pt_files else 0.0
    print(f"[ep {ep:02d}] time={time.time()-t0:.1f}s | steps={step} | mean_loss={np.mean(running):.4f} | coverage≈{cov:.2f}% (subset)")

    # ---- SAVE + UPLOAD EACH EPOCH ----
    STAMP = time.strftime("%Y%m%d_%H%M%S")
    CKPT_PATH = f"/content/clinimcl_contrastive_fast_ep{ep:02d}.pth"
    torch.save({
        "epoch": ep,
        "model": model.state_dict(),
        "optimizer": opt.state_dict(),
        "cfg": {
            "IMG_SIZE": IMG_SIZE,
            "TEMP": TEMP,
            "proj_dim": 128,
            "EPOCH_SUBJECT_FRACTION": EPOCH_SUBJECT_FRACTION,
            "BATCH_SIZE": BATCH_SIZE
        }
    }, CKPT_PATH)
    print(f"[save] wrote local {CKPT_PATH}")

    UPLOAD_PATH = f"{CHECKPOINT_BUCKET}contrastive_fast_ep{ep:02d}_{STAMP}.pth"
    ret = os.system(f"gsutil cp {CKPT_PATH} {UPLOAD_PATH}")
    if ret == 0:
        print(f"[upload] uploaded to {UPLOAD_PATH}")
        os.remove(CKPT_PATH)
        print("[save] removed local checkpoint to free space")
    else:
        print("[warn] upload failed; keeping local checkpoint")


print(f"\nTraining finished in {time.time()-t_all:.1f}s for {EPOCHS} epochs (rotating subject subsets)")


[env] device=cuda | epochs=10 | batch=8 | img=96
[data] files=2832
[data] subjects=1376 | with≥2 tps=685
[check] batch a=(8, 1, 96, 96, 96) b=(8, 1, 96, 96, 96) y=[0, 1, 1, 1, 1, 1, 1, 0]

[train] epochs=10 | full-steps≈354 | per-epoch fraction=0.25
  [ep 01 step 00050] loss=2.0392
[ep 01] time=2386.3s | steps=88 | mean_loss=2.0532 | coverage≈25.56% (subset)
[save] wrote local /content/clinimcl_contrastive_fast_ep01.pth
[warn] upload failed; keeping local checkpoint
  [ep 02 step 00100] loss=2.0343
  [ep 02 step 00150] loss=2.0659
[ep 02] time=2452.5s | steps=88 | mean_loss=2.0569 | coverage≈25.56% (subset)
[save] wrote local /content/clinimcl_contrastive_fast_ep02.pth
[upload] uploaded to gs://clinimcl-data/checkpoints/contrastive_fast_ep02_20251110_010525.pth
[save] removed local checkpoint to free space
  [ep 03 step 00200] loss=1.9822
  [ep 03 step 00250] loss=1.9850
[ep 03] time=2284.9s | steps=88 | mean_loss=1.9870 | coverage≈26.06% (subset)
[save] wrote local /content/clinimcl_c