In [None]:
pip install biolord

In [None]:
import warnings
import os
import sys
import re
import biolord
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import sklearn
import torch


from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    r2_score,
    root_mean_squared_error,
    mutual_info_score
)

In [None]:
# Setup the AnnData

adata = sc.read("/work/Biolord_all/new_data_fixed_normalized/scAtlas_Human_vascular_cells_processed_normalized.h5ad")

In [None]:
adata.obs["cell_type"]

In [None]:
# Subset the data
subset_mask = (adata.obs['sex'] == 'male') & (adata.obs['cell_type'] == 'endothelial cell')
subset_adata = adata[subset_mask].copy()  
train_adata = adata[~subset_mask].copy()  # Training data is everything except the subset

In [None]:
biolord.Biolord.setup_anndata(
    train_adata, ordered_attributes_keys=None, categorical_attributes_keys=["cell_type", "fat_type", "sex", "bmi_group", "donor_id", "surgery","tissue"]
)

In [None]:
# Instantiate a Biolord model

module_params = {
    "decoder_width": 1024,
    "decoder_depth": 4,
    "attribute_nn_width": 512,
    "attribute_nn_depth": 2,
    "n_latent_attribute_categorical": 4,
    "gene_likelihood": "normal",
    "reconstruction_penalty": 1e2,
    "unknown_attribute_penalty": 1e1,
    "unknown_attribute_noise_param": 1e-1,
    "attribute_dropout_rate": 0.1,
    "use_batch_norm": False,
    "use_layer_norm": False,
    "seed": 42,
}

In [None]:
model = biolord.Biolord(
    adata=train_adata,
    n_latent=32,
    model_name="Sc_Atlas_Biolord_run_lOA_fix_new_data_0905",
    module_params=module_params,
    train_classifiers=False
)

In [None]:
# Train the model

trainer_params = {
    "n_epochs_warmup": 0,
    "latent_lr": 1e-4,
    "latent_wd": 1e-4,
    "decoder_lr": 1e-4,
    "decoder_wd": 1e-4,
    "attribute_nn_lr": 1e-2,
    "attribute_nn_wd": 4e-8,
    "step_size_lr": 45,
    "cosine_scheduler": True,
    "scheduler_final_lr": 1e-5,
}

In [None]:
model.train(
    max_epochs=500,
    batch_size=512,
    plan_kwargs=trainer_params,
    early_stopping=True,
    early_stopping_patience=20,
    check_val_every_n_epoch=10,
    num_workers=63,
    enable_checkpointing=False,
)

In [None]:
# Save the model
model.save(dir_path="Biolord_ScAtlas_data/LOO_scatlas_fixed_data_final" + f"{model.model_name}_model/")

In [None]:
# Load model
#model.load("/work/Biolord_ScAtlas_data/Biolord_ScAtlas_data/Output_pathSc_Atlas_Biolord_run_model/", adata = adata)
# Load model and check if epoch history is present
model = model.load("/work/scAtlas_runs/Biolord_ScAtlas_data/Biolord_ScAtlas_data/LOO_scatlas_fixed_data_finalSc_Atlas_Biolord_run_lOA_fix_new_data_0905_model/", adata=train_adata)

In [None]:
model

In [None]:
size = 4
vals = ["generative_mean_accuracy", "generative_var_accuracy", "biolord_metric"]
fig, axs = plt.subplots(nrows=1, ncols=len(vals), figsize=(size * len(vals), size))

model.epoch_history = pd.DataFrame().from_dict(model.training_plan.epoch_history)
for i, val in enumerate(vals):
    sns.lineplot(
        x="epoch",
        y=val,
        hue="mode",
        data=model.epoch_history[model.epoch_history["mode"] == "valid"],
        ax=axs[i],
    )

plt.tight_layout()
plt.show()

In [None]:
df = pd.read_csv("/work/Biolord_ScAtlas_data/Biolord_ScAtlas_data/Output_pathSc_Atlas_Biolord_run_model/history.csv")


In [None]:
sns.lineplot(df,x="epoch", y= "reconstruction_loss", hue="mode")

In [None]:
sns.lineplot(df,x="epoch", y= "biolord_metric", hue="mode")

In [None]:
female_slow = train_adata[
    (train_adata.obs["sex"] == "female") &
    (train_adata.obs["cell_type"] == "endothelial cell")
].copy()

In [None]:
rec_female_slow,_ = model.predict(female_slow, batch_size=256)

In [None]:
cf = female_slow.copy()
cf.obs["sex"] = "male"

In [None]:
cf_pred,_ = model.predict(cf, batch_size=256)

In [None]:
# Compute per-gene ground-truth means
# Gt for females is now females with hepatocytes

female_gt = (
    female_slow.X.toarray().mean(axis=0)
    if hasattr(female_slow.X, "toarray")
    else female_slow.X.mean(axis=0)
)


#    Gt for the male is now the held out dataset
male_gt = (
    subset_adata.X.toarray().mean(axis=0)
    if hasattr(subset_adata.X, "toarray")
    else subset_adata.X.mean(axis=0)
)

In [None]:
#  Compute per-gene prediction means 
baseline_pred = np.asarray(rec_female_slow.X.mean(axis=0)).ravel()
cf_pred       = np.asarray(cf_pred.X.mean(axis=0)).ravel()

In [None]:
f_m = np.sqrt(np.mean((baseline_pred - male_gt)**2)) #pred-F --- TM
m_m = np.sqrt(np.mean((cf_pred - male_gt)**2)) # pred_M --- TM
f_f  =np.sqrt(np.mean((baseline_pred - female_gt)**2)) # pref F ----- TF
m_f = np.sqrt(np.mean((cf_pred - female_gt)**2)) # pref M ---- TF

print(f"RMSE pred F --- TM: {f_m:.4f}")
print(f"RMSE pred M --- TM (counterfactual)  : {m_m:.4f}")
print(f"RMSE pred F ----- TF : {f_f:.4f}")
print(f"RMSE pred M ---- TF (counterfactual)  : {m_f:.4f}")

In [None]:
import matplotlib.pyplot as plt

# Define RMSE values
rmse_data = [
    ["0.0173", "0.0212"],  # True Female
    ["0.0567", "0.0568"],  # True Male
]

# Define labels
column_labels = ["Pred Female", "Pred Male (CF)"]
row_labels = ["True Female", "True Male"]

# Create the figure and axis
fig, ax = plt.subplots(figsize=(6, 2))
ax.axis('tight')
ax.axis('off')

# Create the table
table = ax.table(
    cellText=rmse_data,
    rowLabels=row_labels,
    colLabels=column_labels,
    cellLoc='center',
    loc='center'
)

table.scale(1, 2)  # Increase row height
table.auto_set_font_size(False)
table.set_fontsize(12)

# Color the lower row cells
# Note: Rows and columns are 1-indexed in table.get_celld()
table[(2, 0)].set_facecolor("orange")  # True Male, Pred Female
table[(2, 1)].set_facecolor("lightblue")  # True Male, Pred Male

plt.title("RMSE between Predictions and Ground truth", pad=20)
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Log-transform the data
log_male_gt = np.log1p(male_gt)
log_cf_pred = np.log1p(cf_pred)
log_baseline_pred = np.log1p(baseline_pred)

# Scatter plot of log-transformed ground truth vs. predictions
fig, ax = plt.subplots(figsize=(10, 6))

sns.scatterplot(
    x=log_male_gt, y=log_cf_pred,
    alpha=0.7, ax=ax
)

sns.scatterplot(
    x=log_male_gt, y=log_baseline_pred,
    alpha=0.7, ax=ax
)

# Identity line (in log space)
mn = min(log_male_gt.min(), log_male_gt.min())
mx = max(log_male_gt.max(), log_male_gt.max())
ax.plot([mn, mx], [mn, mx], ls="--", color="red")


ax.set_xlabel("Log Ground Truth (Male)")
ax.set_ylabel("Log Predicted Mean Expression")
ax.set_title("Leave One Out analysis" )
ax.legend(loc="upper left")
plt.tight_layout()
plt.show()
