In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
import torch
from scetm_utils import read_aws_h5ad

from scETM import UnsupervisedTrainer, scETM
from scETM.batch_sampler import CellSampler

sys.path.append("..")
from utils import (
    filter_noisy_genes,
    generate_k_fold,
    write_adata_to_s3,
)

sc.set_figure_params(
    dpi=120, dpi_save=250, fontsize=10, figsize=(10, 10), facecolor="white"
)

In [None]:
# use anndata generate by ..data_processing/inhouse_prior_graph_preprocessing.ipynb
unfilterd_adata = read_aws_h5ad("path to preprocessed h5ad")
adata = filter_noisy_genes(unfilterd_adata)
adata.layers["logcounts"] = adata.X.copy()
adata.X = adata.X.todense()
device = torch.device("cuda:0")
gene_network = adata.uns["sparse_gene_network"].todense()

In [6]:
# powered perturbations
adata.obs["condition"] = adata.obs["condition"].astype(str)
adata.obs["Treatment"] = adata.obs["Treatment"].astype(str)
adata.obs["pert_treat"] = adata.obs["condition"] + "+" + adata.obs["Treatment"]
obs_df = pd.DataFrame(adata.obs["pert_treat"])
category_counts = obs_df["pert_treat"].value_counts()
filtered_categories = category_counts[category_counts >= 50].index
adata = adata[adata.obs["pert_treat"].isin(filtered_categories)]

In [7]:
adata

View of AnnData object with n_obs × n_vars = 65648 × 4997
    obs: 'num_features', 'feature_call', 'num_umis', 'target_gene_name', 'SampleIndex', 'ssid', 'Treatment', 'assigned_archetype', 'node_centrality', 'clusters', 'condition', 'control', 'pert_treat'
    var: 'gene_symbol', 'feature_types', 'genome', 'gene_id', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'gene_symbols'
    uns: 'hvg', 'metadata', 'obsm_annot', 'sparse_gene_network', 'varm_annot'
    obsm: 'ACTION', 'ACTION_B', 'ACTION_normalized', 'C_stacked', 'C_unified', 'H_stacked', 'H_unified', 'X_ACTIONet2D', 'X_ACTIONet3D', 'X_ACTIONred', 'X_denovo_color', 'archetype_footprint'
    varm: 'ACTION_A', 'ACTION_V', 'unified_feature_profile', 'unified_feature_specificity'
    layers: 'counts', 'logcounts'
    obsp: 'ACTIONet'

In [None]:
adata.X = adata.layers["counts"].todense()
sc.pp.normalize_total(adata, target_sum=1e4)
adata.X = np.array(adata.X)

In [9]:
adata.obs["cell_types"] = ["A549" for _ in range(adata.shape[0])]

  adata.obs['cell_types'] = ['A549' for _ in range(adata.shape[0])]


In [36]:
train_idx, val_idx, test_idx = generate_k_fold(
    adata, adata.X, adata.obs["condition"], fold_idx=4
)

In [37]:
adata_train = ad.AnnData(np.array(adata[train_idx].X))
adata_train.obs["condition"] = list(adata[train_idx].obs["condition"])
adata_train.obs["Treatment"] = list(adata[train_idx].obs["Treatment"])
adata_train.obs["cell_types"] = ["A549" for _ in range(adata_train.shape[0])]
adata_test = ad.AnnData(np.array(adata[test_idx].X))
adata_test.obs["condition"] = list(adata[test_idx].obs["condition"])
adata_test.obs["Treatment"] = list(adata[test_idx].obs["Treatment"])
adata_test.obs["cell_types"] = ["A549" for _ in range(adata_test.shape[0])]
# for scETM, subset to TNFA+ for better signal
adata_train = adata_train[adata_train.obs["Treatment"] == "TNFA+"]
adata_test = adata_test[adata_test.obs["Treatment"] == "TNFA+"]

In [39]:
inhouse_model = scETM(
    adata_train.n_vars,
    adata_train.obs.condition.nunique(),
    n_topics=200,
    trainable_gene_emb_dim=400,
)
trainer = UnsupervisedTrainer(inhouse_model, adata_train, test_ratio=0.2, seed=0)

[2024-11-13 00:25:06,897] INFO - scETM.src.scETM.logging_utils: scETM.__init__(4997, 151, n_topics = 200, trainable_gene_emb_dim = 400)
[2024-11-13 00:25:06,927] INFO - scETM.src.scETM.logging_utils: UnsupervisedTrainer.__init__(scETM(
  (q_delta): Sequential(
    (0): Linear(in_features=4997, out_features=128, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.1, inplace=False)
  )
  (mu_q_delta): Linear(in_features=128, out_features=200, bias=True)
  (logsigma_q_delta): Linear(in_features=128, out_features=200, bias=True)
  (rho_trainable_emb): PartlyTrainableParameter2D(height=400, fixed=0, trainable=4997)
), View of AnnData object with n_obs × n_vars = 25701 × 4997
    obs: 'condition', 'Treatment', 'cell_types', test_ratio = 0.2, seed = 0)
[2024-11-13 00:25:06,928] INFO - scETM.src.scETM.trainers.trainer_utils: Set seed to 0.
[2024-11-13 00:25:06,940] INFO - scETM.src.scETM.trainers.trainer_util

In [None]:
import time

start = time.time()
trainer.train(
    n_epochs=12000,
    eval_every=2000,
    batch_col="condition",
    eval_kwargs=dict(batch_col="condition"),
    save_model_ckpt=False,
)
end = time.time()
print(f"Training time: {end-start}")

In [None]:
# retrieve reconstructed gene expression
recon = []
theta = []
for i in range(2):
    adata_sub = adata_test[i * 10000 : min((i + 1) * 10000, len(adata))]
    sampler = CellSampler(
        adata_sub, 10000, sample_batch_id=True, n_epochs=1, batch_col="condition"
    )
    dataloader = iter(sampler)
    data_dict = {k: v.to(torch.device("cuda:0")) for k, v in next(dataloader).items()}
    out = inhouse_model.forward(data_dict=data_dict, hyper_param_dict={"decode": True})
    recon.append(out["recon_log"].clone().detach().cpu().numpy())
    theta.append(out["theta"].clone().detach().cpu().numpy())
all_recon = np.concatenate(recon)
all_theta = np.concatenate(theta)
assert len(adata_test) == all_recon.shape[0]

In [None]:
# save model parameters
adata_test.uns["topics"] = inhouse_model.alpha.clone().detach().cpu().numpy()
adata_test.uns["gene_emb"] = (
    inhouse_model.rho_trainable_emb.trainable.clone().detach().cpu().numpy()
)
adata_test.uns["cell_emb"] = all_theta
adata_test.uns["recon"] = all_recon

In [None]:
# save to s3
write_adata_to_s3(
    s3_url="s3://pert-spectra/scETM_checkpoints/scETM_inhouse/",
    adata_name="fold_4",
    adata=adata_test,
)