In [2]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
import torch
from scetm_utils import read_aws_h5ad

from scETM import UnsupervisedTrainer, scETM
from scETM.batch_sampler import CellSampler

sys.path.append("..")
from utils import filter_noisy_genes, split_data_by_cell, write_adata_to_s3

sc.set_figure_params(
    dpi=120, dpi_save=250, fontsize=10, figsize=(10, 10), facecolor="white"
)

In [None]:
# use anndata generate by ..data_processing/replogle_prior_graph_preprocessing.ipynb
unfiltered_adata = read_aws_h5ad("path to h5ad")
adata = filter_noisy_genes(unfiltered_adata)
adata.layers["logcounts"] = adata.X.copy()
adata.X = adata.X.todense()
device = torch.device("cuda:0")
gene_network = adata.uns["sparse_gene_network"].todense()

In [6]:
# filter adata to perturbations with at least 50 samples
obs_df = pd.DataFrame(adata.obs["gene"])
category_counts = obs_df["gene"].value_counts()
filtered_categories = category_counts[category_counts >= 50].index
adata = adata[adata.obs["gene"].isin(filtered_categories)]

In [7]:
# reference the svae filtered replogle anndata to subset to those cells (see ../data for instructions on generating the anndata object)
filtered_replogle = read_aws_h5ad("path to svae filtered replogle h5ad")
filtered_perts = set(filtered_replogle.obs["gene"].unique()).union(
    set(["SKP2", "CUL1", "UBE2N"])
)
adata = adata[adata.obs["gene"].isin(filtered_perts)]

In [8]:
adata.X = adata.layers["counts"]
sc.pp.normalize_total(adata, target_sum=1e4)
adata.X = np.asarray(adata.X)

In [10]:
train_idx, val_idx, test_idx = split_data_by_cell(
    adata.X, adata.obs["gene"], test_size=0.2, val_size=0.2
)

In [11]:
adata_train = ad.AnnData(np.array(adata[train_idx].X))
adata_train.obs["gene"] = list(adata[train_idx].obs["gene"])
adata_train.obs["cell_types"] = ["K562" for _ in range(adata_train.shape[0])]
adata_test = ad.AnnData(np.array(adata[test_idx].X))
adata_test.obs["gene"] = list(adata[test_idx].obs["gene"])
adata_test.obs["cell_types"] = ["K562" for _ in range(adata_test.shape[0])]

In [13]:
model = scETM(
    adata_train.n_vars,
    adata_train.obs.gene.nunique(),
    n_topics=200,
    trainable_gene_emb_dim=400,
)
trainer = UnsupervisedTrainer(model, adata_train, test_ratio=0.1, seed=0)

[2024-11-13 07:15:56,365] INFO - scETM.src.scETM.logging_utils: scETM.__init__(4935, 517, n_topics = 200, trainable_gene_emb_dim = 400)
[2024-11-13 07:15:57,358] INFO - scETM.src.scETM.logging_utils: UnsupervisedTrainer.__init__(scETM(
  (q_delta): Sequential(
    (0): Linear(in_features=4935, out_features=128, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.1, inplace=False)
  )
  (mu_q_delta): Linear(in_features=128, out_features=200, bias=True)
  (logsigma_q_delta): Linear(in_features=128, out_features=200, bias=True)
  (rho_trainable_emb): PartlyTrainableParameter2D(height=400, fixed=0, trainable=4935)
), AnnData object with n_obs × n_vars = 67803 × 4935
    obs: 'gene', 'cell_types', test_ratio = 0.1, seed = 0)
[2024-11-13 07:15:57,361] INFO - scETM.src.scETM.trainers.trainer_utils: Set seed to 0.
[2024-11-13 07:15:57,422] INFO - scETM.src.scETM.trainers.trainer_utils: Keeping 6780 cells (0.1

In [None]:
trainer.train(
    n_epochs=12000,
    eval_every=2000,
    batch_col="gene",
    eval_kwargs=dict(batch_col="gene"),
    save_model_ckpt=False,
)

In [20]:
# retrieve reconstructed gene expression
recon = []
theta = []
for i in range(3):
    adata_sub = adata_test[i * 10000 : min((i + 1) * 10000, len(adata))]
    sampler = CellSampler(
        adata_sub, 10000, sample_batch_id=True, n_epochs=1, batch_col="gene"
    )
    dataloader = iter(sampler)
    data_dict = {k: v.to(torch.device("cuda:0")) for k, v in next(dataloader).items()}
    out = model.forward(data_dict=data_dict, hyper_param_dict={"decode": True})
    recon.append(out["recon_log"].clone().detach().cpu().numpy())
    theta.append(out["theta"].clone().detach().cpu().numpy())
all_recon = np.concatenate(recon)
all_theta = np.concatenate(theta)
assert len(adata_test) == all_recon.shape[0]
assert len(adata_test) == all_theta.shape[0]

In [21]:
# save model parameters
adata_test.uns["topics"] = model.alpha.clone().detach().cpu().numpy()
adata_test.uns["gene_emb"] = (
    model.rho_trainable_emb.trainable.clone().detach().cpu().numpy()
)
adata_test.uns["cell_emb"] = all_theta
adata_test.uns["recon"] = all_recon
adata_test.X = np.array(adata_test.X)

In [22]:
# save to s3
write_adata_to_s3(
    s3_url="s3://pert-spectra/scETM_checkpoints/scETM_replogle/",
    adata_name="scETM_replogle",
    adata=adata_test,
)