In [1]:
import json
import os
import sys
import time
import copy
from pathlib import Path
from typing import Iterable, List, Tuple, Dict, Union, Optional
import warnings
from scipy.sparse import issparse

import torch
import numpy as np
import matplotlib
from torch import nn
from torch.nn import functional as F
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from torch_geometric.loader import DataLoader
from gears import PertData, GEARS
from gears.inference import compute_metrics, deeper_analysis, non_dropout_analysis
from gears.utils import create_cell_graph_dataset_for_prediction

sys.path.insert(0, "../../../scGPT/")

import scgpt as scg
from scgpt.model import TransformerModel
from scgpt.loss import (
    masked_mse_loss,
    criterion_neg_log_bernoulli,
    masked_relative_error,
)
from scgpt.tokenizer import tokenize_batch, pad_batch, tokenize_and_pad_batch, random_mask_value
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed, map_raw_id_to_vocab_id

import pandas as pd

matplotlib.rcParams["savefig.transparent"] = False
warnings.filterwarnings("ignore")

set_seed(42)

%reload_ext autoreload
%autoreload 2

  import pkg_resources
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(parent)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespac

In [15]:
def random_mask_value(
    values: Union[torch.Tensor, np.ndarray],
    mask_ratio: float = 0.15,
    mask_value: int = -1,
    do_not_pad_index: Union[torch.Tensor, np.ndarray]=np.array([]),
    pad_value: int = 0,
) -> torch.Tensor:
    """
    Randomly mask a batch of data.

    Args:
        values (array-like):
            A batch of tokenized data, with shape (batch_size, n_features).
        mask_ratio (float): The ratio of genes to mask, default to 0.15.
        mask_value (int): The value to mask with, default to -1.
        pad_value (int): The value of padding in the values, will be kept unchanged.

    Returns:
        torch.Tensor: A tensor of masked data.
    """
    if isinstance(values, torch.Tensor):
        # it is crutial to clone the tensor, otherwise it changes the original tensor
        values = values.clone().detach().numpy()
    else:
        values = values.copy()

    for i in range(len(values)):
        row = values[i]
        non_padding_idx = np.nonzero(row - pad_value)[0]
        non_padding_idx = np.setdiff1d(non_padding_idx, do_not_pad_index)
        n_mask = int(len(non_padding_idx) * mask_ratio)
        mask_idx = np.random.choice(non_padding_idx, n_mask, replace=False)
        row[mask_idx] = mask_value
    return torch.from_numpy(values).float()

In [16]:
torch.cuda.is_available()

True

In [17]:
! nvidia-smi

Mon Nov 20 17:44:42 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.23.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A4500               On  | 00000000:B3:00.0 Off |                  Off |
| 30%   26C    P8              13W / 200W |    411MiB / 20470MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# setting up

In [14]:
# settings for data prcocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
pad_value = 0  # for padding values
#pert_pad_id = 2

n_hvg = 0  # number of highly variable genes
include_zero_gene = "all"  # include zero expr genes in training input, "all", "batch-wise", "row-wise", or False
max_seq_len = 10_000

# settings for training
MLM = True  # whether to use masked language modeling, currently it is always on.
CLS = True  # celltype classification objective
CCE = False  # Contrastive cell embedding objective
MVC = False  # Masked value prediction for cell embedding
ECS = False  # Elastic cell similarity objective
cell_emb_style = "cls"
mvc_decoder_style = "inner product, detach"
amp = True
load_model = "../../data/scGPT_human"
load_param_prefixs: list[str] = [
    "encoder",
    "value_encoder",
    "transformer_encoder",
]

# settings for optimizer
lr = 1e-4  # or 1e-4
batch_size = 8
eval_batch_size = 8
epochs = 5
schedule_interval = 1
early_stop = 5

n_input_bins=30

# settings for the model
embsize = 512  # embedding dimension
d_hid = 512  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 12  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 8  # number of heads in nn.MultiheadAttention
n_layers_cls = 3
dropout = 0.2  # dropout probability
use_fast_transformer = True  # whether to use fast transformer

# logging
log_interval = 100
data_name = "cellxgene_rand"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
save_dir = Path(f"../../data/temp/dev_perturb_{data_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"saving to {save_dir}")

logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")
# log running date and current git commit
logger.info(f"Running on {time.strftime('%Y-%m-%d %H:%M:%S')}")

saving to ../../data/temp/dev_perturb_cellxgene_rand-Nov21-07-49
scGPT - INFO - Running on 2023-11-21 07:49:58


# loading adata

In [3]:
import cellxgene_census
import scanpy as sc
from anndata import AnnData, concat

## trial adata

In [21]:
# cellxgene_census.download_source_h5ad(dataset_id="015c230d-650c-4527-870d-8a805849a382", to_path="../../data/temp/015c230d-650c-4527-870d-8a805849a382.h5ad", census_version="latest")

In [22]:
# adata = sc.read_h5ad("../../data/temp/4da324f0-f993-4015-9ac5-16f7d26c51a6.h5ad")
# adata

In [23]:
#adata = adata[np.random.choice(adata.shape[1], 10_000, replace=False)]
#adata.obs['self_reported_ethnicity_ontology_term_id'].value_counts()

## real adata

In [24]:
census = cellxgene_census.open_soma(census_version = "latest")

In [25]:
value_filter="nnz>8000 and is_primary_data == True and raw_sum>40000"

In [26]:
#obs = census["census_data"]["homo_sapiens"].obs.read(column_names=['nnz', 'raw_sum', 'assay_ontology_term_id', 'dataset_id'], value_filter=value_filter).concat().to_pandas()
# obs

NameError: name 'obs' is not defined

In [None]:
#datasets = [i for i,n in obs.dataset_id.value_counts().items() if n>10_000 and n<40_000]
#datasets

['a5d5c529-8a1f-40b5-bda3-35208970070d',
 '52ea546e-9229-40ef-b048-a2e694dd73e8',
 '92161459-9103-4379-ae34-73a38eee1d1d',
 'be401db3-d732-408a-b0c4-71af0458b8ab']

In [22]:
datasets = ['a5d5c529-8a1f-40b5-bda3-35208970070d',
 '52ea546e-9229-40ef-b048-a2e694dd73e8',
 '92161459-9103-4379-ae34-73a38eee1d1d',
 'be401db3-d732-408a-b0c4-71af0458b8ab']

In [23]:
#for i in datasets:
#    cx_adata = cellxgene_census.get_anndata(
#        census=census,
#        organism="Homo sapiens",
#        obs_value_filter=value_filter+" and dataset_id == '"+i+"'",
        
#    )
#    cx_adata.write_h5ad('../../data/temp/all_high_reads_cxg_'+str(i)+'.h5ad')

In [24]:
import anndata as ad

In [25]:
cx_adata = ad.concat([sc.read_h5ad('../../data/temp/all_high_reads_cxg_'+str(i)+'.h5ad') for i in datasets], merge="first", index_unique="_")
cx_adata

AnnData object with n_obs × n_vars = 52712 × 60664
    obs: 'soma_joinid', 'dataset_id', 'assay', 'assay_ontology_term_id', 'cell_type', 'cell_type_ontology_term_id', 'development_stage', 'development_stage_ontology_term_id', 'disease', 'disease_ontology_term_id', 'donor_id', 'is_primary_data', 'self_reported_ethnicity', 'self_reported_ethnicity_ontology_term_id', 'sex', 'sex_ontology_term_id', 'suspension_type', 'tissue', 'tissue_ontology_term_id', 'tissue_general', 'tissue_general_ontology_term_id', 'raw_sum', 'nnz', 'raw_mean_nnz', 'raw_variance_nnz', 'n_measured_vars'
    var: 'soma_joinid', 'feature_id', 'feature_name', 'feature_length', 'nnz', 'n_measured_obs'

## add the validation adata

In [26]:
adata = AnnData((2**pd.read_csv('../../data/GroundTruth/remisdata/scRNA/liu_rna_filtered_log2.tsv', sep='\t'))-1).T
adata

AnnData object with n_obs × n_vars = 72 × 23153

In [27]:
#sc.pp.neighbors(adata)
#sc.tl.umap(adata)
#sc.tl.leiden(adata, key_added='leiden_1.0', resolution=1.0)

In [28]:
#sc.pl.umap(adata, color=['leiden_1.0'], legend_loc='on data')

In [29]:
adata.var

A1BG
A1BG-AS1
A1CF
A2M
A2ML1
...
LINC01422
LINC01481
LINC01505
RAET1E-AS1
RF00017


In [30]:
cx_adata.obs = cx_adata.obs.rename(columns={'assay_ontology_term_id': 'batch'})

In [31]:
adata.obs['dataset_id'] = "rm_validation"
adata.obs['cell_type'] = "stem cell"
adata.obs['batch'] = "EFO:0008931"

In [32]:
cx_adata.var

Unnamed: 0,soma_joinid,feature_id,feature_name,feature_length,nnz,n_measured_obs
0,0,ENSG00000000003,TSPAN6,4536,3583586,65158440
1,1,ENSG00000000005,TNMD,1476,191895,53300590
2,2,ENSG00000000419,DPM1,9276,15772355,65615631
3,3,ENSG00000000457,SCYL3,6883,8510567,65445350
4,4,ENSG00000000460,C1orf112,5970,5907435,65092683
...,...,...,...,...,...,...
60659,60659,ENSG00000288719,RP4-669P10.21,4252,2826,1248980
60660,60660,ENSG00000288720,RP11-852E15.3,7007,99,1248980
60661,60661,ENSG00000288721,RP5-973N23.5,7765,0,0
60662,60662,ENSG00000288723,RP11-553N16.6,1015,18,1248980


In [33]:
cx_adata.var = cx_adata.var.set_index('feature_name')

In [34]:
adata = concat([cx_adata, adata])

In [41]:
adata

AnnData object with n_obs × n_vars = 52784 × 17725
    obs: 'dataset_id', 'batch', 'cell_type'

# load model

In [17]:
model_dir = Path(load_model)
# https://drive.google.com/drive/folders/1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"

vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

# model
with open(model_config_file, "r") as f:
    model_configs = json.load(f)
logger.info(
    f"Resume model from {model_file}, the model args will override the "
    f"config {model_config_file}."
)
embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]

scGPT - INFO - Resume model from ../../data/scGPT_human/best_model.pt, the model args will override the config ../../data/scGPT_human/args.json.


In [43]:
import wandb

In [None]:
#run = wandb.init(
#    config=dict(
#    seed=0),
#    project="scGPT",
#    reinit=True,
#    settings=wandb.Settings(start_method="fork"),
#)
#config = wandb.config
#print(config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjkobject[0m ([33mml4ig[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'seed': 0}
Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f99203eeb30>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f99201f52a0, execution_count=27 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7f99201f6830, raw_cell="run = wandb.init(
    config=dict(
    seed=0),
  .." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Bperso/home/ml4ig1/Documents%20code/scPRINT/notebooks/assessments/attention_GRN_finetuning.ipynb#Y242sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given

In [45]:
ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=len(adata.obs.cell_type.unique()),
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    n_input_bins=n_input_bins,
    do_mvc=MVC,
    cell_emb_style=cell_emb_style,
    mvc_decoder_style=mvc_decoder_style,
    use_fast_transformer=use_fast_transformer,
)
if load_param_prefixs is not None and load_model is not None:
    # only load params that start with the prefix
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_file, map_location=device)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if any([k.startswith(prefix) for prefix in load_param_prefixs])
    }
    for k, v in pretrained_dict.items():
        logger.info(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
elif load_model is not None:
    try:
        model.load_state_dict(torch.load(model_file))
        logger.info(f"Loading all model params from {model_file}")
    except:
        # only load params that are in the model and match the size
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_file)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            logger.info(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())

# Freeze all pre-decoder weights
for name, para in model.named_parameters():
    if "encoder" in name and "transformer_encoder" not in name:
    # if config.freeze and "encoder" in name:
        print(f"freezing weights for: {name}")
        para.requires_grad = False

post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())

logger.info(f"Total Pre freeze Params {(pre_freeze_param_count )}")
logger.info(f"Total Post freeze Params {(post_freeze_param_count )}")
#wandb.log(
#        {
#            "info/pre_freeze_param_count": pre_freeze_param_count,
#            "info/post_freeze_param_count": post_freeze_param_count,
#        },
#)
#wandb.watch(model)
model.to(device)


hey
scGPT - INFO - Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
scGPT - INFO - Loading params encoder.enc_norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params encoder.enc_norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
scGPT - INFO - Loading params value_encoder.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params value_encoder.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.bias with shape torch.Size(

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x FlashTransformerEncoderLayer(
        (self_attn): FlashMHA(
          (Wqkv): Linear(in_features=512, out_features=1536, bias=True)
          (inner_attn): FlashAttention()
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features

## preprocessing

In [7]:
from scgpt.preprocess import Preprocessor

In [8]:
data_is_raw=True

In [35]:
# set up the preprocessor, use the args to config the workflow
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=3,  # step 1
    filter_cell_by_counts=10,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=max_seq_len,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=n_input_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key="batch")

scGPT - INFO - Filtering genes by counts ...
scGPT - INFO - Filtering cells by counts ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Log1p transforming ...
scGPT - INFO - Subsetting highly variable genes ...
scGPT - INFO - Binning data ...


In [49]:
from sklearn.model_selection import train_test_split

In [50]:
vocab

GeneVocab()

In [10]:
adata.var["id_in_vocab"] = [
    1 if gene in vocab else -1 for gene in adata.var.index
]
gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
logger.info(
    f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
    f"in vocabulary of size {len(vocab)}."
)
genes = adata.var.index.tolist()
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)
n_genes = len(genes)

NameError: name 'vocab' is not defined

In [52]:
input_layer_key = "X_binned"
all_counts = (
    adata.layers[input_layer_key].A
    if issparse(adata.layers[input_layer_key])
    else adata.layers[input_layer_key]
)
genes = adata.var.index.tolist()
adata.obs["cell_type_code"] = adata.obs["cell_type"].astype("category").cat.codes.values
adata.obs["batch_code"] = adata.obs["batch"].astype("category").cat.codes.values

celltypes_labels = adata.obs["cell_type_code"].tolist()  # make sure count from 0
num_types = len(set(celltypes_labels))
celltypes_labels = np.array(celltypes_labels)

batch_ids = adata.obs["batch_code"].tolist()
num_batch_types = len(set(batch_ids))
batch_ids = np.array(batch_ids)

(
    train_data,
    valid_data,
    train_celltype_labels,
    valid_celltype_labels,
    train_batch_labels,
    valid_batch_labels,
) = train_test_split(
    all_counts, celltypes_labels, batch_ids, test_size=0.1, shuffle=True
)

In [53]:
tokenized_train = tokenize_and_pad_batch(
    train_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=include_zero_gene,
)
tokenized_valid = tokenize_and_pad_batch(
    valid_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,
    include_zero_gene=include_zero_gene,
)
logger.info(
    f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
)
logger.info(
    f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
)

scGPT - INFO - train set number of samples: 47505, 
	 feature length: 10000
scGPT - INFO - valid set number of samples: 5279, 
	 feature length: 10000


# training

In [54]:
mask_ratio = 0.6

In [97]:
def fileToList(filename, strconv=lambda x: x):
    """
    loads an input file with a\\n b\\n.. into a list [a,b,..]
    """
    with open(filename) as f:
        return [strconv(val[:-1]) for val in f.readlines()]

In [98]:
TFs = fileToList('../../data/TF.txt')
TFs = list(set(TFs) & set(adata.var.index))
len(TFs)

856

In [99]:
# do not pad the TFs
tf_ids = np.array([vocab[gene] for gene in TFs], dtype=int)

In [58]:
from torch.utils.data import Dataset, DataLoader

In [59]:
mask_value= -1

In [60]:
def prepare_data(sort_seq_batch=False) -> Tuple[Dict[str, torch.Tensor]]:
    masked_values_train = random_mask_value(
        tokenized_train["values"],
        mask_ratio=mask_ratio,
        do_not_pad_index=tf_ids,
        pad_value=pad_value,
    )
    masked_values_valid = random_mask_value(
        tokenized_valid["values"],
        mask_ratio=mask_ratio,
        do_not_pad_index=tf_ids,
        pad_value=pad_value,
    )
    print(
        f"random masking at epoch {epoch:3d}, ratio of masked values in train: ",
        f"{(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}",
    )

    input_gene_ids_train, input_gene_ids_valid = (
        tokenized_train["genes"],
        tokenized_valid["genes"],
    )
    input_values_train, input_values_valid = masked_values_train, masked_values_valid
    target_values_train, target_values_valid = (
        tokenized_train["values"],
        tokenized_valid["values"],
    )

    tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long()
    tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long()

    tensor_celltype_labels_train = torch.from_numpy(train_celltype_labels).long()
    tensor_celltype_labels_valid = torch.from_numpy(valid_celltype_labels).long()

    train_data_pt = {
        "gene_ids": input_gene_ids_train,
        "values": input_values_train,
        "target_values": target_values_train,
        "batch_labels": tensor_batch_labels_train,
        "celltype_labels": tensor_celltype_labels_train,
    }
    valid_data_pt = {
        "gene_ids": input_gene_ids_valid,
        "values": input_values_valid,
        "target_values": target_values_valid,
        "batch_labels": tensor_batch_labels_valid,
        "celltype_labels": tensor_celltype_labels_valid,
    }

    return train_data_pt, valid_data_pt

# dataset
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}


# data_loader
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    intra_domain_shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = 0,
    per_seq_batch_sample: bool = False,
) -> DataLoader:
    if num_workers == 0:
        num_workers = min(len(os.sched_getaffinity(0)), batch_size // 2)

    dataset = SeqDataset(data_pt)

    if per_seq_batch_sample:
        # find the indices of samples in each seq batch
        subsets = []
        batch_labels_array = data_pt["batch_labels"].numpy()
        for batch_label in np.unique(batch_labels_array):
            batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
            subsets.append(batch_indices)
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=SubsetsBatchSampler(
                subsets,
                batch_size,
                intra_subset_shuffle=intra_domain_shuffle,
                inter_subset_shuffle=shuffle,
                drop_last=drop_last,
            ),
            num_workers=num_workers,
            pin_memory=True,
        )
        return data_loader

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
    )
    return data_loader

In [65]:
criterion = masked_mse_loss
criterion_cls = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, schedule_interval, gamma=0.9)
scaler = torch.cuda.amp.GradScaler(enabled=amp)


def train(model: nn.Module, train_loader: torch.utils.data.DataLoader) -> None:
    """
    Train the model for one epoch.
    """
    model.train()
    total_loss, total_mse, total_cls = 0.0, 0.0, 0.0
    start_time = time.time()

    num_batches = len(train_loader)
    for batch, batch_data in enumerate(train_loader):
            # src_key_padding_mask = mapped_input_gene_ids.eq(vocab[pad_token])
        if batch==1500:
            break
        src_key_padding_mask = torch.zeros_like(
            batch_data['values'], dtype=torch.bool, device=device
        )

        with torch.cuda.amp.autocast(enabled=amp):
            output_dict = model(
                batch_data['gene_ids'].to(device),
                batch_data['values'].to(device),
                src_key_padding_mask=src_key_padding_mask.to(device),
                CLS=CLS,
                CCE=CCE,
                MVC=MVC,
                ECS=ECS,
            )
            output_values = output_dict["mlm_output"]

            masked_positions = torch.ones_like(
                batch_data['values'], dtype=torch.bool, device=device
            )  # Use all
            loss=0
            metrics_to_log = {}
            if MLM:
                loss_mse = criterion(
                    output_dict["mlm_output"], batch_data['target_values'].to(device), masked_positions
                )
                #MLM task IS 2X MORE IMPORTANT
                loss += 2*loss_mse
                metrics_to_log.update({"train/mse": loss_mse.item()})
            if CLS:
                loss_cls = criterion_cls(output_dict["cls_output"], batch_data['celltype_labels'].to(device))
                loss += loss_cls
                metrics_to_log.update({"train/cls": loss_cls.item()})

                error_rate = 1 - (
                    (output_dict["cls_output"].argmax(1) == batch_data['celltype_labels'].to(device))
                    .sum()
                    .item()
                ) / batch_data['celltype_labels'].size(0)

        model.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                1.0,
                error_if_nonfinite=False if scaler.is_enabled() else True,
            )
            if len(w) > 0:
                logger.warning(
                    f"Found infinite gradient. This may be caused by the gradient "
                    f"scaler. The current scale is {scaler.get_scale()}. This warning "
                    "can be ignored if no longer occurs after autoscaling of the scaler."
                )
        scaler.step(optimizer)
        scaler.update()

        # torch.cuda.empty_cache()
        #wandb.log(metrics_to_log)
        total_loss += loss.item()
        total_mse += loss_mse.item() if MLM else 0.0
        total_cls += loss_cls.item() if CLS else 0.0
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            cur_mse = total_mse / log_interval
            # ppl = math.exp(cur_loss)
            logger.info(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | mse {cur_mse:5.2f} |"
            )
            total_loss = 0
            total_mse = 0
            start_time = time.time()


def evaluate(model: nn.Module, val_loader: torch.utils.data.DataLoader) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0

    with torch.no_grad():
        for batch, batch_data in enumerate(val_loader):
            # src_key_padding_mask = mapped_input_gene_ids.eq(vocab[pad_token])
            if batch==150:
                break
            src_key_padding_mask = torch.zeros_like(
                batch_data['values'], dtype=torch.bool, device=device
            )
            with torch.cuda.amp.autocast(enabled=amp):
                output_dict = model(
                    batch_data['gene_ids'].to(device),
                    batch_data['values'].to(device),
                    src_key_padding_mask=src_key_padding_mask,
                    CLS=CLS,
                    CCE=CCE,
                    MVC=MVC,
                    ECS=ECS,
                    do_sample=True,
                )
                output_values = output_dict["mlm_output"]
                output_class = output_dict["cls_output"]

                masked_positions = torch.ones_like(
                    batch_data['values'], dtype=torch.bool, device=device
                )
                loss = 2*criterion(output_values, batch_data['target_values'].to(device), masked_positions.to(device))
                loss += criterion_cls(output_values, batch_data['celltype_labels'].to(device))
            total_loss += loss.item()
            total_error += masked_relative_error(
                output_values, batch_data['target_values'].to(device), masked_positions
            ).item()
    #wandb.log(
    #    {
    #        "valid/loss": total_loss,
    #        "valid/error": total_error,
    #        "epoch": epoch,
    #    },
    #)
    return total_loss / len(val_loader), total_error / len(val_loader)

#def define_wandb_metrcis():
#    wandb.define_metric("train/cls", summary="min", step_metric="epoch")
#    wandb.define_metric("train/mse", summary="min", step_metric="epoch")
#    wandb.define_metric("valid/dab", summary="min", step_metric="epoch")
#    wandb.define_metric("valid/sum_mse_dab", summary="min", step_metric="epoch")
#    wandb.define_metric("test/avg_bio", summary="max")

In [64]:
best_val_loss = float("inf")
best_model = None
patience = 0
epochs=3

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train_data_pt, valid_data_pt = prepare_data()

    train_loader = prepare_dataloader(
        train_data_pt,
        batch_size=batch_size,
        shuffle=True,
        intra_domain_shuffle=True,
        drop_last=False,
    )
    valid_loader = prepare_dataloader(
        valid_data_pt,
        batch_size=eval_batch_size,
        shuffle=True,
        intra_domain_shuffle=False,
        drop_last=False,
    )

    train(
        model,
        train_loader,
    )
    val_loss, val_mre = evaluate(
        model,
        valid_loader,
    )
    elapsed = time.time() - epoch_start_time
    logger.info("-" * 89)
    logger.info(
        f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
        f"valid loss/mse {val_loss:5.4f} |"
    )
    logger.info("-" * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        logger.info(f"Best model with score {best_val_loss:5.4f}")
        patience = 0
    else:
        patience += 1
        if patience >= early_stop:
            logger.info(f"Early stop at epoch {epoch}")
            break

    torch.save(
        model.state_dict(),
        save_dir / f"model_{epoch}.pt",
    )

    scheduler.step()

random masking at epoch   1, ratio of masked values in train:  0.5916
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   1 | time: 50.37s | valid loss/mse 1.2407 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 1.2407
random masking at epoch   2, ratio of masked values in train:  0.5916
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   2 | time: 51.90s | valid loss/mse 1.0182 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - Best model with score 1.0182
random masking at epoch   3, ratio of masked values in train:  0.5916
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epo

In [None]:
## TODO: place more weights on the MLM tasks. try decreasing the amount of genes. freeze the gene encoding learning

## save

In [None]:
torch.save(
    model.state_dict(),
    save_dir / f"model_{epoch}.pt",
)

## validation

In [12]:
model_dir = Path(load_model)

NotADirectoryError: [Errno 20] Not a directory: '../../data/temp/dev_perturb_cellxgene_rand-Nov20-20-15/model_2.pt/vocab.json'

In [19]:
model_file = Path("../../data/temp/dev_perturb_cellxgene_rand-Nov20-20-15/model_2.pt")

In [21]:
model_file

PosixPath('../../data/temp/dev_perturb_cellxgene_rand-Nov20-20-15/model_2.pt')

In [20]:
ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=len(adata.obs.cell_type.unique()),
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    n_input_bins=n_input_bins,
    do_mvc=MVC,
    cell_emb_style=cell_emb_style,
    mvc_decoder_style=mvc_decoder_style,
    use_fast_transformer=use_fast_transformer,
)
if load_param_prefixs is not None and load_model is not None:
    # only load params that start with the prefix
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_file, map_location=device)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if any([k.startswith(prefix) for prefix in load_param_prefixs])
    }
    for k, v in pretrained_dict.items():
        logger.info(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
elif load_model is not None:
    try:
        model.load_state_dict(torch.load(model_file))
        logger.info(f"Loading all model params from {model_file}")
    except:
        # only load params that are in the model and match the size
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_file)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            logger.info(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

model.to(device)


hey
scGPT - INFO - Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
scGPT - INFO - Loading params encoder.enc_norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params encoder.enc_norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
scGPT - INFO - Loading params value_encoder.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params value_encoder.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.bias with shape torch.Size(

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x FlashTransformerEncoderLayer(
        (self_attn): FlashMHA(
          (Wqkv): Linear(in_features=512, out_features=1536, bias=True)
          (inner_attn): FlashAttention()
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features

In [124]:
test_adata = adata[adata.obs.cell_type=="glutamatergic neuron"][:20,:]
test_adata

## Step 2: Retrieve scGPT's attention weights


### 2.1 Prepare model input


In [125]:
input_layer_key = "X_binned"
all_counts = (
    test_adata.layers[input_layer_key].A
    if issparse(test_adata.layers[input_layer_key])
    else test_adata.layers[input_layer_key]
)

genes = test_adata.var.index.tolist()
gene_ids = np.array(vocab(genes), dtype=int)

tokenized_all = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=len(genes) + 1,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )

### 2.1 Retrieve attention weights

Note that since the flash-attn package does not output attention scores, we
manually calculate q @ k.T to extract the attention weights. Users may specify
which layer to extract the attention weights from. In the manuscript, we used
the attention weights from the last (12th) layer.


In [127]:
all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
gene_vocab_idx = all_gene_ids[0].clone().detach().cpu().numpy()

In [55]:
from tqdm import tqdm
import scipy as sp
from einops import rearrange
from torch.nn.functional import softmax

In [128]:
def get_attention(layer_num=11, batch_size=4):

    torch.cuda.empty_cache()
    dict_sum_condition = {}
    model.eval()
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
        M = all_gene_ids.size(1)
        N = all_gene_ids.size(0)
        device = next(model.parameters()).device
        for i in tqdm(range(0, N, batch_size)):
            batch_size = all_gene_ids[i : i + batch_size].size(0)
            outputs = np.zeros((batch_size, M, M), dtype=np.float32)
            # Replicate the operations in model forward pass
            src_embs = model.encoder(
                torch.tensor(all_gene_ids[i : i + batch_size], dtype=torch.long).to(device)
            )
            val_embs = model.value_encoder(
                torch.tensor(all_values[i : i + batch_size], dtype=torch.float).to(device)
            )
            total_embs = src_embs + val_embs
            # total_embs = model.layer(total_embs.permute(0, 2, 1)).permute(0, 2, 1)
            # Send total_embs to attention layers for attention operations
            # Retrieve the output from second to last layer
            for layer in model.transformer_encoder.layers[:layer_num]:
                total_embs = layer(
                    total_embs,
                    src_key_padding_mask=src_key_padding_mask[i : i + batch_size].to(
                        device
                    ),
                )
            # Send total_embs to the last layer in flash-attn
            # https://github.com/HazyResearch/flash-attention/blob/1b18f1b7a133c20904c096b8b222a0916e1b3d37/flash_attn/flash_attention.py#L90
            qkv = model.transformer_encoder.layers[layer_num].self_attn.Wqkv(
                total_embs
            )
            # Retrieve q, k, and v from flast-attn wrapper
            qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=8)
            q = qkv[:, :, 0, :, :]
            k = qkv[:, :, 1, :, :]
            v = qkv[:, :, 2, :, :]
            # https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a
            # q = [batch, gene, n_heads, n_hid]
            # k = [batch, gene, n_heads, n_hid]
            # attn_scores = [batch, n_heads, gene, gene]
            attn_scores = q.permute(0, 2, 1, 3) @ k.permute(0, 2, 3, 1)
            # apply softmax to get attention weights
            attn_scores = softmax(attn_scores, dim=-1)
            if i == 0:
                sm_attn_scores = attn_scores.sum(0).detach().cpu().numpy()
            else:
                # take the sum
                sm_attn_scores += attn_scores.sum(0).detach().cpu().numpy()
    return sm_attn_scores
    # return [pd.DataFrame(data=sm_attn_scores[i], columns=vocab.lookup_tokens(gene_vocab_idx), index=vocab.lookup_tokens(gene_vocab_idx)) for i in range(0,8)]

In [129]:
att = get_attention(batch_size=1)

  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:43<00:00,  2.17s/it]


In [130]:
sm_attn_scores = [pd.DataFrame(data=att[i], columns=vocab.lookup_tokens(gene_vocab_idx), index=vocab.lookup_tokens(gene_vocab_idx)) for i in range(0,8)]
sm_attn_scores[0][(sm_attn_scores[0]>2).sum(0)>500].index

Index([], dtype='object')

### enrichment

In [67]:
(att.sum(0)>2).sum(0)

array([3, 0, 0, ..., 0, 0, 0])

8906

In [None]:
sm_attn_scores_h5.lo

In [149]:
sm_attn_scores[3].sum(0).sort_values(ascending=False).head(20)

RLN3           12494.658203
CALM2          10566.224609
LINC01189       8788.467773
PARD3B          8107.293457
MEF2C           6857.473145
BFSP1           6132.723145
SLIT2           5984.465820
CBLN2           4144.154785
OR2G3           3413.811279
PDZD2           2937.051514
OR2L13          2845.773926
NFYC-AS1        2768.364990
FENDRR          2482.460449
CELF2           2275.022705
SLC2A13         2228.815918
RGS6            1768.902588
SLC10A5         1679.303101
DGKB            1435.802979
NECAB1          1364.238159
MAP3K20-AS1     1308.543457
dtype: float32

In [139]:
sm_attn_scores[4].sum(0).loc[TFs].mean(), sm_attn_scores[4].sum(0).mean()

(13.884732, 19.99997)

In [140]:
val = sm_attn_scores[0].sum(0)


In [65]:
import gseapy as gp
from gseapy.plot import dotplot

In [144]:
pre_res = gp.prerank(rnk=val[np.argsort(val.tolist())[::-1][:2000]], # or rnk = rnk,
                    gene_sets="ENCODE_TF_ChIP-seq_2014",
                    #"ENCODE_TF_ChIP-seq_2014",
                    #'GO_Molecular_Function_2015', #
                    min_size=5,
                    max_size=1000,
                    permutation_num=200, # reduce number to speed up testing
                    outdir=None, # don't write to disk
                    seed=6,
                    verbose=True, # see what's going on behind the scenes
            )

2023-11-21 11:26:58,552 Parsing data files for GSEA.............................


2023-11-21 11:26:59,219 Enrichr library gene sets already downloaded in: /home/ml4ig1/.gseapy, use local file
2023-11-21 11:27:57,209 0029 gene_sets have been filtered out when max_size=1000 and min_size=5
2023-11-21 11:27:57,258 0469 gene_sets used for further statistical testing.....
2023-11-21 11:27:57,259 Start to run GSEA...Might take a while..................
2023-11-21 11:29:22,362 Start to generate gseapy reports, and produce figures...
2023-11-21 11:29:22,460 Congratulations. GSEApy runs successfully................



In [92]:
# KEGG_2016 H5
pre_res.res2d[pre_res.res2d['fdr']<0.05].sort_values(by=["nes"], ascending=False)

Unnamed: 0_level_0,es,nes,pval,fdr,geneset_size,matched_size,genes,ledge_genes
Term,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1


In [105]:
# ENCODE_TF_ChIP-seq_2014 H5
pre_res.res2d[pre_res.res2d['fdr']<0.05].sort_values(by=["nes"], ascending=False)

Unnamed: 0_level_0,es,nes,pval,fdr,geneset_size,matched_size,genes,ledge_genes
Term,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1


In [145]:
# GO MF h0
pre_res.res2d[pre_res.res2d['fdr']<0.05].sort_values(by=["nes"], ascending=False)

Unnamed: 0_level_0,es,nes,pval,fdr,geneset_size,matched_size,genes,ledge_genes
Term,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1


### bertview

In [None]:
from bertviz import head_view
import numpy as np


In [None]:
SIZE = 70
LOC = 1000

In [None]:

head_view([torch.FloatTensor([sm_attn_scores[:,LOC:LOC+SIZE,LOC:LOC+SIZE]])], sm_attn_scores_h0.index.tolist()[LOC:LOC+SIZE])

<IPython.core.display.Javascript object>