In [28]:
## loading in libraries
import scanpy as sc
import anndata as ad
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import hydra
import pandas as pd
from omegaconf import OmegaConf

from sklearn.model_selection import train_test_split
from datasets import Dataset
from transformers import SchedulerType, get_scheduler

## initialize the model
from Heimdall.models import HeimdallTransformer, TransformerConfig

## Cell representation tools from heimdall
from Heimdall.cell_representations import CellRepresentation
from Heimdall.f_g import identity_fg
from Heimdall.f_c import geneformer_fc
from Heimdall.f_c import scgpt_fc

from Heimdall.utils import heimdall_collate_fn
from Heimdall.trainer import HeimdallTrainer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
with hydra.initialize(version_base=None, config_path="config"):
    config = hydra.compose(config_name="config", overrides=["+experiments=classification_experiment_dev"]) ## setting up a default experiment
    print(OmegaConf.to_yaml(config))

data_path: /work/magroup/shared/Heimdall/data/cell_type_annotation/processed
ensembl_dir: /work/magroup/shared/Heimdall/data/gene_mapping
cache_preprocessed_dataset_dir: null
entity: Heimdall
model:
  type: transformer
  args:
    hidden_size: 256
    num_hidden_layers: 6
    num_attention_heads: 8
    hidden_act: gelu
    hidden_dropout_prob: 0.1
    attention_probs_dropout_prob: 0.1
    max_position_embeddings: 1024
    use_flash_attn: false
    pooling: cls_pooling
dataset:
  dataset_name: cell_type_classification
  preprocess_args:
    data_path: ${data_path}/pancreas.h5ad
    top_n_genes: 1000
    normalize: true
    log_1p: true
    scale_data: true
    species: human
tasks:
  args:
    task_type: classification
    task_structure: single
    label_col_name: task_celltype
    metrics:
    - Accuracy
    - MatthewsCorrCoef
    train_split: 0.8
    shuffle: true
    batchsize: 32
    epochs: 1
    prediction_dim: 14
scheduler:
  name: cosine
  lr_schedule_type: cosine
  warmup_rati



# The Cell Representation Object

- Here you define the f_g and the f_c that you want to use. Here we use ones pre-made and stored in the 
files `Heimdall.f_g` and `Heimdall.f_c`. 

- Follow the readme and the notion page for how to design f_g and f_c

In [30]:
with hydra.initialize(version_base=None, config_path="config"):
    config = hydra.compose(config_name="config", overrides=["+experiments=classification_experiment_dev"]) ## setting up a default experiment
    # print(OmegaConf.to_yaml(config))

CR = CellRepresentation(config) ## takes in the whole config from hydra
#CR.preprocess_anndata() ## standard sc preprocessing can be done here
CR.preprocess_f_g(identity_fg) ## takes in the identity f_g specified above
CR.preprocess_f_c(scgpt_fc) ## takes in the geneformer f_c specified above
# CR.preprocess_f_c(old_geneformer_fc) ## takes in the geneformer f_c specified above

#CR.prepare_labels() ## prepares the labels

## we can take this out here now and pass this into a PyTorch dataloader and separately create the model
X = CR.adata.layers["cell_representation"]
y = CR.labels

print(f"Cell representation X: {X.shape}")
print(f"Cell labels y: {y.shape}")

> Finished Loading in /work/magroup/shared/Heimdall/data/cell_type_annotation/processed/pancreas.h5ad
> Normalizing anndata...
> Log Transforming anndata...
{'data_path': '${data_path}/pancreas.h5ad', 'top_n_genes': 1000, 'normalize': True, 'log_1p': True, 'scale_data': True, 'species': 'human'}
> Using highly variable subset... top 1000 genes
> Scaling the data...


  view_to_actual(adata)


> Finished Processing Anndata Object
> Finished Processing Anndata Object:
AnnData object with n_obs × n_vars = 16382 × 1000
    obs: 'tech', 'celltype', 'size_factors', 'species', 'task_celltype', 'batch'
    var: 'n_cells', 'gene_symbol', 'gene_ensembl', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std'
    uns: 'batch_order', 'celltype_order', 'gene_mapping:symbol_to_ensembl', 'log1p', 'hvg'
    layers: 'counts'
> Performing the f_g identity, desc: each gene is its own token
> Finished calculating f_g with identity_fg
> Performing the f_c using rank-based values, as seen in geneformer


100%|██████████| 16382/16382 [00:03<00:00, 4637.97it/s]
  self._setup_splits()


> Finished calculating f_c with geneformer_fc
> Finished setting up datasets (and loaders):
	{'full': SingleInstanceDataset(size=16,382) wrapping: <Heimdall.cell_representations.CellRepresentation object at 0x7f2cfe99c2b0>,
	 'test': <torch.utils.data.dataset.Subset object at 0x7f2cfc471cf0>,
	 'train': <torch.utils.data.dataset.Subset object at 0x7f2cfc473250>,
	 'val': <torch.utils.data.dataset.Subset object at 0x7f2cfc473880>}
> Performing the f_g identity, desc: each gene is its own token
> Finished calculating f_g with identity_fg
> Performing the f_c using rank-based values with binning, as seen in scGPT


100%|██████████| 16382/16382 [00:08<00:00, 1963.52it/s]


> Finished calculating f_c with geneformer_fc
Cell representation X: (16382, 1000)
Cell labels y: (16382,)


# Dataset Preparation

In [31]:
########
# PREPARE THE DATASET
# I am including this explicit example here just for completeness, but this can
# easily be rolled into a helper function
########


train_x, test_val_x, train_y, test_val_y = train_test_split(X, y, test_size=0.2, random_state=42) 
test_x, val_x, test_y, val_y = train_test_split(test_val_x, test_val_y, test_size=0.5, random_state=42) 

print(f"> Cell representation X: {X.shape}")
print(f"> Cell labels y: {y.shape}")
print(f"> train_x.shape {train_x.shape}")
print(f"> validation_x.shape {val_x.shape}")
print(f"> test_x.shape {test_x.shape}")

# this is how you dynamically process your outputs into the right dataloader format
# if you do not want conditional tokens, just omit those arguments
# what is crucial is that the dataset contains the arguments `inputs` and `labels`, anything else will be put into `conditional`


ds_train = Dataset.from_dict({"inputs": train_x,'labels':train_y, 'conditional_tokens_1': train_x, 'conditional_tokens_2': train_x})
ds_valid= Dataset.from_dict({"inputs": val_x,'labels':val_y, 'conditional_tokens_1': val_x, 'conditional_tokens_2': val_x})
ds_test = Dataset.from_dict({"inputs": test_x,'labels':test_y, 'conditional_tokens_1': test_x, 'conditional_tokens_2': test_x})


## this can probably be rolled into the train functionality itself, but lets keep it outside to be eaiser to debug
dataloader_train = DataLoader(ds_train, batch_size=int(config.tasks.args.batchsize), shuffle=config.tasks.args.shuffle, collate_fn=heimdall_collate_fn)
dataloader_val = DataLoader(ds_valid, batch_size=int(config.tasks.args.batchsize), shuffle=config.tasks.args.shuffle, collate_fn=heimdall_collate_fn)
dataloader_test = DataLoader(ds_test, batch_size=int(config.tasks.args.batchsize), shuffle=config.tasks.args.shuffle, collate_fn=heimdall_collate_fn)


> Cell representation X: (16382, 1000)
> Cell labels y: (16382,)
> train_x.shape (13105, 1000)
> validation_x.shape (1639, 1000)
> test_x.shape (1638, 1000)


In [32]:
for batch in dataloader_train:
    break

## Demonstration of the dataset contents
batch

{'inputs': [tensor([489, 393, 986, 489, 489, 489, 489, 640, 640, 428, 986, 640, 847, 489,
          489, 489, 986, 489, 489, 986, 851, 986, 986, 355, 986, 986, 640, 489,
          489, 986, 419, 847]),
  tensor([640, 913, 847, 279, 986, 640, 851, 355, 159, 489, 847, 393, 355, 851,
          851, 986, 489, 851, 986, 419, 355, 847, 640, 393, 847, 847, 986, 640,
          851, 851, 393, 640]),
  tensor([159, 355, 489, 252, 847, 851, 355, 159, 419, 987, 913, 428, 393, 419,
          986, 847, 847, 847, 847, 428,   3,   3,   3, 428, 913, 851, 913, 913,
          355, 847, 159, 851]),
  tensor([369, 851,   3, 428, 640, 419, 640, 428,   3, 159, 987, 851, 159, 428,
          847, 851, 851, 355, 851,   3, 159, 727, 419, 913, 987, 987, 987, 851,
          640, 393,   3, 355]),
  tensor([928, 987, 747, 747, 428, 428, 419, 913, 913,  88,   3, 987,   3, 913,
          640, 428, 432, 640, 419, 987, 432, 143, 913, 987, 640, 393, 489, 987,
          159, 159, 928, 428]),
  tensor([  3, 159, 159, 432, 

# Model Instantiation Example

In [38]:
########
# Create the model and the types of inputs that it may use
## `type` can either be `learned`, which is integer tokens and learned nn.embeddings, 
##  or `predefined`, which expects the dataset to prepare batchsize x length x hidden_dim
#######

# conditional_input_types = {
#     "conditional_tokens_1":{
#         "type": "learned",
#         "vocab_size": 1000
#     },
#     "conditional_tokens_2":{
#         "type": "learned",
#         "vocab_size": 1000
#     }
# }

conditional_input_types = None
## initialize the model
from Heimdall.models import HeimdallTransformer, TransformerConfig
%load_ext autoreload
%autoreload 2

## model config based on your specifications
transformer_config = TransformerConfig(vocab_size = 1000, max_seq_length = 1000, prediction_dim = 14)
model = HeimdallTransformer(config=transformer_config, 
                            input_type="learned", conditional_input_types = conditional_input_types)

model

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


HeimdallTransformer(
  (input_embeddings): Embedding(1000, 128)
  (position_embeddings): Embedding(1001, 128)
  (conditional_embeddings): ModuleDict()
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Linear(in_features=128, out_features=14, bias=True)
)

# Trainer

In [39]:
# trainer = HeimdallTrainer(cfg=config, model=model,
#                 dataloader_train = dataloader_train, 
#                 dataloader_val = dataloader_val,
#                 dataloader_test = dataloader_test,
#                 run_wandb = True)

# trainer

In [40]:
trainer = HeimdallTrainer(cfg=config, model=model, data=CR, run_wandb=False)


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


> Using Device: cuda
> Number of Devices: 1
!!! Remember that config batchsize here is GLOBAL Batchsize !!!
> global batchsize: 32
> total_samples: 7863
> Warm Up Steps: 24
> Total Steps: 245
> per_device_batch_size: 32
> Finished Wrapping the model, optimizer, and dataloaders in accelerate
> run HeimdallTrainer.train() to begin training


Optionally you can specify a custom loss function that takes in input, labels, and loss (make sure it is batched though)

In [41]:
# def funky_fella(input, labels):
#     loss = torch.nn.CrossEntropyLoss()
#     return loss(input.view(-1, 20), labels.view(-1))


# trainer = Heimdall_Trainer(cfg=config, model=model,
#                 dataloader_train = dataloader_train, 
#                 dataloader_val = dataloader_val,
#                 dataloader_test = dataloader_test,
#                 run_wandb = True,
#                 custom_loss_func = funky_fella)

In [42]:
trainer.fit()

100%|██████████| 62/62 [00:00<00:00, 134.39it/s]


{'valid_loss': 2.6007378485894974, 'valid_Accuracy': 13.173957169055939, 'valid_MatthewsCorrCoef': 0, 'Process_mem_rss': 4.490467071533203}


100%|██████████| 205/205 [00:01<00:00, 138.19it/s]


{'test_loss': 2.59528154977938, 'test_Accuracy': 13.871508836746216, 'test_MatthewsCorrCoef': 0, 'Process_mem_rss': 4.4866943359375}


Epoch: 0, Step 66, Loss: 1.8447, LR: 1.8e-03:  27%|██▋       | 66/246 [00:02<00:07, 25.42it/s]


KeyboardInterrupt: 

In [None]:
 import torch
 
 optimizer_name = "AdamW"
 optimizer_class = getattr(torch.optim, optimizer_name)

 optimizer_class