In [1]:
import os
os.chdir("/oak/stanford/groups/akundaje/kobbad/UCE")
from model import *
import scanpy as sc
from tqdm.auto import tqdm
from torch import nn, Tensor
from evaluate import get_ESM2_embeddings
from utils import get_ESM2_embeddings_x
import argparse
from accelerate import Accelerator
from evaluate import AnndataProcessor
from eval_data import MultiDatasetSentences, MultiDatasetSentenceCollator
import pickle
from torch.utils.data import DataLoader
import torch
import numpy as np

In [2]:
parser = argparse.ArgumentParser(
    description='Embed a single anndata using UCE.')

# Anndata Processing Arguments
parser.add_argument('--adata_path', type=str,
                    default=None,
                    help='Full path to the anndata you want to embed.')
parser.add_argument('--dir', type=str,
                    default="./",
                    help='Working folder where all files will be saved.')
parser.add_argument('--species', type=str, default="human",
                    help='Species of the anndata.')
parser.add_argument('--filter', type=bool, default=True,
                    help='Additional gene/cell filtering on the anndata.')
parser.add_argument('--skip', type=bool, default=True,
                    help='Skip datasets that appear to have already been created.')

# Model Arguments
parser.add_argument('--model_loc', type=str,
                    default=None,
                    help='Location of the model.')
parser.add_argument('--batch_size', type=int, default=25,
                    help='Batch size.')
parser.add_argument('--pad_length', type=int, default=1536,
                    help='Batch size.')
parser.add_argument("--pad_token_idx", type=int, default=0,
                    help="PAD token index")
parser.add_argument("--chrom_token_left_idx", type=int, default=1,
                    help="Chrom token left index")
parser.add_argument("--chrom_token_right_idx", type=int, default=2,
                    help="Chrom token right index")
parser.add_argument("--cls_token_idx", type=int, default=3,
                    help="CLS token index")
parser.add_argument("--CHROM_TOKEN_OFFSET", type=int, default=143574,
                    help="Offset index, tokens after this mark are chromosome identifiers")
parser.add_argument('--sample_size', type=int, default=1024,
                    help='Number of genes sampled for cell sentence')
parser.add_argument('--CXG', type=bool, default=True,
                    help='Use CXG model.')
parser.add_argument('--nlayers', type=int, default=4,
                    help='Number of transformer layers.')
parser.add_argument('--output_dim', type=int, default=1280,
                    help='Output dimension.')
parser.add_argument('--d_hid', type=int, default=5120,
                    help='Hidden dimension.')
parser.add_argument('--token_dim', type=int, default=5120,
                    help='Token dimension.')
parser.add_argument('--multi_gpu', type=bool, default=False,
                    help='Use multiple GPUs')

# Misc Arguments
parser.add_argument("--spec_chrom_csv_path",
                    default="./model_files/species_chrom.csv", type=str,
                    help="CSV Path for species genes to chromosomes and start locations.")
parser.add_argument("--token_file",
                    default="./model_files/all_tokens.torch", type=str,
                    help="Path for token embeddings.")
parser.add_argument("--protein_embeddings_dir",
                    default="./model_files/protein_embeddings/", type=str,
                    help="Directory where protein embedding .pt files are stored.")
parser.add_argument("--offset_pkl_path",
                    default="./model_files/species_offsets.pkl", type=str,
                        help="PKL file which contains offsets for each species.")

# masking arguments
parser.add_argument("--genes_to_mask", 
                    nargs='+',
                    default= None,
                    type=str,
                    help="List of genes to mask")

parser.add_argument("--genes_to_pe_idx",
                    default="./model_files/gene_to_pe_index.pkl",type=str,
                    help="Path to gene to protein embedding index mapping")
parser.add_argument("--num_knockin", 
                        default=1,
                        type=int,
                        help="Number of times to knockin gene")

_StoreAction(option_strings=['--num_knockin'], dest='num_knockin', nargs=None, const=None, default=1, type=<class 'int'>, choices=None, help='Number of times to knockin gene', metavar=None)

In [3]:
args = parser.parse_args()

args.gene_to_knockin = ["STAT2"]

args.gene_to_knockin

['STAT2']

In [4]:
args.model_loc = "/oak/stanford/groups/akundaje/kobbad/UCE/model_files/4layer_model.torch"

In [5]:
#### Set up the model ####
token_dim = args.token_dim
emsize = 1280  # embedding dimension
d_hid = args.d_hid  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = args.nlayers  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 20  # number of heads in nn.MultiheadAttention
dropout = 0.05  # dropout probability
model = TransformerModel(token_dim=token_dim, d_model=emsize, nhead=nhead,
                            d_hid=d_hid,
                            nlayers=nlayers, dropout=dropout,
                            output_dim=args.output_dim)

# intialize as empty
empty_pe = torch.zeros(145469, 5120)
empty_pe.requires_grad = False
model.pe_embedding = nn.Embedding.from_pretrained(empty_pe)
model.load_state_dict(torch.load(args.model_loc, map_location="cpu"),
                        strict=True)
# Load in the real token embeddings
all_pe = get_ESM2_embeddings(args)

if all_pe.shape[0] != 145469: 
        all_pe.requires_grad = False
        model.pe_embedding = nn.Embedding.from_pretrained(all_pe)

model = model.eval()
accelerator = Accelerator(project_dir=args.dir)
model = accelerator.prepare(model)
batch_size = args.batch_size

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [6]:
processor = AnndataProcessor(args, accelerator)
processor.preprocess_anndata()
processor.generate_idxs()
processor.adata

Using sample AnnData: 10k pbmcs dataset
10k_pbmcs_proc already processed. Skipping
PE Idx, Chrom and Starts files already created


AnnData object with n_obs × n_vars = 11990 × 10809
    obs: 'n_counts', 'batch', 'labels', 'str_labels', 'cell_type', 'n_genes'
    var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'n_cells'
    uns: 'cell_types', 'hvg'
    obsm: 'design', 'normalized_qc', 'qc_pc', 'raw_qc'

In [7]:
with open(processor.shapes_dict_path, "rb") as f:
            shapes_dict = pickle.load(f)

In [8]:
dataset = MultiDatasetSentences(sorted_dataset_names=[processor.name],
                                shapes_dict=shapes_dict,
                                args=args, npzs_dir=args.dir,
                                dataset_to_protein_embeddings_path=processor.pe_idx_path,
                                datasets_to_chroms_path=processor.chroms_path,
                                datasets_to_starts_path=processor.starts_path
                                )

In [9]:
dataset.dataset_to_protein_embeddings[dataset.datasets[0]].shape

torch.Size([10809])

In [10]:
dataset.datasets

['10k_pbmcs_proc']

In [11]:
# seems to have the mapping from gene to protein embeddings
dataset.dataset_to_protein_embeddings['10k_pbmcs_proc'].shape

torch.Size([10809])

In [12]:
dataset.dataset_to_protein_embeddings[dataset.datasets[0]]

tensor([28041, 26112, 20529,  ..., 23672, 23677, 23680])

In [13]:
multi_dataset_sentence_collator = MultiDatasetSentenceCollator(args)

In [14]:
multi_dataset_sentence_collator.idx_to_knockin

[29700]

In [15]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                            collate_fn=multi_dataset_sentence_collator,
                            num_workers=0)

In [16]:
x = next(iter(dataloader))

In [29]:
model.to(accelerator.device)
x = accelerator.prepare(x)

In [35]:
bs = x[0].to(accelerator.device)
msk = x[1].to(accelerator.device)

In [36]:
batch_sentences, mask = bs, msk
batch_sentences = batch_sentences.permute(1,0)

batch_sentences = model.pe_embedding(batch_sentences.long())

In [37]:
batch_sentences = nn.functional.normalize(batch_sentences,
                                                      dim=2) 

In [38]:
_, embedding = model.forward(batch_sentences, mask=mask)

In [39]:
embedding.shape

torch.Size([25, 1280])