In [None]:
%cd ..

In [None]:
import os
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('talk')
from umap import UMAP
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

from omegaconf import OmegaConf
from safetensors import safe_open
from huggingface_hub import hf_hub_download


from modules.multiplex_virtues import MultiplexVirtues
from datasets.multiplex_dataset import MultiplexDataset
from utils.utils import load_marker_embeddings
from utils.cell_tokens import compute_cell_tokens

### Cell Phenotyping with Cell Summary Tokens
In this notebook, we demonstrate how to use VirTues to compute cell summary tokens and apply them to cell phenotyping.

#### 1. Model Initialization

To get started, instantiate the VirTues model and load its pretrained weights.

A default configuration file is provided at `configs/base_config.yaml`. This file contains all parameters required for the released VirTues model.

In addition, you must specify a directory containing the embeddings for all markers used. Each embedding should be saved as a `.pt` file, named according to its respective UniProt ID.

In [None]:
conf = OmegaConf.load('configs/base_config.yaml')

PATH_MARKER_EMBEDDINGS = 'assets/example_dataset/marker_embeddings'

In [None]:
marker_embeddings = load_marker_embeddings(PATH_MARKER_EMBEDDINGS)

model = MultiplexVirtues(
    use_default_config = False,
    custom_config = None,
    prior_bias_embeddings=marker_embeddings,
    prior_bias_embedding_type='esm',
    prior_bias_embedding_fusion_type='add',
    patch_size=conf.model.patch_size,
    model_dim=conf.model.model_dim,
    feedforward_dim=conf.model.feedforward_dim,
    encoder_pattern=conf.model.encoder_pattern,
    num_encoder_heads=conf.model.num_encoder_heads,
    decoder_pattern=conf.model.decoder_pattern,
    num_decoder_heads=conf.model.num_decoder_heads,
    num_hidden_layers=conf.model.num_decoder_hidden_layers,
    positional_embedding_type=conf.model.positional_embedding_type,
    dropout=conf.model.dropout,
    group_layers=conf.model.group_layers,
    norm_after_encoder_decoder=conf.model.norm_after_encoder_decoder,
    verbose=False
)

We provide model weights of our pretrained VirTues instance on Hugging Face Hub. These can be downloaded via `hf_hub_download` as follows.

In [None]:
CACHE_DIR = 'assets/checkpoints'
hf_hub_download(repo_id='bunnelab/virtues', filename='model.safetensors', local_dir=CACHE_DIR)

weights = {}
with safe_open(os.path.join(CACHE_DIR, 'model.safetensors'), framework="pt", device='cpu') as f:
    for k in f.keys():
        weights[k] = f.get_tensor(k)
model.load_state_dict(weights)

model = model.cuda()
model = model.eval()

#### 2. Dataset Initialization
Next, let us instantiate a dataset. We provide a simple example dataset at `assets/example_dataset` consisting out of a single tissue image, which we can access using the class `MultiplexDataset`. 

In [None]:
ds_conf = OmegaConf.load('configs/datasets/example_config.yaml')['datasets']['example_dataset']

dataset = MultiplexDataset(
            tissue_dir=ds_conf.tissue_dir,
            crop_dir=ds_conf.crop_dir,
            mask_dir=ds_conf.mask_dir,
            tissue_index=ds_conf.tissue_index,
            crop_index=ds_conf.crop_index,
            channels_file=ds_conf.channels_file,
            quantiles_file=ds_conf.quantiles_file,
            means_file=ds_conf.means_file,
            stds_file=ds_conf.stds_file,
            marker_embedding_dir=PATH_MARKER_EMBEDDINGS,
            split='test',
            crop_size=conf.data.crop_size,
            patch_size=conf.model.patch_size,
            masking_ratio=conf.data.masking_ratio,
            channel_fraction=conf.data.channel_fraction,
    )

Using this dataset class, we can load a tissue image along with its corresponding cell-segmentation mask and the indices of the marker embeddings.\
These indices specify both the identity and the ordering of the markers present in the image, allowing the model to correctly interpret the measurement channels.

In [None]:
tid = 'cords24_ocmzljpb_1'
x = dataset.get_tissue(tid, preprocess=False)
midxs = dataset.get_marker_indices()
mask = dataset.get_segmentation_mask(tid)

#### 3. Computing Cell Summary Tokens
We compute cell summary tokens by first embedding the entire image into patch-level summary tokens using VirTues. These patch tokens are then aggregated according to the segmentation mask, similiar to a convolution, to produce cell-level tokens.
We provide the utility function `utils.cell_tokens.compute_cell_tokens` to perform this computation. Importantly, we recommend padding the image with zeros (corresponding to no signal) before preprocessing.

In [None]:
pad_size = 120

x = torch.nn.functional.pad(x, pad=(pad_size, pad_size, pad_size, pad_size), mode='constant', value=0)
mask = torch.nn.functional.pad(mask, pad=(pad_size, pad_size, pad_size, pad_size), mode='constant', value=0)

x = dataset._preprocess(tissue_id=tid, multiplex=x)

cell_ids, cell_tokens, crop_tokens, indices = compute_cell_tokens(model, x, midxs, mask)

#### 4. UMAP visualization of the cell tokens
Let us visualize the structure of our cell tokens as as a 2D UMAP projection.
For the sample tissue, we provide cell type annotations at `assets/example_dataset/sce_annotations.csv`. We can use these to color our UMAP.

In [None]:
um = UMAP(n_components=2, random_state=42)
cell_tokens_2d = um.fit_transform(cell_tokens.numpy())

In [None]:
annotations = pd.read_csv('assets/example_dataset/sce_annotations.csv')
annotations.set_index('cell_id', inplace=True)
labels = annotations.loc[cell_ids]['cell_category_regen'].values

In [None]:
fig, ax = plt.subplots(figsize=(8,8))
sns.scatterplot(x=cell_tokens_2d[:,0], y=cell_tokens_2d[:,1], hue=labels, hue_order=['Tumor', 'Fibroblast', 'Immune', 'T cell', 'Vessel', 'Other'], palette='tab10', s=20)

#### 5. Cell Phenotyping using Logistic Regression
Cell tokens can be leveraged for downstream cell phenotyping. A straightforward approach is to train a simple logistic regression model using the token representations as input features.\
For illustration, the example below uses a random trainâ€“test split of the loaded cells. In practice, however, these splits group-wise split by patient identities to avoid data leakage and ensure proper evaluation.

In [None]:
train_X, test_X, train_y, test_y = train_test_split(cell_tokens.numpy(), labels, test_size=0.2, random_state=0, stratify=labels)

In [None]:
lr_model = LogisticRegression(max_iter=1000)
lr_model = lr_model.fit(train_X, train_y)

In [None]:
print(classification_report(test_y, lr_model.predict(test_X), digits=3))