# Implementing a novel scFM tokenizer 

Beyond benchmarking existing tokenizers, <span style="font-variant:small-caps; font-size: 16px">Heimdall</span> also enables users to introduce, evaluate and share novel tokenizer designs. Here, we demonstrate several examples of this. As in the first demo, we use a subset of the [**scTab** dataset](https://www.nature.com/articles/s41467-024-51059-5) for evaluation.

In [1]:
import hydra
import Heimdall

In [2]:
from matplotlib import pyplot as plt
import matplotlib
import seaborn as sns

In [3]:
import scanpy as sc

In [4]:
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'

sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme(style="white")  

We are interested in implementing various popular tokenizers for single-cell foundation models within a standardized framework, which enables us to isolate the impact of the tokenizer _itself_ on downstream performance.

## Recap: <span style="font-variant:small-caps; font-size: 28px">Heimdall</span> modularizes scFM tokenizers

The key to <span style="font-variant:small-caps; font-size: 16px">Heimdall</span>'s modularity is the compartmentalization of the tokenizer into $F_\textbf{G}$, $F_\textbf{E}$ and $F_\textbf{C}$ components. Thus, to implement a novel 

<figure>
  <img src="../_static/scfm_breakdown.png"/>
  <figcaption><b>Fig 1C (bottom).</b> Application of <span style="font-variant:small-caps; font-size: 16px">Heimdall</span> for systematic design and evaluation of scFMs. The highlighted yellow rectangle indicates introduction of a novel $F_\textbf{E}$, thereby constituting a new tokenizer implementation. </figcaption>
</figure>

## Implementing a new $F_\textbf{G}$

In our experiments, we use [HyenaDNA](https://arxiv.org/abs/2306.15794) to create a novel gene-identity encoding module. In particular, we feed the DNA sequence of the gene to the HyenaDNA model, and save the outputs as a Torch Tensor file. We already provide a base class that loads the gene embeddings from a `.pt` file, so the implementation is straightforward after we have extracted the embeddings. 

### Python code

In [6]:
from Heimdall.fg import PretrainedFg
class TorchTensorFg(PretrainedFg):   # This is actually already provided in `Heimdall.fg`
    """Mapping of gene names to pretrained embeddings stored as PyTorch                         
    tensors."""                                                                                 
                                                                                                
    def load_embeddings(self):                                                                  
        raw_gene_embedding_map = torch.load(self.embedding_filepath, weights_only=True)
        
        raw_gene_embedding_map = {                                                              
            gene_name: embedding.detach().cpu().numpy() for gene_name, embedding in raw_gene_embedding_map.items()
        }                                                                                       
                                                                                                
        return raw_gene_embedding_map  

class HyenaDNAFg(TorchTensorFg):
    """Mapping of gene names to pretrained HyenaDNA embeddings."""  

### `fg/hyenadna.yaml` config
The final step is to write a config file for this $F_\textbf{G}$

```yaml
type: Heimdall.fg.HyenaDNAFg                                                                                                                                                                                                                         
                                                                                                
args:                                                                                           
  embedding_parameters:                                                                         
    type: Heimdall.embedding.FlexibleTypeEmbedding                                              
    constructor: from_pretrained                                                                
    args:                                                                                       
      embeddings: gene_embeddings                                                               
  embedding_filepath: ${data_path}/pretrained_embeddings/hyenaDNA.pt                     
  d_embedding: ${model.args.d_model}                                                            
  frozen: true 
  ```

## Putting it all together

In [None]:
from omegaconf import OmegaConf
with hydra.initialize(version_base=None, config_path="../Heimdall/config"):
    config = hydra.compose(
        config_name="config",
        overrides=[
            "+experiments=sctab_split1_all",
            "fg=hyenadna",
            "fe=zero",
            "fc=geneformer",
        ],
    )
    
    OmegaConf.resolve(config)

## Training the model

In [18]:
from Heimdall.trainer import setup_trainer
def training_loop(config):
    trainer = setup_trainer(config, cpu=config.trainer.cpu)
    if trainer is not None:
        trainer.fit()

In [19]:
from accelerate import notebook_launcher

In [None]:
args = (config,)
notebook_launcher(training_loop, args, num_processes=1)