# Classification fine-tuning using Helical

## Cell type classification task

In [None]:
from helical.models.geneformer.geneformer_config import GeneformerConfig
from helical.models.geneformer.fine_tuning_model import GeneformerFineTuningModel
from helical.models.geneformer.model import Geneformer
from helical.models.scgpt.fine_tuning_model import scGPTFineTuningModel
from helical.models.scgpt.model import scGPT,scGPTConfig
from helical.models.uce.model import UCE, UCEConfig
from helical.models.uce.fine_tuning_model import UCEFineTuningModel
import torch
import anndata as ad

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Load the dataset

In [3]:
ann_data = ad.read_h5ad("./10k_pbmcs_proc.h5ad")
ann_data

AnnData object with n_obs × n_vars = 11990 × 12000
    obs: 'n_counts', 'batch', 'labels', 'str_labels', 'cell_type'
    var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'cell_types', 'hvg'
    obsm: 'design', 'normalized_qc', 'qc_pc', 'raw_qc'

## Prepare training labels

- For this classification task we want to predict cell type classes
- So we save the cell types as a list

In [4]:
cell_types = list(ann_data.obs.cell_type)
cell_types[:10]

['CD4 T cells',
 'CD4 T cells',
 'CD14+ Monocytes',
 'CD14+ Monocytes',
 'CD8 T cells',
 'CD4 T cells',
 'CD14+ Monocytes',
 'CD4 T cells',
 'CD4 T cells',
 'CD14+ Monocytes']

- We convert these string labels into unique integer classes for training

In [5]:
label_set = set(cell_types)
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types)):
    cell_types[i] = class_id_dict[cell_types[i]]

cell_types[:10]

[8, 8, 5, 5, 2, 8, 5, 8, 8, 5]

## Geneformer Fine-Tuning

Load the desired pretrained Geneformer model and desired configs

In [None]:
geneformer_config = GeneformerConfig(device=device, batch_size=5, model_name="gf-6L-30M-i2048")
geneformer = Geneformer(configurer = geneformer_config)

Process the data so it is in the correct form for Geneformer

In [None]:
dataset = geneformer.process_data(ann_data)

Geneformer makes use of the Hugging Face dataset class and so we need to add the labels as a column to this dataset

In [8]:
dataset = dataset.add_column('cell_types', cell_types)
dataset

Dataset({
    features: ['input_ids', 'length', 'cell_types'],
    num_rows: 11990
})

Define the Geneformer Fine-Tuning Model from the Helical package which appends a fine-tuning head automatically from the list of available heads
- Define the task type, which in this case is classification
- Defined the output size, which is the number of unique labels for classification

In [9]:
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_model=geneformer, fine_tuning_head="classification", output_size=len(label_set))

Shuffle and split our dataset into training and validation sets

In [10]:
dataset = dataset.shuffle()
dataset = dataset.train_test_split(test_size=0.2)

Fine-tune the model

In [11]:
geneformer_fine_tune.train(train_dataset=dataset["train"], validation_dataset=dataset["test"])

Freezing the first 2 encoder layers of the Geneformer model during fine-tuning.


Fine-Tuning:   0%|          | 0/1919 [00:00<?, ?it/s]

Fine-Tuning Validation:   0%|          | 0/480 [00:00<?, ?it/s]

## scGPT Fine-Tuning

Now the same procedure with scGPT
- Loading the model and setting desired configs

In [None]:
scgpt_config=scGPTConfig(batch_size=10, device=device)
scgpt = scGPT(configurer=scgpt_config)

A slightly different methodology for getting the dataset for scGPT since it does not make use of the Hugging Face Dataset class
- Split the data into a train and validation set

In [None]:
dataset = scgpt.process_data(ann_data[:int(len(ann_data)*0.8)])
validation_dataset = scgpt.process_data(ann_data[int(len(ann_data)*0.8):int(len(ann_data))])

Do the same for the cell type labels

In [14]:
cell_types = list(ann_data.obs.cell_type[:int(len(ann_data)*0.8)])
val_cell_types = list(ann_data.obs.cell_type[int(len(ann_data)*0.8):int(len(ann_data))])

Convert them into integer class labels

In [15]:
for i in range(len(cell_types)):
    cell_types[i] = class_id_dict[cell_types[i]]
for i in range(len(val_cell_types)):
    val_cell_types[i] = class_id_dict[val_cell_types[i]]


Define the scGPT fine-tuning model with the desired head and number of classes

In [16]:
scgpt_fine_tune = scGPTFineTuningModel(scGPT_model=scgpt, fine_tuning_head="classification", output_size=len(label_set))

For scGPT fine tuning we have to pass in the labels as a separate list
- This is the same for the validation and training sets

In [17]:
scgpt_fine_tune.train(train_input_data=dataset, train_labels=cell_types, validation_input_data=validation_dataset, validation_labels=val_cell_types)

Fine-Tuning: epoch 1/1: 100%|██████████| 960/960 [00:54<00:00, 17.58it/s, loss=0.953]
Fine-Tuning Validation: 100%|██████████| 240/240 [00:04<00:00, 57.46it/s, accuracy=0.902]


## UCE Fine-Tuning

In [None]:
uce_config=UCEConfig(batch_size=5, device=device)
uce = UCE(configurer=uce_config)

Prepare data the same way as for scGPT
- Add names for each dataset, as datasets are stored as .npz files and separate files are needed

In [None]:
dataset = uce.process_data(ann_data[:int(len(ann_data)*0.8)], name="train")
validation_dataset = uce.process_data(ann_data[int(len(ann_data)*0.8):int(len(ann_data))], name="validation")

cell_types = list(ann_data.obs.cell_type[:int(len(ann_data)*0.8)])
val_cell_types = list(ann_data.obs.cell_type[int(len(ann_data)*0.8):int(len(ann_data))])

Class to integer conversion

In [20]:
for i in range(len(cell_types)):
    cell_types[i] = class_id_dict[cell_types[i]]
for i in range(len(val_cell_types)):
    val_cell_types[i] = class_id_dict[val_cell_types[i]]

Define the fine-tuning model

In [21]:
uce_fine_tune = UCEFineTuningModel(uce_model=uce, fine_tuning_head="classification", output_size=len(label_set))

Fine-tune the model

In [22]:
uce_fine_tune.train(train_input_data=dataset, train_labels=cell_types, validation_input_data=validation_dataset, validation_labels=val_cell_types)

Fine-Tuning: epoch 1/1: 100%|██████████| 1919/1919 [04:59<00:00,  6.40it/s, loss=1.72]
Fine-Tuning Validation: 100%|██████████| 480/480 [00:28<00:00, 16.67it/s, accuracy=0.412]
