# Zero-shot cell type annotation (For Power Users)

NB - we recommend users interface with the model using the CLI and yaml configs. 

Assinging cell type annotations is an import and time consuming part of single-cell analysis using `biomed-multi-omic` for cell type annotation. BMFM-RNA simplifies this process by not only performing the cell-type annotation but also the preprocessing and visualisation through the embeddings created by the model.

In this tutorial we approach this from a power user's perspective, interfacing the with the code-base directly and setting up the relevant pytorch-lightening modules.

In [1]:
import pickle
from pathlib import Path

import scanpy as sc
import torch

from bmfm_targets import config
from bmfm_targets.evaluation.utils import check_gpu, get_label_map, merge_bmfm_adata
from bmfm_targets.tasks.task_utils import (
    instantiate_module_from_checkpoint,
    make_trainer_for_task,
    predict,
)
from bmfm_targets.tokenization import load_tokenizer
from bmfm_targets.training.modules import DataModule

DEVICE = check_gpu()

Using MPS


## Load Example Data

To demostrate the BMFM-RNAs abilites, we use the PBMC data created by 10X Genomics (dataset can be downloaded [here](https://support.10xgenomics.com/single-cell-gene-expression/datasets/1.1.0/pbmc3k)). This dataset is created of 3k PBMCs from a Healthy Donor. The raw data will be used as the input, but we will also extract the cell type annotation from the legacy scanpy workflow as a comparison between the BMFM and classical scRNA-seq analysis. 

For more information about how the data was preprocessing please visit scanpy's tutorial [here](https://scanpy.readthedocs.io/en/1.11.x/tutorials/basics/clustering-2017.html).

In [2]:
# Get raw PBMC3k data
data_dir = Path("data")
data_dir.mkdir(parents=True, exist_ok=True)

# Get PMBC3k raw dataset
adata = sc.datasets.pbmc3k()

# Extract reference data for later downstream comparison
reference_adata = sc.datasets.pbmc3k_processed()
reference_labels = reference_adata.obs[["louvain"]]
reference_obs_index = reference_adata.obs.index.tolist()
reference_vars_index = reference_adata.var.index.tolist()

adata = adata[reference_obs_index, reference_vars_index]
adata.write("data/pbmc3k_raw.h5ad")

Create results directory to save predictions and plots.

In [3]:
# Create results directory
results_dir = Path("results")
results_dir.mkdir(parents=True, exist_ok=True)

## Load checkpoint

To load a model checkpoint, you can either provide a local path or a model id from the Hugging Face model repository: https://huggingface.co/ibm-research 

The model used for this tutorial: [ibm-research/bmfm.rna.bert.110m.wced.multitask.v1](https://huggingface.co/ibm-research/biomed.rna.bert.110m.wced.v1)

To get the model loaded you can use either:
1. Local Path: download the model checkpoints and tokenizer from huggingface
2. Hugging Face Repo: using  `ibm-research/bmfm.rna.bert.110m.wced.multitask.v1`

The model checkpoint has the following keys:

- `epoch`: This key stores the current epoch number during training.
- `global_step`: This key stores the current global step number during training.
- `pytorch-lightning_version`: This key stores the version of PyTorch Lightning used to save the checkpoint.
- `state_dict`: This key stores the model's state dictionary, which contains the model's parameters (weights and biases).
- `loops`: This key stores the list of training loops that were executed during the training process.
- `callbacks`: This key stores the list of callbacks that were executed during the training process.
- `optimizer_states`: This key stores the optimizer's state dictionary, which includes the optimizer's internal state, such as the current step, momentum, and learning rate.
- `lr_schedulers`: This key stores the learning rate scheduler's state dictionary, which includes the scheduler's internal state, such as the current epoch or step.
- `MixedPrecision`: This key stores the mixed precision settings used during training, if any.
- `hparams_name`: This key stores the name of the hyperparameter configuration used for training.
- `hyper_parameters`: This key stores a dictionary of hyperparameters such as 'model_config', 'trainer_config', and 'label_dict'. 


In both cases, the model's configuration will be automatically loaded. If you want to use a specific configuration, you can provide these parameters later in the `task_config`.

In [4]:
# Load the checkpoint locally
model_path = Path("wced_1024_multitask")
checkpoint_path = model_path / "last.ckpt"

cpkt = torch.load(
    checkpoint_path,
    map_location=torch.device(DEVICE),
    weights_only=False,
)

label_dict = cpkt["hyper_parameters"]["label_dict"]
model_config = cpkt["hyper_parameters"]["model_config"]

tokenizer = load_tokenizer(model_path)

## Create Data, Trainer and Task Configs

Once the model checkpoint and tokenizer have been loaded, you will need to setup the data loader for your data. The easiest way to load your data is to save your data as an H5AD object and then provide the path to the data in the `dataset_kwargs` under the key `processed_data_source`. 

Importantly, as we are using the checkpoint for WCED model (`ibm-research/bmfm.rna.bert.110m.wced.multitask.v1`) you will need to ensure that `adata.X` is raw counts. The data module will then handle any transformation for you including limited the genes to protein coding genes only.

Finally, as we want to perform zeroshot, we will need to setup the our data module in predict model by using `data_module.setup("predict")`.

In [5]:
data_module = DataModule(
    data_dir="data",
    transform_datasets=False,
    tokenizer=tokenizer,
    fields=model_config.fields,
    limit_genes="protein_coding",
    mlm=False,
    collation_strategy="multitask",
    batch_size=20,
    pad_zero_expression_strategy="batch_wise",
    dataset_kwargs={
        "processed_data_source": "data/pbmc3k_raw.h5ad",
        "expose_zeros": "all",
    },
    transform_kwargs={},
    num_workers=1,
    log_normalize_transform=True,
)
data_module.setup("predict")

str arguments for `pad_zero_expression_strategy` are deprecated
batch_wise -> {'strategy': 'batch_wise'}


Next we need to set up the trainer and task configs.

The trainer config will need to be provided any of the labels you want to predict under the `losses` parameter. In this case we are predicting `cell_type`. However, the multitask model can also predict other labels such as `tissue` and `tissue_general`. For a full is of losses you can inspect the model's checkpoint using: 
```python 
cpkt["hyper_parameters"]["trainer_config"].losses
```

The task config needs to be provided with the model checkpoint path we created eariler. The DEVICE in this case is set based on the detected hardware on your machine. Finally, we create the pytorch lighting 

*NB: `"16-mixed"` if using CUDA rather than `32` on CPU or MPS

In [6]:
trainer_config = config.TrainerConfig(
    losses=[{"label_column_name": "cell_type"}], batch_size=20
)

task_config = config.PredictTaskConfig(
    checkpoint=str(checkpoint_path),
    default_root_dir=".",
    precision="32",  # "16-mixed" if using CUDA rather than CPU or MPS
    accelerator=DEVICE,
    output_embeddings=True,
    output_predictions=True,
    enable_progress_bar=True,
    enable_model_summary=True,
    callbacks=[],
)

pl_trainer = make_trainer_for_task(task_config)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Initialize model from checkpoint with task config, data module, model config and the trainer config.

In [7]:
pl_module = instantiate_module_from_checkpoint(
    task_config, data_module, model_config, trainer_config
)

Tie weights not supported for this model


In [8]:
type(pl_module)

bmfm_targets.training.modules.multitask_modeling.MultiTaskTrainingModule

## Run predict

How the data module, task config, model config and tranier config have been instantiated to a MultiTaskTrainingModule.

Then we can use the `predict` function with the trainer, MultiTaskTrainingModule and data module to perform our zero-shot.

In [9]:
results = predict(
    pl_trainer=pl_trainer, pl_module=pl_module, pl_data_module=data_module
)

/Users/mattmadgwick/miniforge3/envs/bmfm-dev/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'predict_dataloader' to speed up the dataloader worker initialization.


Predicting: |          | 0/? [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


The complete mapping dictionary of the label codes to the actual cell label names can be converted using the `get_label_map` helper function:

In [10]:
label_map = get_label_map(key="cell_type", predictions=results, label_dict=label_dict)

# Show the first few labels mapped to cell-types
list(label_map.items())[:4]

[(453, 'luminal epithelial cell of mammary gland')]

## Save adata object

Finally, you can save the adata object with the cell types and BMFM embedding using the scanpy's `.write_h5ad()` and save the results from the model as a pickle file.

In [11]:
adata_merged = merge_bmfm_adata(adata, reference_adata)
adata_merged

KeyError: 'X_umap'

In [None]:
adata_merged.write_h5ad(results_dir / "bmfm_pbmc3k.h5ad")
with open(results_dir / "bmfm_pbmc3k_results.pkl", "wb") as rf:
    pickle.dump(results, rf)