In [None]:
%cd ..

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from omegaconf import OmegaConf

from safetensors import safe_open
from huggingface_hub import hf_hub_download

from modules.multiplex_virtues import MultiplexVirtues
from datasets.multiplex_dataset import MultiplexDataset
from utils.utils import load_marker_embeddings
from utils.masking import generate_mask

### Reconstruction and Inpainting
In this notebook, we demonstrate how to use VirTues to reconstruct partially or fully masked channels.

#### 1. Model Initialization

To get started, instantiate the VirTues model and load its pretrained weights.

A default configuration file is provided at `configs/base_config.yaml`. This file contains all parameters required for the released VirTues model.

In addition, you must specify a directory containing the embeddings for all markers used. Each embedding should be saved as a `.pt` file, named according to its respective UniProt ID.

In [None]:
conf = OmegaConf.load('configs/base_config.yaml')

PATH_MARKER_EMBEDDINGS = 'assets/example_dataset/marker_embeddings'

In [None]:
marker_embeddings = load_marker_embeddings(PATH_MARKER_EMBEDDINGS)

model = MultiplexVirtues(
    use_default_config = False,
    custom_config = None,
    prior_bias_embeddings=marker_embeddings,
    prior_bias_embedding_type='esm',
    prior_bias_embedding_fusion_type='add',
    patch_size=conf.model.patch_size,
    model_dim=conf.model.model_dim,
    feedforward_dim=conf.model.feedforward_dim,
    encoder_pattern=conf.model.encoder_pattern,
    num_encoder_heads=conf.model.num_encoder_heads,
    decoder_pattern=conf.model.decoder_pattern,
    num_decoder_heads=conf.model.num_decoder_heads,
    num_hidden_layers=conf.model.num_decoder_hidden_layers,
    positional_embedding_type=conf.model.positional_embedding_type,
    dropout=conf.model.dropout,
    group_layers=conf.model.group_layers,
    norm_after_encoder_decoder=conf.model.norm_after_encoder_decoder,
    verbose=False
)

We provide model weights of our pretrained VirTues instance on Hugging Face Hub. These can be downloaded via `hf_hub_download` as follows.

In [None]:
CACHE_DIR = 'assets/checkpoints'
hf_hub_download(repo_id='bunnelab/virtues', filename='model.safetensors', local_dir=CACHE_DIR)

weights = {}
with safe_open(os.path.join(CACHE_DIR, 'model.safetensors'), framework="pt", device='cpu') as f:
    for k in f.keys():
        weights[k] = f.get_tensor(k)
model.load_state_dict(weights)

model = model.cuda()
model = model.eval()

#### 2. Dataset Initialization
Next, let us instantiate a dataset. We provide a simple example dataset at `assets/example_dataset` consisting out of a single tissue image, which we can access using the class `MultiplexDataset`. 

In [None]:
ds_conf = OmegaConf.load('configs/datasets/example_config.yaml')['datasets']['example_dataset']

dataset = MultiplexDataset(
            tissue_dir=ds_conf.tissue_dir,
            crop_dir=ds_conf.crop_dir,
            mask_dir=ds_conf.mask_dir,
            tissue_index=ds_conf.tissue_index,
            crop_index=ds_conf.crop_index,
            channels_file=ds_conf.channels_file,
            quantiles_file=ds_conf.quantiles_file,
            means_file=ds_conf.means_file,
            stds_file=ds_conf.stds_file,
            marker_embedding_dir=PATH_MARKER_EMBEDDINGS,
            split='test',
            crop_size=conf.data.crop_size,
            patch_size=conf.model.patch_size,
            masking_ratio=conf.data.masking_ratio,
            channel_fraction=conf.data.channel_fraction,
    )

Using this dataset class, we can load a tissue image along with the indices of the marker embeddings.\
These indices specify both the identity and the ordering of the markers present in the image, allowing the model to correctly interpret the measurement channels.

In [None]:
x = dataset.get_tissue('cords24_ocmzljpb_1')
midxs = dataset.get_marker_indices()
crop = x[:, 10:138, 10:138]

#### 3. Reconstructing a Partially Masked Channel
We can reconstruct a masked image using a simple forward pass with VirTues. As a first step, we test reconstruction under independent partial masking applied to each channel. To generate these per-channel masks, you can use the helper function `utils.masking.generate_mask`.

In [None]:
mask = generate_mask(C=crop.shape[0], H=crop.shape[1]//8, W=crop.shape[2]//8, masking_ratio=(0.6,1.0))

crop = crop.cuda()
midx = midxs.cuda()
mask = mask.cuda()

with torch.no_grad():
    with torch.amp.autocast(device_type='cuda'):
        output = model.forward([crop], [midxs], [mask])
recon = output.decoded_multiplex

In [None]:
channel_idx = 5

fig, ax = plt.subplots(1,3, figsize=(12,4))

ax[0].imshow(crop[channel_idx].cpu().numpy(), cmap='inferno')
for i in range(mask.shape[1]):
    for j in range(mask.shape[2]):
        if mask[channel_idx,i,j]:
            ax[0].add_patch(plt.Rectangle((j*8, i*8), 8, 8, color='white'))
ax[0].set_title('Masked Input')

ax[1].imshow(recon[0][channel_idx].cpu().numpy(), cmap='inferno')
ax[1].set_title('Reconstruction')
ax[2].imshow(crop[channel_idx].cpu().numpy(), cmap='inferno')
ax[2].set_title('Original Crop')

#### 3. Reconstructing a Fully Masked Channel
VirTues also supports inpainting of fully masked channels. To do this, simply generate a 3D mask where the channel to be inpainted is set to True.

In [None]:
channel_idx = 3
mask = torch.zeros_like(mask)
mask[channel_idx] = True

In [None]:
with torch.no_grad():
    with torch.amp.autocast(device_type='cuda'):
        output = model.forward([crop], [midxs], [mask])
recon = output.decoded_multiplex

In [None]:
fig, ax = plt.subplots(1,3, figsize=(12,4))

ax[0].imshow(crop[channel_idx].cpu().numpy(), cmap='inferno')
for i in range(mask.shape[1]):
    for j in range(mask.shape[2]):
        if mask[channel_idx,i,j]:
            ax[0].add_patch(plt.Rectangle((j*8, i*8), 8, 8, color='white'))
ax[0].set_title('Masked Input')

ax[1].imshow(recon[0][channel_idx].cpu().numpy(), cmap='inferno')
ax[1].set_title('Reconstruction')
ax[2].imshow(crop[channel_idx].cpu().numpy(), cmap='inferno')
ax[2].set_title('Original Crop')