# CellViT + mmVIRTUES â€” Smoke test (portable)

Ce notebook charge `CellViTMMVirtues` en utilisant uniquement `mmvirtues_root` + `mmvirtues_weights_path`.

## Variables d'environnement (optionnel)
- `MMVIRTUES_ROOT` : dossier `.../virtues_example` (contient `datasets_loading/`, `modules/`, `esm2_t30_150M_UR50D/`, etc.)
- `MMVIRTUES_WEIGHTS` : dossier `.../virtues_example/mmvirtues_weights` (contient `config.yaml` + `teacher_checkpoint.pth`)

In [1]:
import os
import sys
import types
import importlib
from pathlib import Path

import torch

# Disable xformers to avoid GPU attention bias mismatch
os.environ["XFORMERS_DISABLED"] = "1"

# Stub flash_attn to avoid binary import on CPU
if 'flash_attn.flash_attn_interface' not in sys.modules:
    flash_attn_interface = types.ModuleType('flash_attn.flash_attn_interface')
    def flash_attn_func(*args, **kwargs):
        raise ImportError('flash_attn disabled for this test')
    flash_attn_interface.flash_attn_func = flash_attn_func
    flash_attn = types.ModuleType('flash_attn')
    flash_attn.flash_attn_interface = flash_attn_interface
    sys.modules['flash_attn'] = flash_attn
    sys.modules['flash_attn.flash_attn_interface'] = flash_attn_interface

import cellvit.models.cell_segmentation.backbones_mmvirtues as backbones_mmvirtues
importlib.reload(backbones_mmvirtues)
import cellvit.models.cell_segmentation.cellvit_mmvirtues as cellvit_mmvirtues
importlib.reload(cellvit_mmvirtues)
from cellvit.models.cell_segmentation.cellvit_mmvirtues import CellViTMMVirtues

--------------------------------------------------------------------------------

  CuPy may not function correctly because multiple CuPy packages are installed
  in your environment:

    cupy, cupy-cuda12x

  Follow these steps to resolve this issue:

    1. For all packages listed above, run the following command to remove all
       existing CuPy installations:

         $ pip uninstall <package_name>

      If you previously installed CuPy via conda, also run the following:

         $ conda uninstall cupy

    2. Install the appropriate CuPy package.
       Refer to the Installation Guide for detailed instructions.

         https://docs.cupy.dev/en/stable/install.html

--------------------------------------------------------------------------------



In [2]:
default_root = Path('/scratch/mmvirtues_orion_dataset/virtues_example')
mmvirtues_root = Path(os.environ.get('MMVIRTUES_ROOT', str(default_root)))
weights_dir = Path(os.environ.get('MMVIRTUES_WEIGHTS', str(mmvirtues_root / 'mmvirtues_weights')))

print('mmvirtues_root:', mmvirtues_root)
print('weights_dir:', weights_dir)

assert mmvirtues_root.exists(), f'MMVIRTUES_ROOT not found: {mmvirtues_root}'
assert (mmvirtues_root / 'datasets_loading').exists(), f'Expected datasets_loading/ under {mmvirtues_root}'
assert (mmvirtues_root / 'esm2_t30_150M_UR50D').exists(), f'Expected esm2_t30_150M_UR50D/ under {mmvirtues_root}'
assert (weights_dir / 'config.yaml').exists(), f'Missing config.yaml in {weights_dir}'
assert (weights_dir / 'teacher_checkpoint.pth').exists(), f'Missing teacher_checkpoint.pth in {weights_dir}'

mmvirtues_root: /scratch/mmvirtues_orion_dataset/virtues_example
weights_dir: /scratch/mmvirtues_orion_dataset/virtues_example/mmvirtues_weights


In [3]:
model = CellViTMMVirtues(
    mmvirtues_weights_path=weights_dir,
    mmvirtues_root=mmvirtues_root,
    num_nuclei_classes=6,
    num_tissue_classes=19,
    regression_loss=False,
)
model.eval()

x = torch.randn(1, 3, 224, 224)

with torch.no_grad():
    logits, _, z = model.encoder(x)
print('tissue logits:', tuple(logits.shape))
print('intermediate token shapes:', [tuple(t.shape) for t in z])

with torch.no_grad():
    out = model(x)
print({k: (tuple(v.shape) if torch.is_tensor(v) else type(v)) for k, v in out.items()})
print('marker_embeddings_dir:', model.encoder.marker_embeddings_dir)

[32m2025-12-16 15:35:36.047[0m | [1mINFO    [0m | [36mmodules.mmvirtues.layers[0m:[36m<module>[0m:[36m11[0m - [1mUsing xformers for FlexDualVirTues[0m
[32m2025-12-16 15:35:36.214[0m | [1mINFO    [0m | [36mmodules.mmvirtues.flex_dual_mmvirtues[0m:[36m__init__[0m:[36m103[0m - [1mUsing protein embedding: esm with shape torch.Size([213, 640])[0m
[32m2025-12-16 15:35:36.220[0m | [1mINFO    [0m | [36mmodules.mmvirtues.flex_dual_mmvirtues[0m:[36m__init__[0m:[36m117[0m - [1mUsing protein fusion type: add[0m


_IncompatibleKeys(missing_keys=[], unexpected_keys=['dino_head.mlp.0.weight', 'dino_head.mlp.0.bias', 'dino_head.mlp.2.weight', 'dino_head.mlp.2.bias', 'dino_head.mlp.4.weight', 'dino_head.mlp.4.bias', 'dino_head.last_layer.weight_g', 'dino_head.last_layer.weight_v'])
tissue logits: (1, 19)
intermediate token shapes: [(1, 257, 1024), (1, 257, 1024), (1, 257, 1024), (1, 257, 1024)]
{'tissue_types': (1, 19), 'nuclei_binary_map': (1, 2, 256, 256), 'hv_map': (1, 2, 256, 256), 'nuclei_type_map': (1, 6, 256, 256)}
marker_embeddings_dir: /scratch/mmvirtues_orion_dataset/virtues_example/marker_embeddings_symlink
