In [9]:
## Reload Notebook for development
%load_ext autoreload
%autoreload 2

from accelerate import Accelerator
import anndata as ad
import numpy as np

In [28]:
model_config  = {
    "model_loc": "./model_files/33l_8ep_1024t_1280.torch",
    "batch_size": 5, #25,
    "pad_length": 1536,
    "pad_token_idx": 0,
    "chrom_token_left_idx": 1,
    "chrom_token_right_idx": 2,
    "cls_token_idx": 3,
    "CHROM_TOKEN_OFFSET": 143574,
    "sample_size": 1024,
    "CXG": True,
    "n_layers": 33,
    "output_dim": 1280,
    "d_hid": 5120,
    "token_dim": 5120,
    "multi_gpu": False
}

files_config = {
    "spec_chrom_csv_path": "./model_files/species_chrom.csv",
    "token_file": "./model_files/all_tokens.torch",
    "protein_embeddings_dir": "./model_files/protein_embeddings/",
    "offset_pkl_path": "./model_files/species_offsets.pkl"
}

data_config = {
    "adata_path": "../data/full_cells_macaca_obs_sum_v3.h5ad",
    "dir": "./",
    "species": "macaca_fascicularis", #,'macaca_fascicularis',#"human",
    "filter": False,
    "skip": True
}

## Testing Data

In [29]:
ann_data = ad.read_h5ad('../../data/full_cells_macaca_obs_sum_v3.h5ad')
# ann_data = ad.read_h5ad("../../data/10k_pbmcs_proc.h5ad")

## Creating Perturbations

In [30]:
filter_indiv = (ann_data.obs['batch']==1) & (ann_data.obs['dose.share']=='dose3') & ((ann_data.obs['sex']=='M') & ((ann_data.obs['duration']==15) | (ann_data.obs['duration']== -9)))
gene_counts = ann_data.X[filter_indiv]
gene_idx_nonzero, = np.nonzero(gene_counts[1])
perturbation_idx = np.random.choice(gene_idx_nonzero,size=5,replace=False)
perturbation_candidates = []
perturbation_candidates.append(gene_counts[0])
perturbation_candidates.append(gene_counts[1])
for idx in perturbation_idx:
    candidate = gene_counts[1]
    candidate[idx] = gene_counts[0][idx]
    perturbation_candidates.append(candidate)

perturbation_candidates = np.vstack(perturbation_candidates)

In [31]:
perturbation_candidates = ann_data.X[:10]

In [32]:
test = ad.AnnData(X=perturbation_candidates)
test.var = ann_data.var
test.var_names = [gene.lower() for gene in ann_data.var_names.values]

In [15]:
test.var['gene_symbols']

samd11      SAMD11
plekhn1    PLEKHN1
hes4          HES4
isg15        ISG15
agrn          AGRN
            ...   
mt-atp8    MT-ATP8
mt-atp6    MT-ATP6
mt-co3      MT-CO3
mt-nd4      MT-ND4
mt-nd6      MT-ND6
Name: gene_symbols, Length: 12000, dtype: category
Categories (12000, object): ['A1BG', 'A2M', 'A4GALT', 'AAAS', ..., 'ZYG11B', 'ZYX', 'ZZEF1', 'ZZZ3']

In [6]:
from helical.models.uce.uce import UCE
accelerator = Accelerator(project_dir=data_config["dir"])

In [7]:
helical_model = UCE(model_config,data_config,files_config,accelerator=accelerator)

In [34]:
data = helical_model.process_data(test)#,species="human")

In [26]:
embeddings = helical_model.get_embeddings(data)
embeddings.shape

2024-03-18, 08:43:03.020 UCE-Model INFO Inference started
100%|██████████| 2/2 [00:30<00:00, 15.18s/it]


(10, 1280)