# Small Notebook to Prepare your $f_g$ and $f_c$ with hydra

In [2]:
## loading in libraries
import scanpy as sc
import anndata as ad
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import hydra
import pandas as pd
from omegaconf import OmegaConf

from Heimdall.cell_representations import Cell_Representation
%load_ext autoreload
%autoreload 2

## autoreload prevents locally edited files from being cached, easier for development

In [3]:
#####
# an example of some custom fg/fcs
#####
def identity_fg(adata_var):
    """
    identity_fg is an fg that returns a token id for each gene, effectively each gene
    is its own word.

    args:
        - adata_var: takes in the var dataframe, in this case, it expects the index to have the gene names

    output:
        - the output is a dictionary map between the gene names, and their corersponding token id for nn.embedding
    """
    print("> Performing the f_g identity, desc: each gene is its own token")
    gene_df = adata_var
    gene_mapping = {label: idx for idx, label in enumerate(gene_df.index.unique(), start=0)}
    return gene_mapping


def geneformer_fc(fg, adata):
    """
    geneformer_fc is a fc that will reprocess each cell by ordering them by their gene expression value,
    and replace each gene name by their corresponding representation, either token_id or a different vector

    right now this only supports token_id

    args:
        - fg: dictionary that maps gene names to token ids
        - adata: the whole, already processed, anndata object with the CellxGene Matrix

    output:
        - output: dataset, a numpy object that is dimension CellxGene where the position has the token denoting what gene it is
    """

    assert all(isinstance(value, (int)) for value in fg.values()), \
            "Current geneformer_fc only supports token ids"

    print("> Performing the f_c using rank-based values, as seen in geneformer")
    df = pd.DataFrame(adata.X, columns=fg.keys())

    dataset = []
    for i in tqdm(range(len(df))):
        cell = df.iloc[i]
        sorted_cell = cell.sort_values(ascending=False).index
        cell_w_gene_ids = [fg[gene] for gene in sorted_cell]
        dataset.append(cell_w_gene_ids)

    dataset = np.array(dataset)
    return dataset


In [4]:
with hydra.initialize(version_base=None, config_path="config"):
    config = hydra.compose(config_name="config") ## setting up a default experiment
    print(OmegaConf.to_yaml(config))



#####
# For more details please check out the Cell_Representation object and the corresponding functions below
#####

CR = Cell_Representation(config) ## takes in the whole config from hydra
CR.preprocess_anndata() ## standard sc preprocessing can be done here
CR.preprocess_f_g(identity_fg) ## takes in the identity f_g specified above
CR.preprocess_f_c(geneformer_fc) ## takes in the geneformer f_c specified above
CR.prepare_labels() ## prepares the labels

## we can take this out here now and pass this into a PyTorch dataloader and separately create the model
X = CR.cell_representation
y = CR.labels

print(f"Cell representation X: {X.shape}")
print(f"Cell labels y: {y.shape}")



model:
  type: transformer
  args:
    hidden_size: 128
    num_hidden_layers: 2
    num_attention_heads: 32
    hidden_act: gelu
    hidden_dropout_prob: 0.1
    attention_probs_dropout_prob: 0.1
    max_position_embeddings: 1024
    use_flash_attn: false
    pooling: cls_pooling
dataset:
  dataset_name: cell_type_classification
  preprocess_args:
    data_path: data/sc_sub_nick.h5ad
    top_n_genes: 1000
    normalize: true
    log_1p: true
    scale_data: true
  task_args:
    label_col_name: class
    metric_name: MCC
    train_split: 0.8
scheduler:
  name: cosine
  lr_schedule_type: cosine
  warmup_ratio: 0.1
  num_epochs: 20
trainer:
  accelerator: cuda
  precision: 32-true
  random_seed: 11111
  per_device_batch_size: 64
  accumulate_grad_batches: 1
  num_epochs: 20
optimizer:
  name: adamW
  learning_rate: 0.002
  end_learning_rate: 1.0e-05
  grad_norm_clip: 1.0
  weight_decay: 0.1
  beta1: 0.9
  beta2: 0.95
f_c:
  name: identity
  args:
    output_type: ids
f_g:
  name: identi

  view_to_actual(adata)


> Finished Processing Anndata Object
> Performing the f_g identity, desc: each gene is its own token
> Finished calculating f_g with identity
> Performing the f_c using rank-based values, as seen in geneformer


100%|██████████| 26553/26553 [00:11<00:00, 2409.79it/s]


> Finished calculating f_c with identity
> Finished extracting labels, self.labels.shape: (26553,)
Cell representation X: (26553, 1000)
Cell labels y: (26553,)
