In [1]:
import os
import anndata as ad
import numpy as np
import pandas as pd
import matplotlib as plt

from anndata import AnnData
from dotenv import load_dotenv

import warnings
warnings.filterwarnings('ignore')

# Initialize constants
load_dotenv()
CONDITION_KEY, CELL_TYPE_KEY = os.getenv('CONDITION_KEY'), os.getenv('CELL_TYPE_KEY')

from load_data import get_adata
from sc_condition_prediction import create_and_train_vae_model, evaluate_r2, N_INPUT, N_LAYERS, N_HIDDEN, N_LATENT, BATCH_SIZE
from utils import remove_stimulated_for_celltype

# Load data
train_adata = get_adata(train=True, verbose=True)
train_adata_no_cd4t = remove_stimulated_for_celltype(train_adata, celltype="CD4T")

AnnData object with n_obs × n_vars = 16893 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'connectivities', 'distances'


In [2]:
train_adata.to_df().describe()

index,AL627309.1,RP11-206L10.9,LINC00115,NOC2L,KLHL17,HES4,ISG15,TNFRSF18,TNFRSF4,SDF4,...,C21orf67,FAM207A,ADARB1,POFUT2,COL18A1,SLC19A1,COL6A2,FTCD,DIP2A,S100B
count,16893.0,16893.0,16893.0,16893.0,16893.0,16893.0,16893.0,16893.0,16893.0,16893.0,...,16893.0,16893.0,16893.0,16893.0,16893.0,16893.0,16893.0,16893.0,16893.0,16893.0
mean,0.000203,0.000442,0.0033,0.056011,0.000991,0.11463,1.799,0.053575,0.052044,0.067051,...,0.000333,0.032279,0.003885,0.006159,0.004496,0.002654,0.001596,0.000152,0.013051,0.009183
std,0.01274,0.01863,0.050151,0.204129,0.027527,0.316232,1.666201,0.230176,0.226953,0.216867,...,0.017223,0.154578,0.054166,0.067945,0.059533,0.046549,0.036505,0.011595,0.099866,0.098072
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,0.0,0.0,0.0,0.0,1.736927,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,0.0,0.0,0.0,0.0,0.0,0.0,3.120183,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,1.052949,1.027135,2.020838,2.251955,1.069814,2.507464,5.472819,2.349544,2.472356,2.879868,...,1.468089,1.657621,2.230935,1.538911,1.213284,1.51467,1.494599,0.996147,1.607041,2.438123


In [3]:
def make_subsets_from_adata(adata: AnnData, verbose=False):
    cell_types = adata.obs[CELL_TYPE_KEY].cat.categories.values
    if verbose:
        print(f"Unique cell types: {cell_types}")
    adata_by_cell_type = [ 
        adata[adata.obs[CELL_TYPE_KEY] == cell_type]
        for cell_type in cell_types
    ]
    if verbose:
        print(f"AnnData objects by cell types: {adata_by_cell_type}")
    from itertools import combinations
    adata_subsets = []
    for i in range(1, len(adata_by_cell_type) + 1):  # to get all subsets: from 1 to size (omitting 0 subset)
        for adata_subset in combinations(adata_by_cell_type, i):
            adata_subsets.append(adata_subset)
    return adata_subsets

In [4]:
subsets_adata_by_cell_type = make_subsets_from_adata(train_adata, verbose=True)

Unique cell types: ['CD4T' 'CD14+Mono' 'B' 'CD8T' 'NK' 'FCGR3A+Mono' 'Dendritic']
AnnData objects by cell types: [View of AnnData object with n_obs × n_vars = 5564 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'connectivities', 'distances', View of AnnData object with n_obs × n_vars = 2561 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'connectivities', 'distances', View of AnnData object with n_obs × n_vars = 1811 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type'
    var: 'gene_symbol', 'n_cells'
    uns: 'cell_type_colors', 'condition_colors', 'neighbors'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    obsp: 'connectivities', 'd

In [9]:
model_directory = os.path.join("models", "subsets_test")

if not os.path.exists(model_directory):
    os.makedirs(model_directory)

n_steps = len(subsets_adata_by_cell_type)

df_results = pd.DataFrame(data=np.zeros((n_steps, 2)), index=np.arange(n_steps), columns=['r2', 'r2_diff_genes'])

for i, subset_adata_by_cell_type in enumerate(subsets_adata_by_cell_type):
    adata_sample = ad.concat(list(subset_adata_by_cell_type), join="outer")
    params_filename = os.path.join(model_directory, f"{i}_autoencoder.pt")
    create_and_train_vae_model(adata_sample,
                               epochs=15,
                               save_params_to_filename=params_filename)
    r2, r2_diff_genes = evaluate_r2(params_filename)
    df_results.loc[i, ["r2", "r2_diff_genes"]] = [r2, r2_diff_genes]

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 200])

In [None]:
df_results.plot(y=['r2', 'r2_diff_genes'], kind='line');
plt.title('Impact of cell type subsetting on model score')
plt.xlabel('step')
plt.ylabel('$R^2$');