# scGPT
Custom implemention of the scGPT model

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
import scanpy as sc
import time
import copy
from typing import List, Tuple, Dict, Union, Optional
from scipy.sparse import issparse
from sklearn.model_selection import train_test_split

from utils import set_seed, AttrDict
from myTorchtext import Vocab
from preprocess import Preprocessor
from tokenizer import tokenize_and_pad_batch, retrieve_tfs, random_mask_value
# from model import TransformerModel
from model_bioformer import TransformerModel
from loss import masked_mse_loss, masked_relative_error, criterion_neg_log_bernoulli

In [2]:
config = AttrDict({
    "run_name": "",
    "dataset_name": "HYPOXIA_9K",
    "seed": None,
    "d_model": 64,
    "nhead": 4,
    "nlayers": 1,
    "n_bins": 51,
    "n_hvg": 100,
    "include_zero_gene": False,
    "mask_single_value": False,
    "dropout": 0.2,
    "batch_size": 32,
    "epochs": 1,
    "lr": 0.0001,
    "amp": True,
    "schedule_ratio": 0.9,
    "GEPC": False,   # If Gene Expression Prediction for Cell Modelling objective (MLM from <cls> only) TODO in model.py
    "explicit_zero_prob": True, # if modelling gene expression also as bern var
    "do_train": True,
    "log_interval": 100,
    "wandb": False,
})

if config.seed:
    set_seed(config.seed)

if config.wandb:
    import wandb
    wandb.login()
    run = wandb.init(
        project='BioFormer',
        config = config,
        name = config.run_name if config.run_name else None
    )

In [3]:
# Pre-processing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_value = -1 # in the value vector corresponding to msk token (!= msk token index in vocab)
pad_value = -2  # in the value vector corresponding to pad token (!= pad token index in vocab)
n_input_bins = config.n_bins
include_zero_gene = config.include_zero_gene
n_hvg = config.n_hvg
max_seq_len = n_hvg + 1

In [4]:
# Import data
dataset_name = config.dataset_name

if dataset_name == 'BREAST_25K':
    adata = sc.read_h5ad('./data/breast_25k.h5ad')
    data_is_raw = True

elif dataset_name == 'BREAST_12K':
    adata = sc.read_h5ad('./data/breast_12k.h5ad')
    data_is_raw = True

elif dataset_name == 'DERMAL_100K':
    adata = sc.read_h5ad('./data/dermal_100k.h5ad')
    adata.var["gene_name"] = adata.var.feature_name.tolist()
    data_is_raw = True

elif dataset_name == 'HYPOXIA_9K':
    adata = sc.read_h5ad('./data/scsHypoxiaTimeSub.h5ad')
    adata.X = adata.layers['raw_count']
    adata.var['gene_name'] = adata.var.index.tolist()
    data_is_raw = True

print(dataset_name)
print(adata)

HYPOXIA_9K
AnnData object with n_obs × n_vars = 9234 × 19046
    obs: 'nCount_RNA', 'nFeature_RNA', 'SampleTags', 'percent.mt', 'HypoxicState', 'TimePoint', 'nCount_SCT', 'nFeature_SCT', 'S.Score', 'G2M.Score', 'Phase', 'seurat_clusters', 'SampleTagsShort', 'active_ident'
    var: 'variable_gene', 'gene_name'
    uns: 'active_ident_colors', 'seurat_clusters_colors'
    obsm: 'X_pca', 'X_umap'
    layers: 'raw_count'


In [5]:
# Pre-process adata
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=3,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=config.n_hvg,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)

preprocessor(adata, batch_key=None)


Filtering genes by counts ...
Normalizing total counts ...
Log1p transforming ...
Subsetting highly variable genes ...
No batch_key is provided, will use all cells for HVG selection.
Binning data ...


In [6]:
input_layer_key = "X_binned"
all_counts = (
    adata.layers[input_layer_key].toarray()
    if issparse(adata.layers[input_layer_key])
    else adata.layers[input_layer_key]
)
genes = adata.var["gene_name"].tolist()

train_data, valid_data = train_test_split(all_counts, test_size=0.1, shuffle=True)

# Vocab
stoi = {s:i for i, s in enumerate(genes + special_tokens)}
itos = {i:s for i, s in enumerate(genes + special_tokens)}
vocab = Vocab(stoi, itos)
vocab.set_default_index(vocab["<pad>"]) # index to return if token not found in vocab
gene_ids = np.array(vocab(genes), dtype=int)
print(f'Vocab of size: {len(vocab)} --> {len(genes)} genes, {len(special_tokens)} special tokens {special_tokens}')

Vocab of size: 103 --> 100 genes, 3 special tokens ['<pad>', '<cls>', '<eoc>']


In [7]:
tokenized_train = tokenize_and_pad_batch(
    train_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=include_zero_gene,
)
tokenized_valid = tokenize_and_pad_batch(
    valid_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,
    include_zero_gene=include_zero_gene,
)
print(f"Train samples: {tokenized_train['genes'].shape[0]}")
print(f"Valid samples: {tokenized_valid['genes'].shape[0]}")
print(f"Input length: {tokenized_valid['genes'].shape[1]}")

Train samples: 8310
Valid samples: 924
Input length: 72


In [8]:
def prepare_data(use_condition_labels = False):
    
    masked_values_train = random_mask_value(
        tokenized_train["values"],
        mask_value=mask_value,
        pad_value=pad_value,
        mask_single_value = config.mask_single_value
    )
    masked_values_valid = random_mask_value(
        tokenized_valid["values"],
        mask_value=mask_value,
        pad_value=pad_value,
        mask_single_value = config.mask_single_value
    )

    print(f"random masking at epoch {epoch}, ratio of masked values in train: {(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}")

    input_gene_ids_train, input_gene_ids_valid = (
        tokenized_train["genes"],
        tokenized_valid["genes"],
    )
    input_values_train, input_values_valid = masked_values_train, masked_values_valid
    target_values_train, target_values_valid = (
        tokenized_train["values"],
        tokenized_valid["values"],
    )

    train_data_pt = {
        "gene_ids": input_gene_ids_train,
        "values": input_values_train,
        "target_values": target_values_train,
    }
    valid_data_pt = {
        "gene_ids": input_gene_ids_valid,
        "values": input_values_valid,
        "target_values": target_values_valid,
    }

    # if use_condition_labels:
    #     train_data_pt['conditions'] = retrieve_tfs(
    #         input_gene_ids_train,
    #         input_values_train,     # masked
    #         tf = tf                                  
    #     )
    #     valid_data_pt['conditions'] = retrieve_tfs(
    #         input_gene_ids_valid,
    #         input_values_valid,      # masked
    #         tf = tf                                  
    #     )

    return train_data_pt, valid_data_pt

# dataset
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}


# data_loader
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = 0,
) -> DataLoader:
    dataset = SeqDataset(data_pt)

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
    )
    
    return data_loader


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ntoken = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntoken=ntoken,
    d_model=config.d_model,
    nhead=config.nhead,
    d_hid=config.d_model,
    nlayers=config.nlayers,
    vocab=vocab,
    dropout=config.dropout,
    pad_token=pad_token,
    # pad_value=pad_value,
) 

model.to(device)

print(f'''
device: {device} | d_model: {config.d_model} | nhead: {config.nhead} | nlayers: {config.nlayers} | tot. params: {sum(p.numel() for p in model.parameters())/1e6:.0f}M
''')



device: cpu | d_model: 64 | nhead: 4 | nlayers: 1 | tot. params: 0M



In [10]:
criterion = masked_mse_loss
criterion_dab = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    model.parameters(), lr=config.lr, eps=1e-4 if config.amp else 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=config.schedule_ratio)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)



In [11]:
def train(model: nn.Module, loader: DataLoader) -> None:
    """
    Train the model for one epoch.
    """
    model.train()
    total_loss, total_mse, total_gepc = 0.0, 0.0, 0.0
    total_mre = 0.0
    log_interval = config.log_interval
    start_time = time.time()

    num_batches = len(loader)
    for batch, batch_data in enumerate(loader):
        input_gene_ids = batch_data["gene_ids"].to(device)
        input_values = batch_data["values"].to(device)
        target_values = batch_data["target_values"].to(device)

        # ---------- FORWARD PASS -------------------
        with torch.cuda.amp.autocast(enabled=config.amp):

            output_dict = model(input_gene_ids, input_values)
            masked_positions = input_values.eq(mask_value)  # the postions to predict
            loss = loss_mse = criterion(output_dict["mlm_output"], target_values, masked_positions)
            
            metrics_to_log = {"train/mse": loss_mse.item()}
            
            if config.explicit_zero_prob:
                loss_zero_log_prob = criterion_neg_log_bernoulli(output_dict["mlm_zero_probs"], target_values, masked_positions)
                loss += loss_zero_log_prob
                metrics_to_log.update({"train/nzlp": loss_zero_log_prob.item()})
            
            if config.GEPC:
                loss_gepc = criterion(output_dict["mvc_output"], target_values, masked_positions)
                loss += loss_gepc
                metrics_to_log.update({"train/mvc": loss_gepc.item()})
            
            if config.GEPC and config.explicit_zero_prob:
                loss_gepc_zero_log_prob = criterion_neg_log_bernoulli(output_dict["mvc_zero_probs"], target_values, masked_positions)
                loss = loss + loss_gepc_zero_log_prob
                metrics_to_log.update({"train/mvc_nzlp": loss_gepc_zero_log_prob.item()})

        # ---------- BACKWARD PASS ------------------
        model.zero_grad()
        scaler.scale(loss).backward()   # training via the aggregated loss
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()
        # -------------------------------------------
        
        if config.wandb:
            wandb.log(metrics_to_log)

        # Compute MRE for validation
        with torch.no_grad():
            mre = masked_relative_error(output_dict["mlm_output"], target_values, masked_positions)

        total_loss += loss.item()                               # sum of all losses
        total_mse += loss_mse.item()                            # MSE alone
        total_gepc += loss_gepc.item() if config.GEPC else 0.0  # MSE from GEPC alone
        total_mre += mre.item()                                 # MRE alone
        
        # Avg of loss across all log_interval batches (i.e., log_interval = 10, avg loss every 10 batches)
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            cur_mse = total_mse / log_interval
            cur_gepc = total_gepc / log_interval if config.GEPC else 0.0
            cur_mre = total_mre / log_interval
            
            print(f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | train/loss {cur_loss:5.2f} | train/mse {cur_mse:5.2f} |" + (f"train/gepc {cur_gepc:5.2f} |" if config.GEPC else "") + f"train/mre {cur_mre:5.2f} |" )
            
            total_loss = 0
            total_mse = 0
            total_gepc = 0
            total_mre = 0
            start_time = time.time()



In [12]:
def define_wandb_metrics():
    wandb.define_metric("valid/mse", summary="min", step_metric="epoch")
    wandb.define_metric("valid/mre", summary="min", step_metric="epoch")

In [13]:
def evaluate(model: nn.Module, loader: DataLoader) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_mre = 0.0
    total_num = 0
    with torch.no_grad():
        for batch_data in loader:
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            target_values = batch_data["target_values"].to(device)

            with torch.cuda.amp.autocast(enabled=config.amp):
                output_dict = model(input_gene_ids, input_values)
                output_values = output_dict["mlm_output"]

                masked_positions = input_values.eq(mask_value)
                loss = criterion(output_values, target_values, masked_positions)

            total_loss += loss.item() * len(input_gene_ids)
            total_mre += masked_relative_error(output_values, target_values, masked_positions).item() * len(input_gene_ids)
            total_num += len(input_gene_ids)

    if config.wandb:
        wandb.log({ 
            "valid/mse": total_loss / total_num,
            "valid/mre": total_mre / total_num,
            "epoch": epoch
            })

    return total_loss / total_num, total_mre / total_num


In [14]:
best_val_loss = float("inf")
best_avg_bio = 0.0
best_model = None
if config.wandb:
    define_wandb_metrics()

for epoch in range(1, config.epochs + 1):
    epoch_start_time = time.time()
    
    train_data_pt, valid_data_pt = prepare_data()
    
    train_loader = prepare_dataloader(
        train_data_pt,
        batch_size=config.batch_size,
        shuffle=False,
        drop_last=False,
    )
    valid_loader = prepare_dataloader(
        valid_data_pt,
        batch_size=config.batch_size,
        shuffle=False,
        drop_last=False,
    )

    # TRAINING      --> over all batches in the train_loader
    if config.do_train:
        train(model, loader=train_loader)

    # VALIDATION    --> avg loss across all batches in valid_loader
    val_loss, val_mre = evaluate(model, loader=valid_loader)
    
    
    # Some epoch-related stats
    elapsed = time.time() - epoch_start_time
    print("-" * 89)
    print(f"| end of epoch {epoch:3d} | runtime: {elapsed:5.2f}s | valid/mse {val_loss:5.4f} | valid/mre {val_mre:5.4f}")
    print("-" * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        print(f"Best model with valid/mse {best_val_loss:5.4f}")

    scheduler.step()

random masking at epoch 1, ratio of masked values in train: 0.1383




| epoch   1 | 100/260 batches | lr 0.0001 | ms/batch 1561.17 | train/loss 710.12 | train/mse 710.12 |train/mre 62695.87 |
| epoch   1 | 200/260 batches | lr 0.0001 | ms/batch 1496.91 | train/loss 231.35 | train/mse 231.35 |train/mre 629675.65 |
-----------------------------------------------------------------------------------------
| end of epoch   1 | runtime: 431.66s | valid/mse 212.9661 | valid/mre 560529.2197
-----------------------------------------------------------------------------------------
Best model with valid/mse 212.9661


In [15]:
if config.wandb:
    run.finish()

# BioFormer
Adaptation of AF2 modules to work with RNA-seq

In [16]:
raise Warning("BioFormer Section")

Warning: BioFormer Section

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
import scanpy as sc
import time
import copy
from typing import List, Tuple, Dict, Union, Optional
from scipy.sparse import issparse
from sklearn.model_selection import train_test_split

from utils import set_seed, AttrDict
from myTorchtext import Vocab
from preprocess import Preprocessor
from tokenizer import tokenize_and_pad_batch, retrieve_tfs, random_mask_value
# from model import TransformerModel
from model_bioformer import TransformerModel
from loss import masked_mse_loss, masked_relative_error, criterion_neg_log_bernoulli
B, r = 64, 18

d_model= 512
nhead = 4
nlayers = 1
vocab = None
pad_token = "<pad>"
# m = torch.rand((B, r, c_in ))
# z = torch.rand((B, r, r, c_in))
# get Vocab
adata = sc.read_h5ad('./data/scsHypoxiaTimeSub.h5ad')
adata.X = adata.layers['raw_count']
adata.var['gene_name'] = adata.var.index.tolist()
data_is_raw = True
special_tokens = [pad_token, "<cls>", "<eoc>"]
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=3,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=100,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=51,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key=None)
# input_layer_key = "X_binned"
# all_counts = (
#     adata.layers[input_layer_key].toarray()
#     if issparse(adata.layers[input_layer_key])
#     else adata.layers[input_layer_key]
# )
genes = adata.var["gene_name"].tolist()
# train_data, valid_data = train_test_split(all_counts, test_size=0.1, shuffle=True)
stoi = {s:i for i, s in enumerate(genes + special_tokens)}
itos = {i:s for i, s in enumerate(genes + special_tokens)}
vocab = Vocab(stoi, itos)
vocab.set_default_index(vocab["<pad>"]) # index to return if token not found in vocab
gene_ids = np.array(vocab(genes), dtype=int)
print(f'Vocab of size: {len(vocab)} --> {len(genes)} genes, {len(special_tokens)} special tokens {special_tokens}')
ntoken = len(vocab)  # size of vocabulary

model = TransformerModel(
                ntoken=ntoken,
                d_model=d_model,
                nhead=nhead,
                d_hid=d_model,
                nlayers=nlayers,
                vocab=vocab,
                pad_token=pad_token,
)

# m_out, z_out  = model(m, z)
print(f'tot. params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M')

Filtering genes by counts ...
Normalizing total counts ...
Log1p transforming ...
Subsetting highly variable genes ...
No batch_key is provided, will use all cells for HVG selection.
Binning data ...
Vocab of size: 103 --> 100 genes, 3 special tokens ['<pad>', '<cls>', '<eoc>']
tot. params: 144.25M


In [None]:
from bioformer import BioFormerStack

model = BioFormerStack(
    c_m=d_model,
    c_z=d_model,
    c_hidden=d_model,
    no_heads=nhead,
    no_blocks=1
)

print(f'tot. params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M')

tot. params: 142.09M


In [None]:
model.blocks[0].opm

OuterProductMean(
  (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (linear_1): Linear(in_features=512, out_features=512, bias=True)
  (linear_2): Linear(in_features=512, out_features=512, bias=True)
  (linear_out): Linear(in_features=262144, out_features=512, bias=True)
)

In [None]:
print(model.blocks)
print(f'{sum(p.numel() for p in model.blocks.parameters())/1e6:.2f}M')

print(model.blocks[0].opm)
print(f'{sum(p.numel() for p in model.blocks[0].opm.parameters())/1e6:.2f}M')

print(model.blocks[0].attn)
print(f'{sum(p.numel() for p in model.blocks[0].attn.parameters())/1e6:.2f}M')

print(model.blocks[0].trans)
print(f'{sum(p.numel() for p in model.blocks[0].trans.parameters())/1e6:.2f}M')

ModuleList(
  (0): BioFormerBlock(
    (opm): OuterProductMean(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (linear_1): Linear(in_features=512, out_features=512, bias=True)
      (linear_2): Linear(in_features=512, out_features=512, bias=True)
      (linear_out): Linear(in_features=262144, out_features=512, bias=True)
    )
    (attn): RowAttentionWithPairBias(
      (layer_norm_m): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (layer_norm_z): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (linear_z): Linear(in_features=512, out_features=4, bias=False)
      (linear_q): Linear(in_features=512, out_features=2048, bias=False)
      (linear_k): Linear(in_features=512, out_features=2048, bias=False)
      (linear_v): Linear(in_features=512, out_features=2048, bias=False)
      (linear_g): Linear(in_features=512, out_features=2048, bias=True)
      (linear_o): Linear(in_features=2048, out_features=512, bias=True)
    )
    (tran