# Fine-tuning on Pre-trained Model for Multiomic Integration
In this tutorial, we demonstrate how to fine-tune a pre-trained model for multiomic integration. We use the BMMC dataset as an example for RNA and protein abundance data integration from CITE-seq. The BMMC dataset consists of multiple batches by donor origin. In the paired integration setting, each cell contains both RNA and protein measurements. We fine-tune on the pre-trained whole-body model.

We summarize the fine-tuning pipeline in the following steps, which can be used as a general recipe for finetuning on integration tasks and beyond: 

     1. Specify hyper-parameter setup for integration task
     
     2. Load and pre-process data
     
     3. Load the pre-trained scGPT model
     
     4. Finetune scGPT with task-specific objectives
     
     5. Evaluate fine-tuned scGPT

In [1]:
import copy
import gc
import json
import os
from pathlib import Path
import sys
import time
import traceback
#from typing import List, Tuple, Dict, Union, Optional
import warnings

import torch
from anndata import AnnData
import anndata
import scanpy as sc
# import scvi
import numpy as np
import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
import pandas as pd
import scipy.sparse as sp
from sklearn import preprocessing

# sys.path.insert(0, "../")
sys.path.append('/ailab/user/liuxinyuan/projects/scmbench')
from scgpt1 import prepare_data, prepare_dataloader, define_wandb_metrcis, evaluate, eval_testdata, train, eval_testdata_save_embed
from scgpt1.tokenizer import tokenize_and_pad_batch
from scgpt1.model import MultiOmicTransformerModel

import scgpt1 as scg
from scgpt1.tokenizer.gene_tokenizer import GeneVocab
from scgpt1.tokenizer import random_mask_value
from scgpt1.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt1.preprocess import Preprocessor
from scgpt1.utils import set_seed, category_str2int, eval_scib_metrics

sc.set_figure_params(figsize=(4, 4))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')

  IPython.display.set_matplotlib_formats(*ipython_format)


## Step1: Specify hyper-parameter setup for integration task
Here we provide some hyper-parameter recommendations for the multiomic integration task. Note that the BMMC dataset contains multiple batches to be integrated. Therefore, in addition to the default gene modelling objectives, we also turn on DAR objectives specifically to faciliate batch integration. We also turn on the use_mod argument as default to ensure that the model is modality-aware during training.

In [2]:
hyperparameter_defaults = dict(
    task = 'multiomic',
    seed=42,
    dataset_name='SCMBench', # Dataset name "BMMC"
    do_train=True, # Flag to indicate whether to do update model parameters during training
    load_model="/ailab/user/liuxinyuan/projects/foundation_models/scGPT/whole_human", # Path to pre-trained model
    freeze = True, #freeze
    GEP=True, # Gene expression modelling
    GEPC=True, # Gene expression modelling for cell objective
    CLS=False,
    ESC=False,
    DAR = True, # DAR objective weight for batch correction
    DSBN = False,  # Domain-spec batchnorm,
    mask_ratio=0.4, # Default mask ratio
    explicit_zero_prob = False,  # whether explicit bernoulli for zeros
    ecs_thres=0,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=1.0,
    use_batch_labels = True,
    use_mod = True,
    per_seq_batch_sample = False,
    epochs=10, # Default number of epochs for fine-tuning #25
    input_layer_key = "X_binned", # Default expression value binning in data pre-processing
    n_bins=51, # Default number of bins for value binning in data pre-processing
    n_hvg = 1200,  # Default number of highly variable genes
    n_hvp = 4000,
    max_seq_len = 4001, # # Default n_hvg+1
    lr=1e-3, # Default learning rate for fine-tuning
    batch_size=16, # Default batch size for fine-tuning
    layer_size=512,
    nlayers=8,
    nhead=8, # if load model, batch_size, layer_size, nlayers, nhead will be ignored
    dropout=0.2, # Default dropout rate during model fine-tuning
    schedule_ratio=0.95,  # Default rate for learning rate decay
    save_eval_interval=5, # Default model evaluation interval
    log_interval=100, # Default log interval
    fast_transformer=True, # Default setting
    pre_norm=False, # Default setting
    amp=True,  # Default setting: Automatic Mixed Precision
    pad_token = "<pad>",
    mask_value = -1,
    pad_value = -2,
    include_zero_gene = False,
)

In [3]:
run = wandb.init(
    config=hyperparameter_defaults,
    project="scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config
print(config)

set_seed(config.seed)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mliuxinyuan[0m ([33mliuxinyuan-chinese-university-of-hong-kong-shenzhen[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112512699845764, max=1.0…

{'task': 'multiomic', 'seed': 42, 'dataset_name': 'SCMBench', 'do_train': True, 'load_model': '/ailab/user/liuxinyuan/projects/foundation_models/scGPT/whole_human', 'freeze': True, 'GEP': True, 'GEPC': True, 'CLS': False, 'ESC': False, 'DAR': True, 'DSBN': False, 'mask_ratio': 0.4, 'explicit_zero_prob': False, 'ecs_thres': 0, 'dab_weight': 1.0, 'use_batch_labels': True, 'use_mod': True, 'per_seq_batch_sample': False, 'epochs': 10, 'input_layer_key': 'X_binned', 'n_bins': 51, 'n_hvg': 1200, 'n_hvp': 4000, 'max_seq_len': 4001, 'lr': 0.001, 'batch_size': 16, 'layer_size': 512, 'nlayers': 4, 'nhead': 8, 'dropout': 0.2, 'schedule_ratio': 0.95, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True, 'pad_token': '<pad>', 'mask_value': -1, 'pad_value': -2, 'include_zero_gene': False}


In [4]:
# settings for input and preprocessing
special_tokens = [config.pad_token, "<cls>", "<eoc>"]

In [5]:
config

{'task': 'multiomic', 'seed': 42, 'dataset_name': 'SCMBench', 'do_train': True, 'load_model': '/ailab/user/liuxinyuan/projects/foundation_models/scGPT/whole_human', 'freeze': True, 'GEP': True, 'GEPC': True, 'CLS': False, 'ESC': False, 'DAR': True, 'DSBN': False, 'mask_ratio': 0.4, 'explicit_zero_prob': False, 'ecs_thres': 0, 'dab_weight': 1.0, 'use_batch_labels': True, 'use_mod': True, 'per_seq_batch_sample': False, 'epochs': 10, 'input_layer_key': 'X_binned', 'n_bins': 51, 'n_hvg': 1200, 'n_hvp': 4000, 'max_seq_len': 4001, 'lr': 0.001, 'batch_size': 16, 'layer_size': 512, 'nlayers': 4, 'nhead': 8, 'dropout': 0.2, 'schedule_ratio': 0.95, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True, 'pad_token': '<pad>', 'mask_value': -1, 'pad_value': -2, 'include_zero_gene': False}

In [6]:
dataset_name = config.dataset_name
save_dir = Path(f"/ailab/user/liuxinyuan/projects/scmbench/save/dev_{dataset_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"save to {save_dir}")
logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")

save to /ailab/user/liuxinyuan/projects/scmbench/save/dev_SCMBench-Aug31-13-19


## Step 2: Load and pre-process data

### 2.1 Load the BMMC data
Please download the data from https://drive.google.com/file/d/10RxboePS5p2Jj2Sfq1Ghzgqnl6nqPv5V/view?usp=sharing

For the purpose of this tutorial, we selected a 12K+ cell subset from the BMMC dataset (originally 90K+ cells) to speed up training. More specifically, we filtered on the first 3 donors (batches) with B, Mono and T cell subtypes as shown in the data loading step below.

In [7]:
str(save_dir)

'/ailab/user/liuxinyuan/projects/scmbench/save/dev_SCMBench-Aug31-13-19'

In [8]:
if dataset_name == 'SCMBench':
    dataset_dir_name='10x-Multiome-Pbmc10k'
    # dataset_dir_name='Chen-2019-small'
    # dataset_dir_name='Ma-2020-batch-56-small'
    input_rna_dir='/ailab/user/liuxinyuan/projects/scmbench/10XPBMC/'+dataset_dir_name+'-RNA.h5ad'
    input_atac_dir='/ailab/user/liuxinyuan/projects/scmbench/10XPBMC/'+dataset_dir_name+'-ATAC.h5ad'

    output_rna_dir=str(save_dir)+'/'+dataset_dir_name+'-rna.csv'
    output_atac_dir=str(save_dir)+'/'+dataset_dir_name+'-atac.csv'
if dataset_name == 'SCMBench-simi':
    input_rna_dir='/mnt/nas/user/yixuan/scGPT/tutorials/scmdata/simulated_num_cell_5000_num_batch_3_num_gene_3000_effect_0.5-RNA.h5ad'
    input_atac_dir='/mnt/nas/user/yixuan/scGPT/tutorials/scmdata/simulated_num_cell_5000_num_batch_3_num_gene_3000_effect_0.5-ATAC.h5ad'

    output_rna_dir='/mnt/nas/user/yixuan/scGPT/tutorials/scmdata/simulated_num_cell_5000_num_batch_3_num_gene_3000_effect_0.5-rna.csv'
    output_atac_dir='/mnt/nas/user/yixuan/scGPT/tutorials/scmdata/simulated_num_cell_5000_num_batch_3_num_gene_3000_effect_0.5-atac.csv'

In [9]:
print(input_rna_dir)
print(output_rna_dir)

/ailab/user/liuxinyuan/projects/scmbench/10XPBMC/10x-Multiome-Pbmc10k-RNA.h5ad
/ailab/user/liuxinyuan/projects/scmbench/save/dev_SCMBench-Aug31-13-19/10x-Multiome-Pbmc10k-rna.csv


In [10]:
if dataset_name == 'BMMC':
    adata = sc.read('scmdata/BMMC_processed.h5ad')
    # subset to first 3 donors with B, Mono and T cell subtypes
    adata = adata[adata.obs.DonorID.isin([10886, 11466, 12710]) & adata.obs.cell_type.isin(np.unique(adata.obs.cell_type.values)[:17])]
    adata.obs["celltype"] = adata.obs["cell_type"].astype(str).astype('category')
    adata.var["gene_name"] = adata.var.index.tolist()
    le = preprocessing.LabelEncoder()
    encoded_batch = le.fit_transform(adata.obs['batch'].values)
    adata.obs["batch_id"] =  encoded_batch
    adata.obs["str_batch"] = adata.obs["batch_id"].astype('category')
    adata_protein = adata[:, adata.var.feature_types.isin(['ADT'])].copy()
    adata_protein.var.index = ['p_' + i for i in adata_protein.var.index]
    adata = adata[:, adata.var.feature_types.isin(['GEX'])].copy()
    data_is_raw = False
if dataset_name == 'SCMBench':
    adata = anndata.read_h5ad(input_rna_dir)
    adata_protein = anndata.read_h5ad(input_atac_dir)
    adata.obs["celltype"] = adata.obs["cell_type"].astype(str).astype('category')
    adata.var["gene_name"] = adata.var.index.tolist()
    le = preprocessing.LabelEncoder()
    encoded_batch = le.fit_transform(adata.obs['domain'].values)
    adata.obs["batch_id"] =  encoded_batch
    adata.obs["str_batch"] = adata.obs["batch_id"].astype('category')
    data_is_raw = False
    print('#')
if dataset_name == 'SCMBench-simi':
    adata = anndata.read_h5ad(input_rna_dir)
    adata_protein = anndata.read_h5ad(input_atac_dir)
    adata.obs["celltype"] = adata.obs["cell_type"].astype(str).astype('category')
    adata.var["gene_name"] = adata.var.index.tolist()
    le = preprocessing.LabelEncoder()
    encoded_batch = le.fit_transform(adata.obs['batch'].values)
    adata.obs["batch_id"] =  encoded_batch
    adata.obs["str_batch"] = adata.obs["batch_id"].astype('category')
    data_is_raw = False
    print('#')

#


In [11]:
adata.obs

Unnamed: 0_level_0,domain,protocol,dataset,orig.ident,nCount_RNA,nFeature_RNA,percent.mt,nCount_ATAC,nFeature_ATAC,nCount_SCT,nFeature_SCT,SCT.weight,ATAC.weight,wsnn_res.0.8,seurat_clusters,sub.cluster,cell_type,celltype,batch_id,str_batch
cells,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
AAACAGCCAATCCCTT-1,scRNA-seq,10x Multiome,10x-Multiome-Pbmc10k-RNA,SeuratProject,3771.0,1896.0,10.527711,20485.0,7247.0,3781.0,1895.0,0.507959,0.492041,4.0,4.0,4,CD4 TCM,CD4 TCM,0,0
AAACAGCCAATGCGCT-1,scRNA-seq,10x Multiome,10x-Multiome-Pbmc10k-RNA,SeuratProject,6876.0,2904.0,6.457243,16674.0,6528.0,4686.0,2769.0,0.425481,0.574519,1.0,1.0,1,CD4 Naive,CD4 Naive,0,0
AAACAGCCACCAACCG-1,scRNA-seq,10x Multiome,10x-Multiome-Pbmc10k-RNA,SeuratProject,5415.0,2282.0,6.500462,7658.0,3323.0,4486.0,2278.0,0.471465,0.528535,2.0,2.0,2,CD8 Naive,CD8 Naive,0,0
AAACAGCCAGGATAAC-1,scRNA-seq,10x Multiome,10x-Multiome-Pbmc10k-RNA,SeuratProject,2759.0,1353.0,6.922798,10355.0,4267.0,3332.0,1352.0,0.538544,0.461456,1.0,1.0,1,CD4 Naive,CD4 Naive,0,0
AAACAGCCAGTTTACG-1,scRNA-seq,10x Multiome,10x-Multiome-Pbmc10k-RNA,SeuratProject,3633.0,1691.0,4.183870,20516.0,7241.0,3676.0,1691.0,0.461856,0.538144,4.0,4.0,4,CD4 TCM,CD4 TCM,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGGTGACATGC-1,scRNA-seq,10x Multiome,10x-Multiome-Pbmc10k-RNA,SeuratProject,3285.0,1586.0,10.898021,15264.0,5715.0,3489.0,1585.0,0.399426,0.600574,2.0,2.0,2,CD8 Naive,CD8 Naive,0,0
TTTGTTGGTGTTAAAC-1,scRNA-seq,10x Multiome,10x-Multiome-Pbmc10k-RNA,SeuratProject,2866.0,1525.0,11.060712,17865.0,6628.0,3310.0,1524.0,0.795064,0.204936,8.0,8.0,8,CD8 Naive,CD8 Naive,0,0
TTTGTTGGTTAGGATT-1,scRNA-seq,10x Multiome,10x-Multiome-Pbmc10k-RNA,SeuratProject,4148.0,2024.0,9.932498,10556.0,4266.0,4081.0,2022.0,0.499490,0.500510,7.0,7.0,7,NK,NK,0,0
TTTGTTGGTTGGTTAG-1,scRNA-seq,10x Multiome,10x-Multiome-Pbmc10k-RNA,SeuratProject,2976.0,1620.0,5.342742,17053.0,6366.0,3341.0,1618.0,0.598785,0.401215,3.0,3.0,3,CD4 TCM,CD4 TCM,0,0


In [12]:
if config.use_mod:
    gene_rna_df = pd.DataFrame(index = adata.var.index.tolist())
    gene_rna_df['mod'] = 'RNA'
    gene_protein_df = pd.DataFrame(index = adata_protein.var.index.tolist())
    gene_protein_df['mod'] = 'Protein'
    gene_loc_df = pd.concat([gene_rna_df, gene_protein_df])
    gene_loc_df['mod'] = gene_loc_df['mod'].astype('category')

### 2.2 Cross-check gene set with the pre-trained model 
Note that we retain the common gene set between the data and the pre-trained model for further fine-tuning.

In [26]:
if config.load_model is not None:
    model_dir = Path(config.load_model)
    model_config_file = model_dir / "args.json"
    model_file = model_dir / "best_model.pt"
    vocab_file = model_dir / "vocab.json"

    vocab = GeneVocab.from_file(vocab_file)
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in adata.var["gene_name"]
    ]
    gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
    logger.info(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    old_vocab = vocab
embsize = config.layer_size
nhead = config.nhead
nlayers = config.nlayers
d_hid = config.layer_size

scGPT - INFO - match 749/108405 genes in vocabulary of size 60697.


### 2.3 Pre-process the data
We follow the standardized pipline of depth normalization, log normalization, and highly vairable gene (HVG) selection for data pre-processing. We further introduce value binning to obtain the relative expressions of each HVG. Given multiple sequencing modalities, we perform the pre-processing steps on each individual modality first, and then combine them into multi-modal sequences as model input.

#### 2.3.1 Pre-process the RNA data

In [14]:
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=1,  # step 1
    filter_cell_by_counts=1,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=config.n_hvg,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key=None)

scGPT - INFO - Filtering genes by counts ...


scGPT - INFO - Filtering cells by counts ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Subsetting highly variable genes ...
scGPT - INFO - Binning data ...


#### 2.3.2 Pre-process the Protein data

In [15]:
preprocessor_protein = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=0,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=False,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=False,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=False,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor=None,
    binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor_protein(adata_protein, batch_key=None)

scGPT - INFO - Binning data ...


#### 2.3.3 Combine RNA, and Protein data

In [16]:
data_combined = np.concatenate([adata.layers["X_binned"], adata_protein.layers["X_binned"]], axis=1)
adata = AnnData(
    X=data_combined,
    obs=adata.obs,
    var=pd.DataFrame(index=adata.var_names.tolist() + adata_protein.var_names.tolist()),
    layers={"X_binned": data_combined,}
)
adata.var["gene_name"] = adata.var.index.tolist()

In [17]:
if config.per_seq_batch_sample:
    # sort the adata by batch_id in advance
    adata_sorted = adata[adata.obs["batch_id"].argsort()].copy()

In [18]:
adata

AnnData object with n_obs × n_vars = 9631 × 108405
    obs: 'domain', 'protocol', 'dataset', 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'nCount_ATAC', 'nFeature_ATAC', 'nCount_SCT', 'nFeature_SCT', 'SCT.weight', 'ATAC.weight', 'wsnn_res.0.8', 'seurat_clusters', 'sub.cluster', 'cell_type', 'celltype', 'batch_id', 'str_batch', 'n_counts'
    var: 'gene_name'
    layers: 'X_binned'

### 2.4 Tokenize the input data for model fine-tuning

In [19]:
all_counts = (
    adata.layers[config.input_layer_key].A
    if issparse(adata.layers[config.input_layer_key])
    else adata.layers[config.input_layer_key]
)
genes = adata.var["gene_name"].tolist()

celltypes_labels = adata.obs["cell_type"].tolist()  # make sure count from 0
num_types = len(set(celltypes_labels))
celltypes_labels = np.array(celltypes_labels)

batch_ids = adata.obs["batch_id"].tolist()
num_batch_types = len(set(batch_ids))
batch_ids = np.array(batch_ids)

In [20]:
if config.use_mod:
    mod_type = np.array([gene_loc_df.loc[g, 'mod'] for g in genes])
    vocab_mod = Vocab(VocabPybind(np.unique(gene_loc_df['mod']).tolist() + special_tokens, None))
    vocab_mod.set_default_index(vocab_mod["<pad>"])
    mod_type = np.array(vocab_mod(list(mod_type)), dtype=int)
    ntokens_mod = len(vocab_mod)

In [21]:
(
    train_data,
    valid_data,
    train_celltype_labels,
    valid_celltype_labels,
    train_batch_labels,
    valid_batch_labels,
) = train_test_split(
    all_counts, celltypes_labels, batch_ids, test_size=0.1, shuffle=True
)


In [22]:
num_of_non_zero_genes = [
    np.count_nonzero(train_data[i]) for i in range(train_data.shape[0])
]
print(f"max num of non_zero genes: {np.max(num_of_non_zero_genes)}")
print(f"min num of non_zero genes: {np.min(num_of_non_zero_genes)}")
print(f"average num of non_zero genes: {np.mean(num_of_non_zero_genes)}")
print(
    f"99% quantile num of non_zero genes: {np.quantile(num_of_non_zero_genes, 0.99)}"
)
print(f"max original values: {np.max(train_data)}")
print(
    f"average original non_zero values: {np.mean(train_data[np.nonzero(train_data)])}"
)
print(
    f"99% quantile original non_zero values: {np.quantile(train_data[np.nonzero(train_data)], 0.99)}"
)
print(f"num of celltypes: {num_types}")

max num of non_zero genes: 20884
min num of non_zero genes: 2243
average num of non_zero genes: 7383.912887965847
99% quantile num of non_zero genes: 14499.680000000008
max original values: 50
average original non_zero values: 25.481826774776753
99% quantile original non_zero values: 49.0
num of celltypes: 19


In [23]:
if config.load_model is None:
    vocab = Vocab(VocabPybind(genes + special_tokens, None))
    vocab.set_default_index(vocab["<pad>"])
    gene_ids = np.array(vocab(genes), dtype=int)
else:
    pretrained_genes = [g for g in genes + special_tokens if g in old_vocab]
    new_genes = [g for g in genes + special_tokens if g not in old_vocab]
    gene_ids_pretrained = np.array(old_vocab(pretrained_genes), dtype=int)
    # https://discuss.pytorch.org/t/expand-an-existing-embedding-and-linear-layer-nan-loss-value/55670/2
    # Retrieve pretrained weights
    vocab = Vocab(VocabPybind(pretrained_genes + new_genes, None))
    vocab.set_default_index(vocab["<pad>"])
    gene_ids = np.array(vocab(genes), dtype=int)

In [24]:
tokenized_train = tokenize_and_pad_batch(
    train_data,
    gene_ids,
    max_len=config.max_seq_len,
    vocab=vocab,
    pad_token=config.pad_token,
    pad_value=config.pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=config.include_zero_gene,
    mod_type=mod_type if config.use_mod else None,
    vocab_mod=vocab_mod if config.use_mod else None,
)
tokenized_valid = tokenize_and_pad_batch(
    valid_data,
    gene_ids,
    max_len=config.max_seq_len,
    vocab=vocab,
    pad_token=config.pad_token,
    pad_value=config.pad_value,
    append_cls=True,
    include_zero_gene=config.include_zero_gene,
    mod_type=mod_type if config.use_mod else None,
    vocab_mod=vocab_mod if config.use_mod else None,
)
logger.info(
    f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
)
logger.info(
    f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
)

scGPT - INFO - train set number of samples: 8667, 
	 feature length: 4001
scGPT - INFO - valid set number of samples: 964, 
	 feature length: 4001


## Step 3: Load the pre-trained scGPT model
Note that for multiomic integration, since the pre-trained model does not include the ATAC and protein tokens, we expand the embedding layer by adding these new tokens. We inherit only the gene embedding layer from the pre-trained model, and train rest of the model from scratch.

In [25]:
torch.cuda.is_available()

True

In [29]:
nlayers=8

In [30]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_number = 0
if torch.cuda.is_available():
    device = torch.device(f"cuda:{device_number}")
else:
    device = torch.device("cpu")
model_dict = torch.load(model_file)
ntokens = len(vocab)  # size of vocabulary
model = MultiOmicTransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid, 
    nlayers,
    vocab=vocab,
    dropout=config.dropout,
    pad_token=config.pad_token,
    pad_value=config.pad_value,
    do_mvc=config.GEPC,
    do_dab=config.DAR,
    use_batch_labels=config.use_batch_labels,
    num_batch_labels=num_batch_types,
    domain_spec_batchnorm=config.DSBN,
    n_input_bins=config.n_bins,
    ecs_threshold=config.ecs_thres,
    explicit_zero_prob=config.explicit_zero_prob,
    use_fast_transformer=False,
    pre_norm=config.pre_norm,
    use_mod=config.use_mod,
    ntokens_mod=ntokens_mod if config.use_mod else None,
    vocab_mod=vocab_mod if config.use_mod else None,
)

with torch.no_grad():
    pretrained_emb_weights = model_dict['encoder.embedding.weight'][gene_ids_pretrained, :]
    model.encoder.embedding.weight.data[:len(pretrained_genes), :] = pretrained_emb_weights
    model.encoder.enc_norm.weight.data = model_dict['encoder.enc_norm.weight']

    model_state_dict = model.state_dict()

    pretrained_state_dict = {}
    for key, value in model_dict.items():
        if key in model_state_dict:
            if key.startswith('value_encoder') or key.startswith('transformer_encoder.layers.0') or \
            key.startswith('transformer_encoder.layers.1') or key.startswith('transformer_encoder.layers.2') or \
            key.startswith('transformer_encoder.layers.3') or key.startswith('transformer_encoder.layers.4') or \
            key.startswith('transformer_encoder.layers.5') or key.startswith('transformer_encoder.layers.6') or \
            key.startswith('transformer_encoder.layers.7'):
                pretrained_state_dict[key] = value

    model_state_dict.update(pretrained_state_dict)
    model.load_state_dict(model_state_dict)

ntokens = len(vocab)

# model.to('cuda')
model.to(device)
# wandb.watch(model)
print(model)

MultiOmicTransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (batch_encoder): BatchLabelEncoder(
    (embedding): Embedding(1, 512)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (mod_encoder): BatchLabelEncoder(
    (embedding): Embedding(5, 512, padding_idx=2)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQ

In [31]:
if config.GEP and config.GEPC:
    criterion_gep_gepc = masked_mse_loss
if config.CLS:
    criterion_cls = nn.CrossEntropyLoss()
if config.DAR:
    criterion_dab = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=config.lr, eps=1e-4 if config.amp else 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=config.schedule_ratio)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

## Step 4: Finetune scGPT with task-specific objectives

In [None]:
best_val_loss = float("inf")
best_avg_bio = 0.0
best_model = None
define_wandb_metrcis()

for epoch in range(1, config.epochs + 1):
    epoch_start_time = time.time()
    train_data_pt, valid_data_pt = prepare_data(
        tokenized_train=tokenized_train, 
        tokenized_valid=tokenized_valid, 
        train_batch_labels=train_batch_labels,
        valid_batch_labels=valid_batch_labels,
        config=config,
        epoch=epoch,
        sort_seq_batch=config.per_seq_batch_sample)
    
    train_loader = prepare_dataloader(
        train_data_pt,
        batch_size=config.batch_size,
        shuffle=True,
        intra_domain_shuffle=False,
        drop_last=False,
        per_seq_batch_sample=config.per_seq_batch_sample
    )
    valid_loader = prepare_dataloader(
        valid_data_pt,
        batch_size=config.batch_size,
        shuffle=False,
        intra_domain_shuffle=False,
        drop_last=False,
        per_seq_batch_sample=config.per_seq_batch_sample
    )

    if config.do_train:
        train(
            model=model,
            loader=train_loader,
            vocab=vocab,
            criterion_gep_gepc=criterion_gep_gepc if config.GEP and config.GEPC else None,
            criterion_dab=criterion_dab if config.DAR else None,
            criterion_cls=criterion_cls if config.CLS else None,
            scaler=scaler,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            config=config,
            logger=logger,
            epoch=epoch,
        )
    val_loss = evaluate(
        model=model,
        loader=valid_loader,
        vocab=vocab,
        criterion_gep_gepc=criterion_gep_gepc if config.GEP and config.GEPC else None,
        criterion_dab=criterion_dab if config.DAR else None,
        criterion_cls=criterion_cls if config.CLS else None,
        device=device,
        config=config,
        epoch=epoch
    )
    elapsed = time.time() - epoch_start_time
    logger.info("-" * 89)
    logger.info(
        f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
        f"valid loss {val_loss:5.4f} | "
    )
    logger.info("-" * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        logger.info(f"Best model with score {best_val_loss:5.4f}")

    if epoch % config.save_eval_interval == 0 or epoch == config.epochs:
        logger.info(f"Saving model to {save_dir}")
        torch.save(best_model.state_dict(), save_dir / f"model_e{best_model_epoch}.pt")

        # eval on testdata
        # results = eval_testdata(
        #     model = best_model,
        #     adata_t = adata_sorted if config.per_seq_batch_sample else adata,
        #     gene_ids = gene_ids,
        #     vocab = vocab,
        #     config = config,
        #     logger = logger,
        #     include_types=["cls"],
        # )
        results,cell_embeddings = eval_testdata_save_embed(
            model = best_model,
            adata_t = adata_sorted if config.per_seq_batch_sample else adata,
            gene_ids = gene_ids,
            vocab = vocab,
            config = config,
            logger = logger,
            include_types=["cls"],
        )

        pd.DataFrame(cell_embeddings, index=adata.obs_names).to_csv(output_rna_dir, header=False)
        pd.DataFrame(cell_embeddings, index=adata.obs_names).to_csv(output_atac_dir, header=False)

        results["batch_umap"].savefig(
            save_dir / f"embeddings_batch_umap[cls]_e{best_model_epoch}.png", dpi=300
        )

        results["celltype_umap"].savefig(
            save_dir / f"embeddings_celltype_umap[cls]_e{best_model_epoch}.png", dpi=300
        )
        metrics_to_log = {"test/" + k: v for k, v in results.items()}
        metrics_to_log["test/batch_umap"] = wandb.Image(
            str(save_dir / f"embeddings_batch_umap[cls]_e{best_model_epoch}.png"),
            caption=f"celltype avg_bio epoch {best_model_epoch}",
        )

        metrics_to_log["test/celltype_umap"] = wandb.Image(
            str(save_dir / f"embeddings_celltype_umap[cls]_e{best_model_epoch}.png"),
            caption=f"celltype avg_bio epoch {best_model_epoch}",
        )
        metrics_to_log["test/best_model_epoch"] = best_model_epoch
        wandb.log(metrics_to_log)
        wandb.log({"avg_bio": results.get("avg_bio", 0.0)})

    scheduler.step()

In [32]:
# save the best model
torch.save(best_model.state_dict(), save_dir / "best_model.pt")

wandb: Network error (ProxyError), entering retry loop.


In [None]:
artifact = wandb.Artifact(f"best_model", type="model")
glob_str = os.path.join(save_dir, "best_model.pt")
artifact.add_file(glob_str)
run.log_artifact(artifact)

run.finish()
wandb.finish()
gc.collect()