# Benchmarking tokenizers for cross-tissue generalization

Here, we demonstrate how to use <span style="font-variant:small-caps; font-size: 16px">Heimdall</span> for benchmarking the impact of tokenizer choice on cell-type annotation performance in a challenging cross-tissue evaluation setting. For this evaluation, we use a subset of the [**scTab** dataset](https://www.nature.com/articles/s41467-024-51059-5).

In [41]:
import hydra
import Heimdall

In [50]:
from matplotlib import pyplot as plt
import matplotlib
import seaborn as sns

In [51]:
import scanpy as sc

In [52]:
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'

sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme(style="white")

## Recap: the <span style="font-variant:small-caps; font-size: 28px">Heimdall</span> framework

We are interested in implementing various popular tokenizers for single-cell foundation models within a standardized framework, which enables us to isolate the impact of the tokenizer _itself_ on downstream performance.

<figure>
  <img src="../_static/heimdall_framework.png"/>
  <figcaption><b>Fig 1A.</b> Modular conceptualization of scFMs in <span style="font-variant:small-caps; font-size: 16px">Heimdall</span>. Gene identities and expression values from each single-cell input
are processed by a cell tokenization scheme (tokenizer) that generates a sequence-based “cell sentence”. The tokenizer is decomposed into three modules: a gene identity encoder ($F_\textbf{G}$), an expression encoder ($F_\textbf{E}$), and a cell constructor ($F_\textbf{C}$). The tokenizer output is then passed to a sequence-based model (e.g., a transformer).</figcaption>
</figure>

## Experiment setup (via `hydra`)

We use [`hydra`](https://github.com/facebookresearch/hydra) to configure each <span style="font-variant:small-caps; font-size: 16px">Heimdall</span> run. First, we detail the configuration files that are shared across all tokenizers for this cross-tissue generalization experiment.

Let's take a closer look at some of the most important configs...

### `experiments`
Top-level experiment config that specifies other essential configs.


```yaml
defaults:
  - override /dataset: new_sctab  # dataset for the relevant task
  - override /tasks: new_sctab_split  # dataset for the relevant task
  - override /model: transformer  # the chosen model
  - override /scheduler: cosine  # contains scheduler
  - override /trainer: default  # contains optimizer and trainer details
  - override /optimizer: AdamW
  - override /fg: random
  - override /fe: noop
  - override /fc: geneformer
  
seed: 55 # random seed for reproducibility

project_name: new_sctab_split1  # project name for WandB
```

### `dataset`
Specifies the path to the dataset, as well as preprocessing arguments.

```yaml
dataset_name: new_sctab

preprocess_args:
  data_path: ${data_path}/sctab/tissue_splits_spencer/scTab_GItract_train.h5ad
  top_n_genes: false
  normalize: true
  log_1p: true
  scale_data: false
  species: human
```

### `task`
Specifies the task to use for training the model, including training metrics, dataset splits, the architecture of the task "head" (used for predicting the task outputs), and the loss function. In this case, the predefined dataset splits (stored in the `adata.obs["split3"]` column) reserve gastrointestinal tract cells exclusively for training, while brain cells are used exclusively from validation/testing.

```yaml
type: Heimdall.task.SingleInstanceTask

args:
  task_type: multiclass
  label_col_name: cell_type
  metrics: [Accuracy, MatthewsCorrCoef, ConfusionMatrix]
  track_metric: MatthewsCorrCoef
  splits:
    type: predefined
    col: split3
    keys_:
      train: train
      val: val
      test: test
  early_stopping_patience: 5
  early_stopping: true
  shuffle: true
  batchsize: 32
  epochs: 50
  dataset_config:
    type: Heimdall.datasets.SingleInstanceDataset
  head_config:
    type: Heimdall.models.LinearCellPredHead
    args:
  loss_config:
    type: Heimdall.losses.FlattenCrossEntropyLoss

cell_rep_config:
  type: Heimdall.cell_representations.CellRepresentation
```

### `model`
Specifies the model architecture used for the scFM.

```yaml
type: Heimdall.models.Transformer
name: transformer

args:
  d_model: 128
  pos_enc: BERT
  num_encoder_layers: 2
  nhead: 4
  hidden_act: gelu
  hidden_dropout_prob: 0.1
  use_flash_attn: false
  pooling: cls_pooling # or "mean_pooling"
```

## Modular reimplementation of the Geneformer tokenizer

Having configured everything except the tokenizer, we now focus on implementing the tokenizer. For practice, let's implement the [Geneformer](https://www.nature.com/articles/s41586-023-06139-9) tokenizer.

### `fg` - the gene identity encoder ($F_\textbf{G}$)
Specifies the `Fg` implementation for this tokenizer, as well as the `torch.nn.Module` used for providing the embeddings of the genes. Here, we use the `random` implementation, which assigns a randomly-initialized embedding vector of dimensionality `model.args.d_model` to each gene in the cell.

```yaml
type: Heimdall.fg.IdentityFg

args:
  embedding_parameters:
    type: Heimdall.embedding.FlexibleTypeEmbedding
    args:
      num_embeddings: vocab_size
      embedding_dim: ${fg.args.d_embedding}
  d_embedding: ${model.args.d_model}
  frozen: false
```

### `fe` - the gene expression encoder ($F_\textbf{E}$)
Specifies the `Fe` implementation for this tokenizer, as well as the `torch.nn.Module` used for providing the embeddings of the genes' expression levels. Here, we use the `noop` implementation, which simply outputs a vector of zeros of dimensionality `model.args.d_model` for each gene in the cell, regardless of the gene's expression level.

```yaml
type: Heimdall.fe.IdentityFe
name: Heimdall.fe.IdentityFe

args:
  embedding_parameters:
    type: Heimdall.embedding.ZeroBroadcast
    args:
      out_features: ${fe.args.d_embedding}
  d_embedding: ${model.args.d_model}
  drop_zeros: true
```

### `fc` - the single-cell representation function ($F_\textbf{C}$)
Specifies the 

```yaml
type: Heimdall.fc.Fc

args:
  max_input_length: 2048
  embedding_parameters:
    type: torch.nn.Module  # Should throw an error if called
  tailor_config:
    type: Heimdall.tailor.ReorderTailor
  order_config:
    type: Heimdall.order.ExpressionOrder
  reduce_config:
    type: Heimdall.reduce.SumReduce
```

## Putting it all together

In [58]:
from omegaconf import OmegaConf
with hydra.initialize(version_base=None, config_path="../Heimdall/config"):
    config = hydra.compose(
        config_name="config",
        overrides=[
            "+experiments=sctab_split1_all",
            "fg=random",
            "fe=noop",
            "fc=geneformer",
        ],
    )

    OmegaConf.resolve(config)

In [69]:
print(OmegaConf.to_yaml(config))

project_name: new_sctab_split1
run_name: Heimdall.fg.IdentityFg_Heimdall.fe.IdentityFe_Heimdall.fc.Fc_Heimdall.models.Transformer_lr0.002_bz32
work_dir: new_sctab_split1_results/Heimdall.fg.IdentityFg_Heimdall.fe.IdentityFe_Heimdall.fc.Fc_new_sctab_lr0.002_bz32_seed55_agTrue
run_wandb: true
float_dtype: float32
seed: 55
data_path: /work/magroup/shared/Heimdall/data
ensembl_dir: /work/magroup/shared/Heimdall/data
cache_preprocessed_dataset_dir: /scratch/heimdall/shared/cache
entity: Heimdall
only_preprocess_data: false
model:
  type: Heimdall.models.Transformer
  name: transformer
  args:
    d_model: 128
    pos_enc: BERT
    num_encoder_layers: 2
    nhead: 4
    hidden_act: gelu
    hidden_dropout_prob: 0.1
    use_flash_attn: false
    pooling: cls_pooling
dataset:
  dataset_name: new_sctab
  preprocess_args:
    data_path: /work/magroup/shared/Heimdall/data/sctab/tissue_splits_spencer/scTab_GItract_train.h5ad
    top_n_genes: false
    normalize: true
    log_1p: true
    scale_dat

## Training the model

In [1]:
from Heimdall.trainer import setup_trainer
def training_loop(config): 
    trainer = setup_trainer(config, cpu=config.trainer.cpu)
    if trainer is not None:
        trainer.fit()

In [71]:
from accelerate import notebook_launcher

In [None]:
args = (config,)
notebook_launcher(training_loop, args, num_processes=1)