# Classification fine-tuning using Helical

## Cell type classification task

In [1]:
from helical.models.geneformer.geneformer_config import GeneformerConfig
from helical.models.geneformer.fine_tuning_model import GeneformerFineTuningModel
from helical.models.geneformer.model import Geneformer
from helical.models.scgpt.fine_tuning_model import scGPTFineTuningModel
from helical.models.scgpt.model import scGPT,scGPTConfig
from helical.models.uce.model import UCE, UCEConfig
from helical.models.uce.fine_tuning_model import UCEFineTuningModel
from helical.utils.dataset_to_anndata import get_anndata_from_hf_dataset
import torch
import numpy as np

2024-09-23 14:15:43.711747: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-23 14:15:43.720599: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-23 14:15:43.730389: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-23 14:15:43.733542: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-23 14:15:43.741604: I tensorflow/core/platform/cpu_feature_guar

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### Install datasets

In [3]:
from datasets import load_dataset
ds = load_dataset("helical-ai/yolksac_human",trust_remote_code=True, download_mode="reuse_cache_if_exists")

Generating train split:   0%|          | 0/25344 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6336 [00:00<?, ? examples/s]

In [4]:
train_dataset = get_anndata_from_hf_dataset(ds["train"])[:10]
test_dataset = get_anndata_from_hf_dataset(ds["test"])[:2]



## Prepare training labels

- For this classification task we want to predict cell type classes
- So we save the cell types as a list

In [5]:
cell_types_train = list(np.array(train_dataset.obs["LVL1"].tolist()))[:10]
cell_types_test = list(np.array(test_dataset.obs["LVL1"].tolist()))[:2]

- We convert these string labels into unique integer classes for training

In [6]:
label_set = set(cell_types_train)
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types_train)):
    cell_types_train[i] = class_id_dict[cell_types_train[i]]

for i in range(len(cell_types_test)):
    cell_types_test[i] = class_id_dict[cell_types_test[i]]



## Geneformer Fine-Tuning

Load the desired pretrained Geneformer model and desired configs

In [7]:
geneformer_config = GeneformerConfig(device=device, batch_size=5, model_name="gf-6L-30M-i2048")
geneformer = Geneformer(configurer = geneformer_config)

INFO:helical.services.downloader:File: '/home/matthew/.cache/helical/models/geneformer/v1/gene_median_dictionary.pkl' exists already. File is not overwritten and nothing is downloaded.
INFO:helical.services.downloader:File saved to: '/home/matthew/.cache/helical/models/geneformer/v1/gene_median_dictionary.pkl'
INFO:helical.services.downloader:File: '/home/matthew/.cache/helical/models/geneformer/v1/token_dictionary.pkl' exists already. File is not overwritten and nothing is downloaded.
INFO:helical.services.downloader:File saved to: '/home/matthew/.cache/helical/models/geneformer/v1/token_dictionary.pkl'
INFO:helical.services.downloader:File: '/home/matthew/.cache/helical/models/geneformer/v1/ensembl_mapping_dict.pkl' exists already. File is not overwritten and nothing is downloaded.
INFO:helical.services.downloader:File saved to: '/home/matthew/.cache/helical/models/geneformer/v1/ensembl_mapping_dict.pkl'
INFO:helical.services.downloader:File: '/home/matthew/.cache/helical/models/gene

Process the data so it is in the correct form for Geneformer

In [8]:
geneformer_train_dataset = geneformer.process_data(train_dataset)
geneformer_test_dataset = geneformer.process_data(test_dataset)

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/matthew/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/matthew/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/matthew/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle
INFO:helical.services.mapping:Mapped 21359 genes to Ensembl IDs from a total of 37318 genes.
... storing 'LVL1' as categorical
... storing 'LVL2' as categorical
... storing 'LVL3' as categorical
... storing 'ensembl_id' as categorical
100%|██████████| 50/50 [00:31<00:00,  1.60it/s]
INFO:helical.models.geneformer.geneformer_tokenizer:/tmp/tmpyjug5tfe.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
INFO:helical.models.geneformer.geneformer_tokenizer:Creating dataset.


Map:   0%|          | 0/25344 [00:00<?, ? examples/s]

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/matthew/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/matthew/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/matthew/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle
INFO:helical.services.mapping:Mapped 21359 genes to Ensembl IDs from a total of 37318 genes.
... storing 'LVL1' as categorical
... storing 'LVL2' as categorical
... storing 'LVL3' as categorical
... storing 'ensembl_id' as categorical
100%|██████████| 13/13 [00:08<00:00,  1.55it/s]
INFO:helical.models.geneformer.geneformer_tokenizer:/tmp/tmpkyxawh14.h5ad has no column attribute 'filter_pass'; tokenizing all cells.
INFO:helical.models.geneformer.geneformer_tokenizer:Creating dataset.


Map:   0%|          | 0/6336 [00:00<?, ? examples/s]

Geneformer makes use of the Hugging Face dataset class and so we need to add the labels as a column to this dataset

In [9]:
geneformer_train_dataset = geneformer_train_dataset.add_column("LVL1", cell_types_train)
geneformer_test_dataset = geneformer_test_dataset.add_column("LVL1", cell_types_test)

Define the Geneformer Fine-Tuning Model from the Helical package which appends a fine-tuning head automatically from the list of available heads
- Define the task type, which in this case is classification
- Defined the output size, which is the number of unique labels for classification

In [10]:
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_model=geneformer, fine_tuning_head="classification", output_size=len(label_set))

Fine-tune the model

In [11]:
geneformer_fine_tune.train(train_dataset=geneformer_train_dataset, validation_dataset=geneformer_test_dataset, label="LVL1")

INFO:helical.models.geneformer.fine_tuning_model:Freezing the first 2 encoder layers of the Geneformer model during fine-tuning.
INFO:helical.models.geneformer.fine_tuning_model:Starting Fine-Tuning
Fine-Tuning: epoch 1/1: 100%|██████████| 5069/5069 [06:17<00:00, 13.43it/s, loss=0.06]  
Fine-Tuning Validation: 100%|██████████| 1268/1268 [00:44<00:00, 28.42it/s, accuracy=0.989]
INFO:helical.models.geneformer.fine_tuning_model:Fine-Tuning Complete. Epochs: 1


## scGPT Fine-Tuning

Now the same procedure with scGPT
- Loading the model and setting desired configs

In [12]:
scgpt_config=scGPTConfig(batch_size=10, device=device)
scgpt = scGPT(configurer=scgpt_config)

INFO:helical.services.downloader:File: '/home/matthew/.cache/helical/models/scgpt/scGPT_CP/vocab.json' exists already. File is not overwritten and nothing is downloaded.
INFO:helical.services.downloader:File saved to: '/home/matthew/.cache/helical/models/scgpt/scGPT_CP/vocab.json'
INFO:helical.services.downloader:File: '/home/matthew/.cache/helical/models/scgpt/scGPT_CP/best_model.pt' exists already. File is not overwritten and nothing is downloaded.
INFO:helical.services.downloader:File saved to: '/home/matthew/.cache/helical/models/scgpt/scGPT_CP/best_model.pt'
INFO:helical.models.scgpt.model:Model finished initializing.


A slightly different methodology for getting the dataset for scGPT since it does not make use of the Hugging Face Dataset class
- Split the data into a train and validation set

In [13]:
dataset = scgpt.process_data(train_dataset)
validation_dataset = scgpt.process_data(test_dataset)

INFO:helical.models.scgpt.model:Filtering out 10801 genes to a total of 26517 genes with an id in the scGPT vocabulary.


INFO:helical.models.scgpt.model:Filtering out 10801 genes to a total of 26517 genes with an id in the scGPT vocabulary.


Define the scGPT fine-tuning model with the desired head and number of classes

In [14]:
scgpt_fine_tune = scGPTFineTuningModel(scGPT_model=scgpt, fine_tuning_head="classification", output_size=len(label_set))

For scGPT fine tuning we have to pass in the labels as a separate list
- This is the same for the validation and training sets

In [15]:
scgpt_fine_tune.train(train_input_data=dataset, train_labels=cell_types_train, validation_input_data=validation_dataset, validation_labels=cell_types_test)

INFO:helical.models.scgpt.fine_tuning_model:Starting Fine-Tuning
Fine-Tuning: epoch 1/1:   0%|          | 1/2535 [00:00<04:20,  9.74it/s, loss=1.68]

Fine-Tuning: epoch 1/1: 100%|██████████| 2535/2535 [01:59<00:00, 21.17it/s, loss=0.187]
Fine-Tuning Validation: 100%|██████████| 634/634 [00:10<00:00, 60.23it/s, accuracy=0.989]
INFO:helical.models.scgpt.fine_tuning_model:Fine-Tuning Complete. Epochs: 1


## UCE Fine-Tuning

In [16]:
uce_config=UCEConfig(batch_size=5, device=device)
uce = UCE(configurer=uce_config)

INFO:helical.services.downloader:File: '/home/matthew/.cache/helical/models/uce/all_tokens.torch' exists already. File is not overwritten and nothing is downloaded.
INFO:helical.services.downloader:File saved to: '/home/matthew/.cache/helical/models/uce/all_tokens.torch'
INFO:helical.services.downloader:File: '/home/matthew/.cache/helical/models/uce/4layer_model.torch' exists already. File is not overwritten and nothing is downloaded.
INFO:helical.services.downloader:File saved to: '/home/matthew/.cache/helical/models/uce/4layer_model.torch'
INFO:helical.services.downloader:File: '/home/matthew/.cache/helical/models/uce/species_chrom.csv' exists already. File is not overwritten and nothing is downloaded.
INFO:helical.services.downloader:File saved to: '/home/matthew/.cache/helical/models/uce/species_chrom.csv'
INFO:helical.services.downloader:File: '/home/matthew/.cache/helical/models/uce/species_offsets.pkl' exists already. File is not overwritten and nothing is downloaded.
INFO:helic

Prepare data the same way as for scGPT
- Add names for each dataset, as datasets are stored as .npz files and separate files are needed

In [17]:
dataset = uce.process_data(train_dataset, name="train")
validation_dataset = uce.process_data(test_dataset, name="validation")

INFO:helical.models.uce.gene_embeddings:Finished loading gene embeddings for {'human'} from /home/matthew/.cache/helical/models/uce/protein_embeddings
INFO:helical.models.uce.gene_embeddings:Filtered out 19355 genes to a total of 17963 genes with embeddings.
INFO:helical.models.uce.uce_utils:Passed the gene expressions (with shape=(25344, 37318) and max gene count data 31635.0) to ./train_counts.npz
INFO:helical.models.uce.model:Successfully prepared the UCE Dataset.
INFO:helical.models.uce.gene_embeddings:Finished loading gene embeddings for {'human'} from /home/matthew/.cache/helical/models/uce/protein_embeddings
INFO:helical.models.uce.gene_embeddings:Filtered out 19355 genes to a total of 17963 genes with embeddings.
INFO:helical.models.uce.uce_utils:Passed the gene expressions (with shape=(6336, 37318) and max gene count data 20639.0) to ./validation_counts.npz
INFO:helical.models.uce.model:Successfully prepared the UCE Dataset.


Define the fine-tuning model

In [18]:
uce_fine_tune = UCEFineTuningModel(uce_model=uce, fine_tuning_head="classification", output_size=len(label_set))

Fine-tune the model

In [19]:
uce_fine_tune.train(train_input_data=dataset, train_labels=cell_types_train, validation_input_data=validation_dataset, validation_labels=cell_types_test)

INFO:helical.models.uce.fine_tuning_model:Starting Fine-Tuning
Fine-Tuning: epoch 1/1: 100%|██████████| 5069/5069 [12:54<00:00,  6.55it/s, loss=1.13]
Fine-Tuning Validation: 100%|██████████| 1268/1268 [01:16<00:00, 16.49it/s, accuracy=0.473]
INFO:helical.models.uce.fine_tuning_model:Fine-Tuning Complete. Epochs: 1
