In [1]:
import os
os.chdir("/oak/stanford/groups/akundaje/kobbad/UCE")
from model import *
import scanpy as sc
from tqdm.auto import tqdm
from torch import nn, Tensor
from evaluate import get_ESM2_embeddings
from utils import get_ESM2_embeddings_x
import argparse
from accelerate import Accelerator
from evaluate import AnndataProcessor
from eval_data import MultiDatasetSentences, MultiDatasetSentenceCollator
import pickle
from torch.utils.data import DataLoader
import torch
import numpy as np

In [2]:
parser = argparse.ArgumentParser(
    description='Embed a single anndata using UCE.')

# Anndata Processing Arguments
parser.add_argument('--adata_path', type=str,
                    default=None,
                    help='Full path to the anndata you want to embed.')
parser.add_argument('--dir', type=str,
                    default="./",
                    help='Working folder where all files will be saved.')
parser.add_argument('--species', type=str, default="human",
                    help='Species of the anndata.')
parser.add_argument('--filter', type=bool, default=True,
                    help='Additional gene/cell filtering on the anndata.')
parser.add_argument('--skip', type=bool, default=True,
                    help='Skip datasets that appear to have already been created.')

# Model Arguments
parser.add_argument('--model_loc', type=str,
                    default=None,
                    help='Location of the model.')
parser.add_argument('--batch_size', type=int, default=25,
                    help='Batch size.')
parser.add_argument('--pad_length', type=int, default=1536,
                    help='Batch size.')
parser.add_argument("--pad_token_idx", type=int, default=0,
                    help="PAD token index")
parser.add_argument("--chrom_token_left_idx", type=int, default=1,
                    help="Chrom token left index")
parser.add_argument("--chrom_token_right_idx", type=int, default=2,
                    help="Chrom token right index")
parser.add_argument("--cls_token_idx", type=int, default=3,
                    help="CLS token index")
parser.add_argument("--CHROM_TOKEN_OFFSET", type=int, default=143574,
                    help="Offset index, tokens after this mark are chromosome identifiers")
parser.add_argument('--sample_size', type=int, default=1024,
                    help='Number of genes sampled for cell sentence')
parser.add_argument('--CXG', type=bool, default=True,
                    help='Use CXG model.')
parser.add_argument('--nlayers', type=int, default=4,
                    help='Number of transformer layers.')
parser.add_argument('--output_dim', type=int, default=1280,
                    help='Output dimension.')
parser.add_argument('--d_hid', type=int, default=5120,
                    help='Hidden dimension.')
parser.add_argument('--token_dim', type=int, default=5120,
                    help='Token dimension.')
parser.add_argument('--multi_gpu', type=bool, default=False,
                    help='Use multiple GPUs')

# Misc Arguments
parser.add_argument("--spec_chrom_csv_path",
                    default="./model_files/species_chrom.csv", type=str,
                    help="CSV Path for species genes to chromosomes and start locations.")
parser.add_argument("--token_file",
                    default="./model_files/all_tokens.torch", type=str,
                    help="Path for token embeddings.")
parser.add_argument("--protein_embeddings_dir",
                    default="./model_files/protein_embeddings/", type=str,
                    help="Directory where protein embedding .pt files are stored.")
parser.add_argument("--offset_pkl_path",
                    default="./model_files/species_offsets.pkl", type=str,
                        help="PKL file which contains offsets for each species.")

# masking arguments
parser.add_argument("--genes_to_mask", 
                    nargs='+',
                    default= None,
                    type=str,
                    help="List of genes to mask")

parser.add_argument("--genes_to_pe_idx",
                    default="./model_files/gene_to_pe_index.pkl",type=str,
                    help="Path to gene to protein embedding index mapping")

_StoreAction(option_strings=['--genes_to_pe_idx'], dest='genes_to_pe_idx', nargs=None, const=None, default='./model_files/gene_to_pe_index.pkl', type=<class 'str'>, choices=None, help='Path to gene to protein embedding index mapping', metavar=None)

In [3]:
args = parser.parse_args()

In [4]:
args.genes_to_mask = ["STAT2"]

args.genes_to_mask

['STAT2']

In [5]:
args.model_loc = "/oak/stanford/groups/akundaje/kobbad/UCE/model_files/4layer_model.torch"

In [6]:
#### Set up the model ####
token_dim = args.token_dim
emsize = 1280  # embedding dimension
d_hid = args.d_hid  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = args.nlayers  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 20  # number of heads in nn.MultiheadAttention
dropout = 0.05  # dropout probability
model = TransformerModel(token_dim=token_dim, d_model=emsize, nhead=nhead,
                            d_hid=d_hid,
                            nlayers=nlayers, dropout=dropout,
                            output_dim=args.output_dim)

In [7]:
# intialize as empty
empty_pe = torch.zeros(145469, 5120)
empty_pe.requires_grad = False
model.pe_embedding = nn.Embedding.from_pretrained(empty_pe)
model.load_state_dict(torch.load(args.model_loc, map_location="cpu"),
                        strict=True)
# Load in the real token embeddings
all_pe = get_ESM2_embeddings(args)

In [8]:
if all_pe.shape[0] != 145469: 
        all_pe.requires_grad = False
        model.pe_embedding = nn.Embedding.from_pretrained(all_pe)

In [9]:
model = model.eval()
accelerator = Accelerator(project_dir=args.dir)
model = accelerator.prepare(model)
batch_size = args.batch_size

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [10]:
processor = AnndataProcessor(args, accelerator)

Using sample AnnData: 10k pbmcs dataset


In [11]:
processor.preprocess_anndata()

10k_pbmcs_proc already processed. Skipping


In [12]:
processor.generate_idxs()

PE Idx, Chrom and Starts files already created


In [13]:
processor.adata

AnnData object with n_obs × n_vars = 11990 × 10809
    obs: 'n_counts', 'batch', 'labels', 'str_labels', 'cell_type', 'n_genes'
    var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'n_cells'
    uns: 'cell_types', 'hvg'
    obsm: 'design', 'normalized_qc', 'qc_pc', 'raw_qc'

In [14]:
with open(processor.shapes_dict_path, "rb") as f:
            shapes_dict = pickle.load(f)

In [15]:
dataset = MultiDatasetSentences(sorted_dataset_names=[processor.name],
                                shapes_dict=shapes_dict,
                                args=args, npzs_dir=args.dir,
                                dataset_to_protein_embeddings_path=processor.pe_idx_path,
                                datasets_to_chroms_path=processor.chroms_path,
                                datasets_to_starts_path=processor.starts_path
                                )

In [16]:
dataset.dataset_to_protein_embeddings[dataset.datasets[0]].shape

torch.Size([10809])

In [17]:
dataset.datasets

['10k_pbmcs_proc']

In [18]:
# seems to have the mapping from gene to protein embeddings
dataset.dataset_to_protein_embeddings['10k_pbmcs_proc'].shape

torch.Size([10809])

In [19]:
dataset.dataset_to_starts['10k_pbmcs_proc'].shape

(10809,)

In [25]:
dataset.dataset_to_protein_embeddings[dataset.datasets[0]]

tensor([28041, 26112, 20529,  ..., 23672, 23677, 23680])

In [20]:
args.genes_to_mask

['STAT2']

In [21]:
multi_dataset_sentence_collator = MultiDatasetSentenceCollator(args)

In [22]:
multi_dataset_sentence_collator.idx_to_mask

[29700]

In [26]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                            collate_fn=multi_dataset_sentence_collator,
                            num_workers=0)

In [48]:
x = next(iter(dataloader))

Masking done
Masking done


In [52]:
torch.nonzero(x[0] == 29700)

tensor([[  6, 928],
        [  6, 929],
        [  8, 801]], device='cuda:0')

In [56]:
x[1][8, 801]

tensor(0., device='cuda:0')

In [None]:
dataloader = accelerator.prepare(dataloader)
pbar = tqdm(dataloader, disable=not accelerator.is_local_main_process)
dataset_embeds = []
with torch.no_grad():
    for batch in pbar:
        batch_sentences, mask, idxs = batch[0], batch[1], batch[2]
        batch_sentences = batch_sentences.permute(1, 0)
        batch_sentences_save = batch_sentences.permute(1, 0)
        if args.multi_gpu:
            batch_sentences = model.module.pe_embedding(batch_sentences.long())
        else:
            batch_sentences = model.pe_embedding(batch_sentences.long())
        batch_sentences = nn.functional.normalize(batch_sentences,
                                                    dim=2)  # Normalize token outputs now
        pred, embedding = model.forward(batch_sentences, mask=mask)
        # Fix for duplicates in last batch
        accelerator.wait_for_everyone()
        embeddings = accelerator.gather_for_metrics((embedding))
        if accelerator.is_main_process:
            dataset_embeds.append(embeddings.detach().cpu().numpy())

In [154]:
torch.nonzero(batch_sentences_save[0] == 144155)

tensor([[1]], device='cuda:0')

In [158]:
mask[0,torch.nonzero(batch_sentences_save[0] == 2)]

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]], device='cuda:0')

In [89]:
# cast to int
np.sort(batch_sentences_save[0].int().cpu().numpy())[::-1][0:100]


array([144186, 144185, 144159, 144158, 144157, 144156, 144155, 144154,
       144153, 144152, 144151, 144150, 144149, 144148, 144147, 144146,
       144145, 144144, 144143, 144142, 144141, 144140, 144139, 144138,
        33086,  33065,  32974,  32974,  32847,  32804,  32804,  32729,
        32729,  32728,  32600,  32600,  32599,  32558,  32511,  32511,
        32511,  32511,  32494,  32419,  32419,  32411,  32411,  32400,
        32400,  32386,  32384,  32352,  32352,  32314,  32240,  32236,
        32236,  32236,  32226,  32150,  32105,  32098,  32098,  32098,
        32098,  32094,  32087,  32046,  32046,  32012,  32012,  31938,
        31902,  31868,  31780,  31780,  31780,  31774,  31774,  31745,
        31745,  31745,  31721,  31721,  31701,  31696,  31696,  31696,
        31687,  31633,  31633,  31621,  31611,  31611,  31611,  31420,
        31419,  31419,  31419,  31387], dtype=int32)

In [92]:
mask.shape

torch.Size([25, 1075])

In [23]:
# these are the gene idxs which are mapped above
batch_sentences_save[0,0]

tensor(3., device='cuda:0')

In [24]:
model.pe_embedding(torch.tensor(3).to('cuda'))

tensor([ 1.2758,  0.8424,  0.2497,  ...,  1.1378,  1.6793, -1.8231],
       device='cuda:0')

In [25]:
nn.functional.normalize(all_pe[144146].unsqueeze(0), dim=1)

tensor([[ 0.0070, -0.0088,  0.0030,  ...,  0.0112, -0.0101, -0.0255]])

In [26]:
all_pe[144150]

tensor([ 1.3609, -0.1782,  0.8830,  ...,  0.5165, -0.9954,  0.5466])

In [27]:
batch_sentences[:,0,:][1]

tensor([-0.0072, -0.0125,  0.0050,  ...,  0.0057, -0.0163, -0.0150],
       device='cuda:0')

In [28]:
batch_sentences_save[0][2]

tensor(14948., device='cuda:0')

In [29]:
batch_sentences.shape

torch.Size([1075, 25, 5120])

In [30]:
model.gene_embedding_layer(batch_sentences)[0].shape

torch.Size([25, 1280])

In [31]:
embedding.shape

torch.Size([25, 1280])

In [32]:
y = model.predict(embedding, batch_sentences[3])

In [33]:
y

tensor([[-4.3502],
        [-4.8144],
        [-5.8105],
        [-2.8324],
        [-1.6371],
        [-1.5982],
        [-2.2231],
        [-6.7526],
        [-3.2576],
        [-4.5858],
        [-1.4756],
        [-4.5196],
        [-1.8379],
        [-2.2768],
        [-3.8854],
        [-4.0805],
        [-5.2052],
        [-2.9048],
        [-1.4073],
        [-3.0574],
        [-2.7870],
        [-3.2805],
        [-3.4762],
        [-1.9117],
        [-1.6429]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [34]:
# read in anndata 10k_pbmcs_proc_uce_adata.h5ad
adata = sc.read("10k_pbmcs_proc_uce_adata.h5ad")

In [35]:
adata.var[adata.var["gene_symbols"]=="FOXO3"]

Unnamed: 0,gene_symbols,n_counts-0,n_counts-1,n_counts,highly_variable,highly_variable_rank,means,variances,variances_norm,n_cells
FOXO3,FOXO3,865.0,417.0,1241.0,True,6208.0,0.103503,0.109813,0.955795,1150


In [36]:
adata.var["gene_symbols"]

SAMD11      SAMD11
PLEKHN1    PLEKHN1
HES4          HES4
ISG15        ISG15
AGRN          AGRN
            ...   
MT-ATP8    MT-ATP8
MT-ATP6    MT-ATP6
MT-CO3      MT-CO3
MT-ND4      MT-ND4
MT-ND6      MT-ND6
Name: gene_symbols, Length: 10809, dtype: category
Categories (10809, object): ['A1BG', 'A2M', 'AAAS', 'AACS', ..., 'ZYG11B', 'ZYX', 'ZZEF1', 'ZZZ3']

In [37]:
adata.obs['cell_type'].unique() 

['CD4 T cells', 'CD14+ Monocytes', 'CD8 T cells', 'B cells', 'Other', 'Dendritic Cells', 'FCGR3A+ Monocytes', 'NK cells', 'Megakaryocytes']
Categories (9, object): ['B cells', 'CD4 T cells', 'CD8 T cells', 'CD14+ Monocytes', ..., 'FCGR3A+ Monocytes', 'Megakaryocytes', 'NK cells', 'Other']

In [38]:
adata.obsm['X_uce'].shape

(11990, 1280)

In [39]:
adata

AnnData object with n_obs × n_vars = 11990 × 10809
    obs: 'n_counts', 'batch', 'labels', 'str_labels', 'cell_type', 'n_genes'
    var: 'gene_symbols', 'n_counts-0', 'n_counts-1', 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'n_cells'
    uns: 'cell_types', 'hvg'
    obsm: 'X_uce', 'design', 'normalized_qc', 'qc_pc', 'raw_qc'

In [40]:
adata.var[adata.var['gene_symbols']=='MTOR']

Unnamed: 0,gene_symbols,n_counts-0,n_counts-1,n_counts,highly_variable,highly_variable_rank,means,variances,variances_norm,n_cells
MTOR,MTOR,197.0,108.0,298.0,True,6917.0,0.024854,0.025239,0.949541,292


In [41]:
counts = np.memmap("10k_pbmcs_proc_counts.npz", dtype='int64', mode='r', shape = (11990,10809))

In [45]:
tokens, counts = np.unique(batch_sentences_save[0].cpu().numpy(), return_counts=True)

counts = dict(zip(tokens, counts))

# sort by counts
sorted(counts.items(), key=lambda x: x[1], reverse=True)

counts

{0.0: 2,
 2.0: 24,
 3.0: 1,
 13466.0: 2,
 13534.0: 2,
 13554.0: 1,
 13598.0: 1,
 13600.0: 1,
 13624.0: 3,
 13671.0: 3,
 13674.0: 3,
 13691.0: 1,
 13708.0: 2,
 13798.0: 1,
 13833.0: 3,
 13855.0: 3,
 13924.0: 1,
 13933.0: 2,
 13990.0: 1,
 14017.0: 2,
 14075.0: 1,
 14149.0: 2,
 14155.0: 1,
 14241.0: 2,
 14248.0: 1,
 14249.0: 2,
 14270.0: 2,
 14332.0: 4,
 14354.0: 2,
 14364.0: 1,
 14375.0: 1,
 14432.0: 1,
 14474.0: 4,
 14537.0: 3,
 14548.0: 2,
 14576.0: 3,
 14639.0: 2,
 14773.0: 2,
 14906.0: 1,
 14948.0: 1,
 14996.0: 1,
 15023.0: 3,
 15044.0: 1,
 15125.0: 3,
 15179.0: 4,
 15252.0: 1,
 15523.0: 1,
 15540.0: 1,
 15594.0: 1,
 15660.0: 1,
 15706.0: 2,
 15711.0: 1,
 15736.0: 1,
 15755.0: 1,
 16005.0: 1,
 16035.0: 1,
 16066.0: 2,
 16068.0: 5,
 16069.0: 1,
 16073.0: 1,
 16078.0: 1,
 16086.0: 4,
 16088.0: 1,
 16092.0: 3,
 16097.0: 1,
 16144.0: 1,
 16200.0: 1,
 16202.0: 2,
 16222.0: 1,
 16227.0: 1,
 16271.0: 1,
 16335.0: 2,
 16429.0: 3,
 16477.0: 1,
 16496.0: 3,
 16594.0: 2,
 16757.0: 1,
 16785.0: 

In [46]:
np.argwhere(batch_sentences_save[0].cpu().numpy() == 13491.0)

array([], shape=(0, 1), dtype=int64)

In [47]:
nn.functional.normalize(all_pe[13491].unsqueeze(0), dim=1)

tensor([[ 0.0246,  0.0061,  0.0222,  ..., -0.0028, -0.0150, -0.0100]])

In [48]:
nn.functional.normalize(batch_sentences[871,0,:].unsqueeze(0), dim=1)

tensor([[ 0.0341,  0.0098,  0.0111,  ..., -0.0083, -0.0130, -0.0190]],
       device='cuda:0')

In [49]:
np.argwhere(batch_sentences[0].cpu().numpy() == 13491)

array([], shape=(0, 2), dtype=int64)

In [50]:
batch_sentences[1072,0,:]

tensor([-0.0224, -0.0055, -0.0064,  ..., -0.0043, -0.0310,  0.0039],
       device='cuda:0')

In [51]:
mask[0].sum()

tensor(1073., device='cuda:0')

In [52]:
batch_sentences_save[4]

tensor([3.0000e+00, 1.4418e+05, 2.3674e+04,  ..., 2.0000e+00, 0.0000e+00,
        0.0000e+00], device='cuda:0')

In [53]:
dataset.npzs_dir

'./'

In [54]:
cts = np.memmap(dataset.npzs_dir + f"{dataset.datasets[0]}_counts.npz", dtype='int64', mode='r', shape=dataset.shapes_dict[dataset.datasets[0]])

In [55]:
counts = cts[0]

In [56]:
counts = torch.tensor(counts).unsqueeze(0)
weights = torch.log1p(counts)
weights = (weights / torch.sum(weights))

In [57]:
dataset_obj = dataset

In [58]:
batch_weights = weights
dataset_to_protein_embeddings = dataset.dataset_to_protein_embeddings
dataset_to_chroms = dataset.dataset_to_chroms
dataset_to_starts = dataset.dataset_to_starts
dataset = dataset.datasets[0]

In [59]:
dataset_idxs = dataset_to_protein_embeddings[dataset] # get the dataset specific protein embedding idxs
cell_sentences = torch.zeros((counts.shape[0], args.pad_length)) # init the cell representation as 0s
mask = torch.zeros((counts.shape[0], args.pad_length)) # start of masking the whole sequence
chroms = dataset_to_chroms[dataset] # get the dataset specific chroms for each gene
starts = dataset_to_starts[dataset] # get the dataset specific genomic start locations for each gene

longest_seq_len = 0 # we need to keep track of this so we can subset the batch at the end


In [60]:
for c, cell in enumerate(counts):
    weights = batch_weights[c].numpy()
    weights = weights / sum(weights)  # RE NORM after mask
    
    # randomly choose the genes that will make up the sample, weighted by expression, with replacement
    choice_idx = np.random.choice(np.arange(len(weights)),
                                    size=args.sample_size, p=weights,
                                    replace=True)
    choosen_chrom = chroms[choice_idx] # get the sampled genes chromosomes
    # order the genes by chromosome
    chrom_sort = np.argsort(choosen_chrom)  
    choice_idx = choice_idx[chrom_sort]

    # sort the genes by start
    new_chrom = chroms[choice_idx]
    choosen_starts = starts[choice_idx]

    ordered_choice_idx = np.full((args.pad_length),
                                    args.cls_token_idx)  # start with cls
    # i= 0 first token is CLS
    i = 1  # continue on to the rest of the sequence with left bracket being assumed.
    # Shuffle the chroms now, there's no natural order to chromosomes
    uq_chroms = np.unique(new_chrom)
    np.random.shuffle(uq_chroms) # shuffle
    
    # This loop is actually just over one cell
    for chrom in uq_chroms:
        # Open Chrom token
        ordered_choice_idx[i] = int(chrom) + args.CHROM_TOKEN_OFFSET # token of this chromosome # i = 1 next token is a chrom open
        i += 1
        # now sort the genes by start order within the chroms
        loc = np.where(new_chrom == chrom)[0]
        sort_by_start = np.argsort(
            choosen_starts[loc])  # start locations for this chromsome

        to_add = choice_idx[loc[sort_by_start]]
        ordered_choice_idx[i:(i + len(to_add))] = dataset_idxs[to_add]
        i += len(to_add)
        ordered_choice_idx[i] = args.chrom_token_right_idx # add the chrom sep again
        i += 1  # add the closing token again

    longest_seq_len = max(longest_seq_len, i)
    remainder_len = (args.pad_length - i)

    cell_mask = torch.concat((torch.ones(i),
                                # pay attention to all of these tokens, ignore the rest!
                                torch.zeros(remainder_len)))

    mask[c, :] = cell_mask

    ordered_choice_idx[i:] = args.pad_token_idx # the remainder of the sequence
    cell_sentences[c, :] = torch.from_numpy(ordered_choice_idx)

In [61]:
cell_sentences_pe = cell_sentences.long() # token indices

In [62]:
cell_sentences_pe

tensor([[     3, 144144,  16509,  ...,      0,      0,      0]])

In [93]:
dataset_obj.dataset_to_protein_embeddings[dataset].numpy().max()

33255

In [118]:
adata.var['gene_symbols'].to_list().index

['SAMD11',
 'PLEKHN1',
 'HES4',
 'ISG15',
 'AGRN',
 'C1orf159',
 'TTLL10',
 'TNFRSF18',
 'TNFRSF4',
 'SDF4',
 'B3GALT6',
 'SCNN1D',
 'ACAP3',
 'PUSL1',
 'CPTP',
 'TAS1R3',
 'MXRA8',
 'ANKRD65',
 'ATAD3C',
 'ATAD3B',
 'ATAD3A',
 'MIB2',
 'MMP23B',
 'CDK11B',
 'SLC35E2B',
 'CDK11A',
 'NADK',
 'GNB1',
 'TMEM52',
 'PRKCZ',
 'SKI',
 'PEX10',
 'PLCH2',
 'PANK4',
 'MMEL1',
 'MEGF6',
 'TPRG1L',
 'WRAP73',
 'SMIM1',
 'LRRC47',
 'CEP104',
 'DFFB',
 'C1orf174',
 'AJAP1',
 'NPHP4',
 'KCNAB2',
 'RNF207',
 'ICMT',
 'GPR153',
 'ACOT7',
 'HES2',
 'ESPN',
 'TNFRSF25',
 'PLEKHG5',
 'NOL9',
 'ZBTB48',
 'KLHL21',
 'PHF13',
 'THAP3',
 'DNAJC11',
 'VAMP3',
 'PER3',
 'UTS2',
 'TNFRSF9',
 'PARK7',
 'SLC45A1',
 'RERE',
 'ENO1',
 'CA6',
 'SLC2A5',
 'GPR157',
 'H6PD',
 'SPSB1',
 'SLC25A33',
 'PIK3CD',
 'CLSTN1',
 'CTNNBIP1',
 'LZIC',
 'NMNAT1',
 'RBP7',
 'UBE4B',
 'KIF1B',
 'PGD',
 'DFFA',
 'PEX14',
 'CASZ1',
 'C1orf127',
 'TARDBP',
 'SRM',
 'EXOSC10',
 'MTOR',
 'UBIAD1',
 'FBXO2',
 'FBXO44',
 'FBXO6',
 'MAD2L2'

In [97]:
dataset_obj.dataset_to_protein_embeddings[dataset][10349]

tensor(32975)

In [102]:
nn.functional.normalize(model.pe_embedding(dataset_obj.dataset_to_protein_embeddings[dataset][10349].to('cuda')).unsqueeze(0), dim=1)

tensor([[ 0.0219,  0.0010,  0.0236,  ...,  0.0084, -0.0243, -0.0175]],
       device='cuda:0')

In [111]:
batch_sentences.shape

torch.Size([1075, 25, 5120])

In [106]:
batch_sentences[2,0,:]

tensor([ 0.0226, -0.0027,  0.0209,  ..., -0.0062, -0.0228, -0.0036],
       device='cuda:0')

In [112]:
batch_sentences_save[0,2]

tensor(23557., device='cuda:0')

In [127]:
adata.var['gene_symbols'].to_list().index('STAT2')

6760

In [132]:
# create dictionary from gene symbol to index 
gene_to_idx = dict(zip(adata.var['gene_symbols'], range(adata.var.shape[0])))

pe_index = dataset_obj.dataset_to_protein_embeddings[dataset].numpy()

gene_to_pe_index = { g: pe_index[i] for i, g in enumerate(adata.var['gene_symbols'].to_list())}


In [134]:
# write this to a pickle
with open("./model_files/gene_to_pe_index.pkl", "wb") as f:
    pickle.dump(gene_to_pe_index, f)

In [133]:
gene_to_pe_index['STAT2']

29700

In [123]:
np.where(dataset_obj.dataset_to_protein_embeddings[dataset].numpy() == 23557)

(array([2834]),)

In [131]:
dataset_obj.dataset_to_protein_embeddings[dataset].numpy()

array([28041, 26112, 20529, ..., 23672, 23677, 23680])

In [129]:
model.pe_embedding(torch.tensor(29700).to('cuda'))

tensor([ 0.1137,  0.0827,  0.1744,  ..., -0.0165,  0.0019, -0.1174],
       device='cuda:0')