In [1]:
import os
import sys
import types
import importlib
import torch
from pathlib import Path

# Disable xformers to avoid GPU attention bias mismatch
os.environ["XFORMERS_DISABLED"] = "1"

# Stub flash_attn to avoid binary import on CPU
if 'flash_attn.flash_attn_interface' not in sys.modules:
    flash_attn_interface = types.ModuleType('flash_attn.flash_attn_interface')
    def flash_attn_func(*args, **kwargs):
        raise ImportError('flash_attn disabled for this test')
    flash_attn_interface.flash_attn_func = flash_attn_func
    flash_attn = types.ModuleType('flash_attn')
    flash_attn.flash_attn_interface = flash_attn_interface
    sys.modules['flash_attn'] = flash_attn
    sys.modules['flash_attn.flash_attn_interface'] = flash_attn_interface

import cellvit.models.cell_segmentation.backbones_mmvirtues as backbones_mmvirtues
importlib.reload(backbones_mmvirtues)
import cellvit.models.cell_segmentation.cellvit_mmvirtues as cellvit_mmvirtues
importlib.reload(cellvit_mmvirtues)
from cellvit.models.cell_segmentation.cellvit_mmvirtues import CellViTMMVirtues

--------------------------------------------------------------------------------

  CuPy may not function correctly because multiple CuPy packages are installed
  in your environment:

    cupy, cupy-cuda12x

  Follow these steps to resolve this issue:

    1. For all packages listed above, run the following command to remove all
       existing CuPy installations:

         $ pip uninstall <package_name>

      If you previously installed CuPy via conda, also run the following:

         $ conda uninstall cupy

    2. Install the appropriate CuPy package.
       Refer to the Installation Guide for detailed instructions.

         https://docs.cupy.dev/en/stable/install.html

--------------------------------------------------------------------------------



In [2]:
weights_dir = Path('/scratch/mmvirtues_orion_dataset/virtues_example/mmvirtues_weights')
mmvirtues_root = Path('/scratch/mmvirtues_orion_dataset/virtues_example')

In [3]:
model = CellViTMMVirtues(
    mmvirtues_weights_path=weights_dir,
    mmvirtues_root=mmvirtues_root,
    num_nuclei_classes=6,
    num_tissue_classes=19,
    regression_loss=False,
)

[32m2025-12-16 15:33:21.113[0m | [1mINFO    [0m | [36mmodules.mmvirtues.layers[0m:[36m<module>[0m:[36m11[0m - [1mUsing xformers for FlexDualVirTues[0m
[32m2025-12-16 15:33:21.572[0m | [1mINFO    [0m | [36mmodules.mmvirtues.flex_dual_mmvirtues[0m:[36m__init__[0m:[36m103[0m - [1mUsing protein embedding: esm with shape torch.Size([213, 640])[0m
[32m2025-12-16 15:33:21.576[0m | [1mINFO    [0m | [36mmodules.mmvirtues.flex_dual_mmvirtues[0m:[36m__init__[0m:[36m117[0m - [1mUsing protein fusion type: add[0m


_IncompatibleKeys(missing_keys=[], unexpected_keys=['dino_head.mlp.0.weight', 'dino_head.mlp.0.bias', 'dino_head.mlp.2.weight', 'dino_head.mlp.2.bias', 'dino_head.mlp.4.weight', 'dino_head.mlp.4.bias', 'dino_head.last_layer.weight_g', 'dino_head.last_layer.weight_v'])


In [4]:
model.eval()

CellViTMMVirtues(
  (encoder): MMVirtuesEncoder(
    (model): DinoVisionTransformer(
      (he_embed_layer): PatchEmbed(
        (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
        (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (mx_embed_layer): FlexDualVirTuesEncoder(
        (protein_encoder): Linear(in_features=640, out_features=1024, bias=True)
        (he_patch_encoder): Linear(in_features=588, out_features=1024, bias=True)
        (multiplex_patch_encoder): Linear(in_features=196, out_features=1024, bias=True)
        (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (encoder): ModuleList(
          (0): MarkerAttentionEncoderBlock(
            (encoder_layer): TransformerEncoder(
              (layers): ModuleList(
                (0-1): 2 x TransformerEncoderBlock(
                  (multi_head_attention): MHAwithPosEmb(
                    (W_q): Linear(in_features=1024, out_features=1024, bias=True)
 

In [5]:
x = torch.randn(1, 3, 224, 224)

In [6]:
# Debug: run encoder only first to confirm token shapes
with torch.no_grad():
    logits, _, z = model.encoder(x)
print("tissue logits:", logits.shape)
print("intermediate token shapes:", [t.shape for t in z])

# Full forward (segmentation heads)
with torch.no_grad():
    out = model(x)
print({k: v.shape if torch.is_tensor(v) else type(v) for k, v in out.items()})
print("marker_embeddings_dir:", model.encoder.marker_embeddings_dir)

tissue logits: torch.Size([1, 19])
intermediate token shapes: [torch.Size([1, 257, 1024]), torch.Size([1, 257, 1024]), torch.Size([1, 257, 1024]), torch.Size([1, 257, 1024])]
{'tissue_types': torch.Size([1, 19]), 'nuclei_binary_map': torch.Size([1, 2, 256, 256]), 'hv_map': torch.Size([1, 2, 256, 256]), 'nuclei_type_map': torch.Size([1, 6, 256, 256])}
marker_embeddings_dir: /scratch/mmvirtues_orion_dataset/virtues_example/marker_embeddings_symlink


In [7]:
print({k: v.shape if torch.is_tensor(v) else type(v) for k, v in out.items()})

{'tissue_types': torch.Size([1, 19]), 'nuclei_binary_map': torch.Size([1, 2, 256, 256]), 'hv_map': torch.Size([1, 2, 256, 256]), 'nuclei_type_map': torch.Size([1, 6, 256, 256])}
