In [1]:
## loading in libraries
import scanpy as sc
import anndata as ad
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import hydra
import pandas as pd
from omegaconf import OmegaConf

from sklearn.model_selection import train_test_split
from datasets import Dataset
from transformers import SchedulerType, get_scheduler

## initialize the model
from Heimdall.models import Heimdall_Transformer, TransformerConfig

## Cell representation tools from heimdall
from Heimdall.cell_representations import Cell_Representation
from Heimdall.f_g import identity_fg
from Heimdall.f_c import geneformer_fc
from Heimdall.utils import heimdall_collate_fn
from Heimdall.trainer import Heimdall_Trainer

%load_ext autoreload
%autoreload 2



# The Cell Representation Object

- Here you define the f_g and the f_c that you want to use. Here we use ones pre-made and stored in the 
files `Heimdall.f_g` and `Heimdall.f_c`. 

- Follow the readme and the notion page for how to design f_g and f_c

In [3]:
with hydra.initialize(version_base=None, config_path="config"):
    config = hydra.compose(config_name="config") ## setting up a default experiment
    # print(OmegaConf.to_yaml(config))

CR = Cell_Representation(config) ## takes in the whole config from hydra
CR.preprocess_anndata() ## standard sc preprocessing can be done here
CR.preprocess_f_g(identity_fg) ## takes in the identity f_g specified above
CR.preprocess_f_c(geneformer_fc) ## takes in the geneformer f_c specified above
CR.prepare_labels() ## prepares the labels

## we can take this out here now and pass this into a PyTorch dataloader and separately create the model
X = CR.cell_representation
y = CR.labels

print(f"Cell representation X: {X.shape}")
print(f"Cell labels y: {y.shape}")



> Finished Loading in data/sc_sub_nick.h5ad
> Normalizing anndata...
> Log Transforming anndata...
> Using highly variable subset... top 1000 genes
> Scaling the data...


  view_to_actual(adata)


> Finished Processing Anndata Object
> Performing the f_g identity, desc: each gene is its own token
> Finished calculating f_g with identity
> Performing the f_c using rank-based values, as seen in geneformer


100%|██████████| 26553/26553 [00:06<00:00, 4321.80it/s]


> Finished calculating f_c with identity
> Finished extracting labels, self.labels.shape: (26553,)
Cell representation X: (26553, 1000)
Cell labels y: (26553,)


# Dataset Preparation

In [5]:
########
# PREPARE THE DATASET
# I am including this explicit example here just for completeness, but this can
# easily be rolled into a helper function
########


train_x, test_val_x, train_y, test_val_y = train_test_split(X, y, test_size=0.2, random_state=42) 
test_x, val_x, test_y, val_y = train_test_split(test_val_x, test_val_y, test_size=0.5, random_state=42) 

print(f"> Cell representation X: {X.shape}")
print(f"> Cell labels y: {y.shape}")
print(f"> train_x.shape {train_x.shape}")
print(f"> validation_x.shape {val_x.shape}")
print(f"> test_x.shape {test_x.shape}")

# this is how you dynamically process your outputs into the right dataloader format
# if you do not want conditional tokens, just omit those arguments
# what is crucial is that the dataset contains the arguments `inputs` and `labels`, anything else will be put into `conditional`
ds_train = Dataset.from_dict({"inputs": train_x,'labels':train_y, 'conditional_tokens_1': train_x, 'conditional_tokens_2': train_x})
ds_valid= Dataset.from_dict({"inputs": val_x,'labels':val_y, 'conditional_tokens_1': val_x, 'conditional_tokens_2': val_x})
ds_test = Dataset.from_dict({"inputs": test_x,'labels':test_y, 'conditional_tokens_1': test_x, 'conditional_tokens_2': test_x})

## this can probably be rolled into the train functionality itself, but lets keep it outside to be eaiser to debug
dataloader_train = DataLoader(ds_train, batch_size=int(config.dataset.task_args.batchsize), shuffle=config.dataset.task_args.shuffle, collate_fn=heimdall_collate_fn)
dataloader_val = DataLoader(ds_valid, batch_size=int(config.dataset.task_args.batchsize), shuffle=config.dataset.task_args.shuffle, collate_fn=heimdall_collate_fn)
dataloader_test = DataLoader(ds_test, batch_size=int(config.dataset.task_args.batchsize), shuffle=config.dataset.task_args.shuffle, collate_fn=heimdall_collate_fn)


> Cell representation X: (26553, 1000)
> Cell labels y: (26553,)
> train_x.shape (21242, 1000)
> validation_x.shape (2656, 1000)
> test_x.shape (2655, 1000)


In [6]:
for batch in dataloader_train:
    break

## Demonstration of the dataset contents
batch

{'inputs': tensor([[647, 686, 365,  ...,  56, 759, 136],
         [489, 811, 223,  ..., 554, 508, 498],
         [919, 159,  60,  ..., 581, 645, 908],
         ...,
         [137, 181, 782,  ..., 554,  38, 498],
         [211, 761, 745,  ..., 524, 872, 645],
         [962, 251, 603,  ..., 640, 645, 554]]),
 'labels': tensor([ 4, 11,  6, 16,  0,  4,  0,  0,  4,  2,  0,  0,  0,  0,  0,  0,  2, 16,
          4,  6,  0,  3,  0,  6,  3,  4, 16, 16,  4, 11,  0,  0]),
 'conditional_tokens': {'conditional_tokens_1': tensor([[647, 686, 365,  ...,  56, 759, 136],
          [489, 811, 223,  ..., 554, 508, 498],
          [919, 159,  60,  ..., 581, 645, 908],
          ...,
          [137, 181, 782,  ..., 554,  38, 498],
          [211, 761, 745,  ..., 524, 872, 645],
          [962, 251, 603,  ..., 640, 645, 554]]),
  'conditional_tokens_2': tensor([[647, 686, 365,  ...,  56, 759, 136],
          [489, 811, 223,  ..., 554, 508, 498],
          [919, 159,  60,  ..., 581, 645, 908],
          ...,


# Model Instantiation Example

In [None]:
########
# Create the model and the types of inputs that it may use
## `type` can either be `learned`, which is integer tokens and learned nn.embeddings, 
##  or `predefined`, which expects the dataset to prepare batchsize x length x hidden_dim
#######

conditional_input_types = {
    "conditional_tokens_1":{
        "type": "learned",
        "vocab_size": 1000
    },
    "conditional_tokens_2":{
        "type": "learned",
        "vocab_size": 1000
    }
}

## initialize the model
from Heimdall.models import Heimdall_Transformer, TransformerConfig
%load_ext autoreload
%autoreload 2

## model config based on your specifications
transformer_config = TransformerConfig(vocab_size = 1000, max_seq_length = 1000, prediction_dim = 20)
model = Heimdall_Transformer(config=transformer_config, input_type="learned", conditional_input_types = conditional_input_types)

## optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.optimizer.learning_rate,
    weight_decay=config.optimizer.weight_decay,
    betas=(config.optimizer.beta1, config.optimizer.beta2),
    foreach=False) ## the forearch is due to a distributed bug with cosine scheduler


model

# Trainer

In [8]:
trainer = Heimdall_Trainer(config=config, model=model, optimizer=optimizer,
                dataloader_train = dataloader_train, 
                dataloader_val = dataloader_val,
                dataloader_test = dataloader_test,
                run_wandb = True)

trainer

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


> Using Device: cuda
==> Starting a new WANDB run


[34m[1mwandb[0m: Currently logged in as: [33mnih121[0m ([33mHeimdall[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0111126235117101, max=1.0))…

==> Initialized Run
 !!! Remember that config batchsize here is GLOBAL Batchsize !!!
> global batchsize: 32
> num_devices: 1
> total_samples: 21242
> warmup_step: 663
> total_steps: 6630
> per_device_batch_size: 32
> Finished Wrapping the model, optimizer, and dataloaders in accelerate
> run Heimdall_Trainer.train() to begin training


<train.Heimdall_Trainer at 0x7f102fa6c250>

In [9]:
trainer.train()

100%|██████████| 83/83 [00:05<00:00, 14.83it/s]
100%|██████████| 83/83 [00:04<00:00, 18.86it/s]
Epoch: 0, Step 664, Loss: 0.7718, LR: 2.0e-03: 100%|██████████| 664/664 [01:00<00:00, 10.91it/s]
100%|██████████| 83/83 [00:04<00:00, 20.61it/s]
100%|██████████| 83/83 [00:04<00:00, 19.17it/s]
Epoch: 1, Step 1328, Loss: 0.1148, LR: 1.9e-03: 100%|██████████| 664/664 [01:03<00:00, 10.48it/s]
100%|██████████| 83/83 [00:04<00:00, 20.67it/s]
100%|██████████| 83/83 [00:04<00:00, 19.73it/s]
Epoch: 2, Step 1992, Loss: 0.0108, LR: 1.8e-03: 100%|██████████| 664/664 [01:01<00:00, 10.84it/s]
100%|██████████| 83/83 [00:03<00:00, 22.41it/s]
100%|██████████| 83/83 [00:04<00:00, 19.18it/s]
Epoch: 3, Step 2656, Loss: 0.0007, LR: 1.5e-03: 100%|██████████| 664/664 [01:02<00:00, 10.65it/s]
100%|██████████| 83/83 [00:03<00:00, 24.69it/s]
100%|██████████| 83/83 [00:04<00:00, 17.88it/s]
Epoch: 4, Step 3320, Loss: 0.0535, LR: 1.2e-03: 100%|██████████| 664/664 [01:02<00:00, 10.57it/s]
100%|██████████| 83/83 [00:04<0

NameError: name 'run_wandb' is not defined