In [27]:
import os
from anndata import AnnData
from dotenv import load_dotenv

import warnings
warnings.filterwarnings('ignore')

# Initialize constants
load_dotenv()
CONDITION_KEY, CELL_TYPE_KEY = os.getenv('CONDITION_KEY'), os.getenv('CELL_TYPE_KEY')

from load_data import get_adata
from sc_condition_prediction import create_and_train_vae_model, evaluate_r2, N_INPUT, N_LAYERS, N_HIDDEN, N_LATENT, BATCH_SIZE
from utils import remove_stimulated_for_celltype

# Load data
train_adata = get_adata(train=True, verbose=True)
train_adata_no_cd4t = remove_stimulated_for_celltype(train_adata, celltype="CD4T")

KeyError: 'controlled'

In [None]:
train_adata.to_df().describe()

In [None]:
def make_subsets_from_adata(adata: AnnData, verbose=False):
    cell_types = adata.obs[CELL_TYPE_KEY].cat.categories.values
    if verbose:
        print(f"Unique cell types: {cell_types}")
    adata_by_cell_type = [ 
        adata[adata.obs[CELL_TYPE_KEY] == cell_type]
        for cell_type in cell_types
    ]
    if verbose:
        print(f"AnnData objects by cell types: {adata_by_cell_type}")
    from itertools import combinations
    subsets_adata_by_cell_type = []
    for i in range(1, len(adata_by_cell_type) + 1):  # to get all subsets: from 1 to size (omitting 0 subset)
        for subset_adata_by_cell_type in combinations(adata_by_cell_type, i):
            subsets_adata_by_cell_type.append(subset_adata_by_cell_type)
    return subsets_adata_by_cell_type

In [None]:
subsets_adata_by_cell_type = make_subsets_from_adata(train_adata, verbose=True)

In [None]:
for i, subset_adata_by_cell_type in enumerate(subsets_adata_by_cell_type):
    adata_sample = subset_adata_by_cell_type[0]
    params_filename = os.path.join("models", "subsets_test", f"{i}_autoencoder.pt")
    create_and_train_vae_model(adata_sample,
                               epochs=15,
                               save_params_to_filename=params_filename)
    r2, r2_diff_genes = evaluate_r2(params_filename)