In [1]:
import os
import sys
import types
import importlib
from pathlib import Path

import torch
import torch.nn.functional as F

# Disable xformers to avoid attention-kernel issues across envs
os.environ.setdefault('XFORMERS_DISABLED', '1')

# Stub flash_attn to avoid binary import on environments without it
if 'flash_attn.flash_attn_interface' not in sys.modules:
    flash_attn_interface = types.ModuleType('flash_attn.flash_attn_interface')
    def flash_attn_func(*args, **kwargs):
        raise ImportError('flash_attn disabled for this smoke test')
    flash_attn_interface.flash_attn_func = flash_attn_func
    flash_attn = types.ModuleType('flash_attn')
    flash_attn.flash_attn_interface = flash_attn_interface
    sys.modules['flash_attn'] = flash_attn
    sys.modules['flash_attn.flash_attn_interface'] = flash_attn_interface

import cellvit.models.cell_segmentation.backbones_mmvirtues as backbones_mmvirtues
importlib.reload(backbones_mmvirtues)
import cellvit.models.cell_segmentation.cellvit_mmvirtues as cellvit_mmvirtues
importlib.reload(cellvit_mmvirtues)
from cellvit.models.cell_segmentation.cellvit_mmvirtues import CellViTMMVirtues

--------------------------------------------------------------------------------

  CuPy may not function correctly because multiple CuPy packages are installed
  in your environment:

    cupy, cupy-cuda12x

  Follow these steps to resolve this issue:

    1. For all packages listed above, run the following command to remove all
       existing CuPy installations:

         $ pip uninstall <package_name>

      If you previously installed CuPy via conda, also run the following:

         $ conda uninstall cupy

    2. Install the appropriate CuPy package.
       Refer to the Installation Guide for detailed instructions.

         https://docs.cupy.dev/en/stable/install.html

--------------------------------------------------------------------------------



In [2]:
# --- mmVIRTUES paths ---
default_root = Path('/scratch/mmvirtues_orion_dataset/virtues_example')
mmvirtues_root = Path(os.environ.get('MMVIRTUES_ROOT', str(default_root))).resolve()
weights_dir = Path(os.environ.get('MMVIRTUES_WEIGHTS', str(mmvirtues_root / 'mmvirtues_weights'))).resolve()

print('mmvirtues_root:', mmvirtues_root)
print('weights_dir:', weights_dir)

assert mmvirtues_root.exists(), f'MMVIRTUES_ROOT not found: {mmvirtues_root}'
assert (mmvirtues_root / 'datasets_loading').exists(), f'Expected datasets_loading/ under {mmvirtues_root}'
assert (mmvirtues_root / 'esm2_t30_150M_UR50D').exists(), f'Expected esm2_t30_150M_UR50D/ under {mmvirtues_root}'
assert (weights_dir / 'config.yaml').exists(), f'Missing config.yaml in {weights_dir}'
assert (weights_dir / 'teacher_checkpoint.pth').exists(), f'Missing teacher_checkpoint.pth in {weights_dir}'

mmvirtues_root: /scratch/mmvirtues_orion_dataset/virtues_example
weights_dir: /scratch/mmvirtues_orion_dataset/virtues_example/mmvirtues_weights


In [3]:
# --- Synthetic PanNuke-like batch (dims + mask keys) ---
B, H, W = 2, 256, 256
x = torch.rand(B, 3, H, W, dtype=torch.float32)

yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
instance_map = torch.zeros((B, H, W), dtype=torch.int64)
nuclei_type_map = torch.zeros((B, H, W), dtype=torch.int64)
nuclei_binary_map = torch.zeros((B, H, W), dtype=torch.int64)
hv_map = torch.zeros((B, 2, H, W), dtype=torch.float32)

# A few fake nuclei (simple disks) with types in {1..5}
for b in range(B):
    for inst_id, (cy, cx, r, t) in enumerate([(64, 64, 18, 1), (160, 120, 24, 3), (120, 200, 16, 5)], start=1):
        m = (yy - cy) ** 2 + (xx - cx) ** 2 <= r ** 2
        instance_map[b][m] = inst_id
        nuclei_type_map[b][m] = t
        nuclei_binary_map[b][m] = 1

masks = {
    'instance_map': instance_map,
    'nuclei_type_map': nuclei_type_map,
    'nuclei_binary_map': nuclei_binary_map,
    'hv_map': hv_map,
}

print('x:', tuple(x.shape), x.dtype, 'range:', float(x.min()), float(x.max()))
print({k: (tuple(v.shape), v.dtype) for k, v in masks.items()})

x: (2, 3, 256, 256) torch.float32 range: 6.556510925292969e-07 0.9999989867210388
{'instance_map': ((2, 256, 256), torch.int64), 'nuclei_type_map': ((2, 256, 256), torch.int64), 'nuclei_binary_map': ((2, 256, 256), torch.int64), 'hv_map': ((2, 2, 256, 256), torch.float32)}


In [4]:
# --- Build model + forward ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

model = CellViTMMVirtues(
    mmvirtues_weights_path=weights_dir,
    mmvirtues_root=mmvirtues_root,
    num_nuclei_classes=6,  # PanNuke: background + 5 types
    num_tissue_classes=19,
    regression_loss=False,
)
model.to(device)
model.eval()

x_dev = x.to(device)
with torch.no_grad():
    out = model(x_dev)

print({k: (tuple(v.shape) if torch.is_tensor(v) else type(v)) for k, v in out.items()})
print('marker_embeddings_dir:', model.encoder.marker_embeddings_dir)

device: cuda


[32m2025-12-16 16:32:31.572[0m | [1mINFO    [0m | [36mmodules.mmvirtues.layers[0m:[36m<module>[0m:[36m11[0m - [1mUsing xformers for FlexDualVirTues[0m
[32m2025-12-16 16:32:31.695[0m | [1mINFO    [0m | [36mmodules.mmvirtues.flex_dual_mmvirtues[0m:[36m__init__[0m:[36m103[0m - [1mUsing protein embedding: esm with shape torch.Size([213, 640])[0m
[32m2025-12-16 16:32:31.701[0m | [1mINFO    [0m | [36mmodules.mmvirtues.flex_dual_mmvirtues[0m:[36m__init__[0m:[36m117[0m - [1mUsing protein fusion type: add[0m


_IncompatibleKeys(missing_keys=[], unexpected_keys=['dino_head.mlp.0.weight', 'dino_head.mlp.0.bias', 'dino_head.mlp.2.weight', 'dino_head.mlp.2.bias', 'dino_head.mlp.4.weight', 'dino_head.mlp.4.bias', 'dino_head.last_layer.weight_g', 'dino_head.last_layer.weight_v'])
{'tissue_types': (2, 19), 'nuclei_binary_map': (2, 2, 256, 256), 'hv_map': (2, 2, 256, 256), 'nuclei_type_map': (2, 6, 256, 256)}
marker_embeddings_dir: /scratch/mmvirtues_orion_dataset/virtues_example/marker_embeddings_symlink


In [5]:
# --- Optional: 1-step backward (freeze mmVIRTUES encoder to keep it light) ---
for p in model.encoder.parameters():
    p.requires_grad = False

model.train()
optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-4)

x_train = x.to(device)
masks_train = {k: v.to(device) for k, v in masks.items()}

out = model(x_train)

nb_logits = out['nuclei_binary_map']
nb_target = masks_train['nuclei_binary_map'].long()
nt_logits = out['nuclei_type_map']
nt_target = masks_train['nuclei_type_map'].long()
hv_pred = out['hv_map']
hv_target = masks_train['hv_map'].float()

H, W = nb_target.shape[-2:]
if nb_logits.shape[-2:] != (H, W):
    nb_logits = F.interpolate(nb_logits, size=(H, W), mode='bilinear', align_corners=False)
if nt_logits.shape[-2:] != (H, W):
    nt_logits = F.interpolate(nt_logits, size=(H, W), mode='bilinear', align_corners=False)
if hv_pred.shape[-2:] != (H, W):
    hv_pred = F.interpolate(hv_pred, size=(H, W), mode='bilinear', align_corners=False)

loss_nb = F.cross_entropy(nb_logits, nb_target.clamp(0, 1))
loss_nt = F.cross_entropy(nt_logits, nt_target.clamp(0, 5))
loss_hv = F.l1_loss(hv_pred, hv_target)
loss = loss_nb + loss_nt + loss_hv

optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()

print('loss_total:', float(loss))
print('loss_nb:', float(loss_nb), 'loss_nt:', float(loss_nt), 'loss_hv:', float(loss_hv))

loss_total: 2.7761454582214355
loss_nb: 0.7834705114364624 loss_nt: 1.7265998125076294 loss_hv: 0.26607513427734375
