In [1]:
from accelerate import Accelerator
import anndata as ad
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_config  = {
    "model_loc": "./model_files/33l_8ep_1024t_1280.torch",
    "batch_size": 5, #25,
    "pad_length": 1536,
    "pad_token_idx": 0,
    "chrom_token_left_idx": 1,
    "chrom_token_right_idx": 2,
    "cls_token_idx": 3,
    "CHROM_TOKEN_OFFSET": 143574,
    "sample_size": 1024,
    "CXG": True,
    "n_layers": 33,
    "output_dim": 1280,
    "d_hid": 5120,
    "token_dim": 5120,
    "multi_gpu": False
}

files_config = {
    "spec_chrom_csv_path": "./model_files/species_chrom.csv",
    "token_file": "./model_files/all_tokens.torch",
    "protein_embeddings_dir": "./model_files/protein_embeddings/",
    "offset_pkl_path": "./model_files/species_offsets.pkl"
}

data_config = {
    "adata_path": "../data/full_cells_macaca_obs_sum_v3.h5ad",
    "dir": "./",
    "species": 'macaca_fascicularis',#"human",
    "filter": False,
    "skip": True
}

## Testing Data

In [3]:
ann_data = ad.read_h5ad('../../data/full_cells_macaca_obs_sum_v3.h5ad')



## Creating Perturbations

In [4]:
filter_indiv = (ann_data.obs['batch']==1) & (ann_data.obs['dose.share']=='dose3') & ((ann_data.obs['sex']=='M') & ((ann_data.obs['duration']==15) | (ann_data.obs['duration']== -9)))
gene_counts = ann_data.X[filter_indiv]
gene_idx_nonzero, = np.nonzero(gene_counts[1])
perturbation_idx = np.random.choice(gene_idx_nonzero,size=5,replace=False)
perturbation_candidates = []
perturbation_candidates.append(gene_counts[0])
perturbation_candidates.append(gene_counts[1])
for idx in perturbation_idx:
    candidate = gene_counts[1]
    candidate[idx] = gene_counts[0][idx]
    perturbation_candidates.append(candidate)

perturbation_candidates = np.vstack(perturbation_candidates)

In [5]:
test = ad.AnnData(X=perturbation_candidates)
test.var = ann_data.var

In [6]:
from helical.models.uce.uce import UCE
accelerator = Accelerator(project_dir=data_config["dir"])

In [7]:
helical_model = UCE(model_config,data_config,files_config,accelerator=accelerator)

In [8]:
data = helical_model.process_data(test)

In [9]:
embeddings = helical_model.get_embeddings(data)
embeddings.shape

2024-03-16, 19:56:22.839 UCE-Model INFO Inference started
100%|██████████| 2/2 [00:13<00:00,  6.62s/it]


(7, 1280)