In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
pip install git+https://github.com/theislab/drvi.git@main

In [None]:
pip install gprofiler

In [None]:
import warnings

In [None]:
warnings.filterwarnings("ignore")

In [None]:
import anndata as ad
import scanpy as sc
from matplotlib import pyplot as plt
from IPython.display import display
from gprofiler import gprofiler
import torch
import numpy as np
import drvi
from drvi.model import DRVI
from drvi.utils.misc import hvg_batch

In [None]:
sc.settings.set_figure_params(dpi=100, frameon=False)
sc.set_figure_params(dpi=100)
sc.set_figure_params(figsize=(3, 3))
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.figsize"] = (3, 3)

In [None]:
adata = sc.read("/work/DRVI/fixed_data/HTAPP/HTAPP_997_processed_raw_FINAL.h5ad")
adata

In [None]:
adata.obs["cell_type"]

In [None]:
# Subset the data
subset_mask = (adata.obs['replicate'] == '1') & (adata.obs['cell_type'] == 'hepatocyte')
subset_adata = adata[subset_mask].copy()  
train_adata = adata[~subset_mask].copy()  # Training data is everything except the subset

In [None]:
subset_adata

In [None]:
train_adata

In [None]:
train_adata.layers["counts"] = train_adata.X.copy()

In [None]:
train_adata.layers["counts"]

In [None]:
print(type(train_adata.layers["counts"]))
print(train_adata.layers["counts"].shape)


In [None]:
print(train_adata.layers["counts"][:50, :50])  # View top-left 5×5 corner

In [None]:
# Setup data
DRVI.setup_anndata(
    train_adata,
    # DRVI accepts count data by default.
    # Do not forget to change gene_likelihood if you provide a non-count data.
    layer="counts",
    # Always provide a list. DRVI can accept multiple covariates.
    categorical_covariate_keys=["cell_type", "Phase", "replicate", "compartments", "cnv_pass_mal"],
    # DRVI accepts count data by default.
    # Set to false if you provide log-normalized data and use normal distribution (mse loss).
    is_count_data=True,
)

# construct the model
model = DRVI(
    train_adata,
    # Provide categorical covariates keys once again. Refer to advanced usages for more options.
    categorical_covariates=["cell_type", "Phase", "replicate", "compartments", "cnv_pass_mal"],
    n_latent=128,
    # For encoder and decoder dims, provide a list of integers.
    encoder_dims=[128, 128],
    decoder_dims=[128, 128],
)
model

In [None]:
# For cpu training you should add the following line to the model.train parameters:
#accelerator="cpu", devices=1,
#
# For mps acceleration on macbooks, add the following line to the model.train parameters:
# accelerator="mps", devices=1,
#
# For gpu training don't provide any additional parameter.
# More details here: https://lightning.ai/docs/pytorch/stable/accelerators/gpu_basic.html

n_epochs = 400

# train the model
model.train(
    max_epochs=n_epochs,
    early_stopping=False,
    early_stopping_patience=20,
    # mps
    # accelerator="mps", devices=1,
    # cpu
    #accelerator="cpu", devices=1,
    # gpu: no additional parameter
    #
    # No need to provide `plan_kwargs` if n_epochs >= 400.
    plan_kwargs={
        "n_epochs_kl_warmup": n_epochs,
    },
)

# Runtime:
# The runtime for CPU laptop (M1) is 208 minutes
# The runtime for Macbook gpu (M1) is 64 minutes
# The runtime for GPU (A100) is 17 minutes

In [None]:
model.save("DRVI/LOA_DRVI/LOA_models_new_fixed/HTAPP_fixed_LOA_last")

In [None]:
model = model.load("trained_models/DRVI_HH_train_2/", adata = adata)

In [None]:
def predict(model, adata):
    model._validate_anndata(adata)
    model.module.eval()

    scdl = model._make_data_loader(adata=adata, indices=None, batch_size=128, shuffle=False)
    mus = []
    for tensors in scdl:
        inference_outputs, generative_outputs = model.module.forward(
                    tensors,
                    compute_loss=False,
                )
        _mus = torch.nan_to_num(generative_outputs['px'].mean, nan=0, neginf=0, posinf=100) 
        mus.append(_mus.detach().cpu().numpy())
    mus = np.concatenate(mus, axis=0)
    out_adata = adata.copy()
    out_adata.X = mus
    return out_adata


#model._validate_anndata(subset_adata)
#ec = predict(model, subset_adata)

In [None]:
rep_2_base = train_adata[
    (train_adata.obs["replicate"] == "2") &
    (train_adata.obs["cell_type"] == "hepatocyte")
].copy()

In [None]:
rec_rep2 = predict(model, rep_2_base)

In [None]:
cf = rep_2_base.copy()
cf.obs["replicate"] = "1"

In [None]:
rec_cf = predict(model, cf)

In [None]:
# Compute per-gene ground-truth means
# Gt for females is now females with hepatocytes

rep_2_gt = (
    rep_2_base.X.toarray().mean(axis=0)
    if hasattr(rep_2_base.X, "toarray")
    else rep_2_base.X.mean(axis=0)
)


#    Gt for the male is now the held out dataset
rep_1_gt = (
    subset_adata.X.toarray().mean(axis=0)
    if hasattr(subset_adata.X, "toarray")
    else subset_adata.X.mean(axis=0)
)

In [None]:
rep_2_gt

In [None]:
rep_1_gt

In [None]:
#  Compute per-gene prediction means 
baseline_pred = np.asarray(rec_rep2.X.mean(axis=0)).ravel()
cf_pred       = np.asarray(rec_cf.X.mean(axis=0)).ravel()

In [None]:
f_m = np.sqrt(np.mean((baseline_pred - rep_1_gt)**2)) #pred-F --- TM
m_m = np.sqrt(np.mean((cf_pred - rep_1_gt)**2)) # pred_M --- TM
f_f  =np.sqrt(np.mean((baseline_pred - rep_2_gt)**2)) # pref F ----- TF
m_f = np.sqrt(np.mean((cf_pred - rep_2_gt)**2)) # pref M ---- TF

print(f"RMSE pred F --- TM: {f_m:.4f}")
print(f"RMSE pred M --- TM (counterfactual)  : {m_m:.4f}")
print(f"RMSE pred F ----- TF : {f_f:.4f}")
print(f"RMSE pred M ---- TF (counterfactual)  : {m_f:.4f}")

In [None]:
import matplotlib.pyplot as plt

# Define RMSE values
rmse_data = [
    ["0.0287", "0.3625"],  # True Female
    ["0.9498", "1.1606"],  # True Male
]

# Define labels
column_labels = ["Pred Replicate 2", "Pred Replicate 1 (CF)"]
row_labels = ["True Replicate 2", "True Replicate 1"]

# Create the figure and axis
fig, ax = plt.subplots(figsize=(6, 2))
ax.axis('tight')
ax.axis('off')

# Create the table
table = ax.table(
    cellText=rmse_data,
    rowLabels=row_labels,
    colLabels=column_labels,
    cellLoc='center',
    loc='center'
)

table.scale(1, 2)  # Increase row height
table.auto_set_font_size(False)
table.set_fontsize(12)

# Color the lower row cells
# Note: Rows and columns are 1-indexed in table.get_celld()
table[(2, 0)].set_facecolor("orange")  # True Male, Pred Female
table[(2, 1)].set_facecolor("lightblue")  # True Male, Pred Male

plt.title("RMSE between Predictions and Ground truth", pad=20)
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Log-transform the data
log_male_gt = np.log1p(rep_1_gt)
log_cf_pred = np.log1p(cf_pred)
log_baseline_pred = np.log1p(baseline_pred)

# Scatter plot of log-transformed ground truth vs. predictions
fig, ax = plt.subplots(figsize=(10, 6))

sns.scatterplot(
    x=log_male_gt, y=log_cf_pred,
    alpha=0.7, ax=ax
)

sns.scatterplot(
    x=log_male_gt, y=log_baseline_pred,
    alpha=0.7, ax=ax
)

# Identity line (in log space)
mn = min(log_male_gt.min(), log_male_gt.min())
mx = max(log_male_gt.max(), log_male_gt.max())
ax.plot([mn, mx], [mn, mx], ls="--", color="red")


ax.set_xlabel("Log Ground Truth (Male)")
ax.set_ylabel("Log Predicted Mean Expression")
ax.set_title("Leave One Out analysis" )
ax.legend(loc="upper left")
plt.tight_layout()
ax.grid(False)
plt.show()
