In [None]:
pip install session_info

In [None]:
import session_info

In [None]:
pip install git+https://github.com/theislab/scgen.git

In [None]:
import scanpy as sc
import torch
import logging
import scgen 
import sklearn
import seaborn as sns
import torch
import warnings
import os
import sys
import re





#import numpy as np

# Remember to downgrade scvi-tools (Sometimes need to downgrade not always, use pip install scvi-tools 1.6, 1.1.1 ) 
# sqrt issue in latent space
# 2. Download scgen (not development version) --use that one for now

In [None]:
session_info.show()

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error

In [None]:
adata = sc.read("/work/scGen_Human_vascular/new_data_fix_may/focal_cortical_processed_RAW.h5ad")

In [None]:
adata

In [None]:
# Subset the data
subset_mask = (adata.obs['sex'] == 'male') & (adata.obs['cell_type'] == 'microglial cell')
subset_adata = adata[subset_mask].copy()  
train_adata = adata[~subset_mask].copy()  # Training data is everything except the subset

In [None]:
scgen.SCGEN.setup_anndata(train_adata, batch_key="sex", labels_key="cell_type")

In [None]:
model = scgen.SCGEN(train_adata)
#model.save("scGen_Human_vascular/LOA/scgen_focal_LOA", overwrite=True)

In [None]:
model.train(
    max_epochs=300,
    early_stopping=True,
    early_stopping_patience=100,
)


In [None]:
model = torch.load("/work/scGen_Human_vascular/work/scGen_Human_vascular/saved_models/model_perturbation_2/", adata = train)


In [None]:
model.save("scGen_Human_vascular/work/scGen_Human_vascular/LOA_models_gou/scGen_Focal_cortical_LOA_FINAL.pt", overwrite=True)

In [None]:
model.save("scGen_Human_vascular/work/scGen_Human_vascular/new_fixed_models/", overwrite=True)

In [None]:
model.is_trained_ = True
model.is_trained = True
model

In [None]:
female_micro = train_adata[
    (train_adata.obs["sex"] == "female") &
    (train_adata.obs["cell_type"] == "microglial cell")
].copy()

In [None]:
base_pred, _ = model.predict(
    ctrl_key="female",        # encode cells as female
    stim_key="female",        # then decode them _still_ as female
    adata_to_predict=female_micro
)


In [None]:
cf = female_micro.copy()
cf.obs["sex"] = "male"

In [None]:
cf_pred, _ = model.predict(
    ctrl_key="female",        # encode those same cells as female
    stim_key="male",          # then apply the learned female→male shift
    adata_to_predict=cf
)

In [None]:
# Compute per-gene ground-truth means
# Gt for females is now females with hepatocytes

female_gt = (
    female_micro.X.toarray().mean(axis=0)
    if hasattr(female_micro.X, "toarray")
    else female_micro.X.mean(axis=0)
)


#    Gt for the male is now the held out dataset
male_gt = (
    subset_adata.X.toarray().mean(axis=0)
    if hasattr(subset_adata.X, "toarray")
    else subset_adata.X.mean(axis=0)
)


In [None]:
import numpy as np

In [None]:
#  Compute per-gene prediction means 
baseline_pred = np.asarray(base_pred.X.mean(axis=0)).ravel()
cf_pred       = np.asarray(cf_pred.X.mean(axis=0)).ravel()

In [None]:
# RMSEs
rmse_baseline = np.sqrt(np.mean((baseline_pred - female_gt)**2))
rmse_cf       = np.sqrt(np.mean((cf_pred       - male_gt)**2))

print(f"RMSE baseline      (female→female): {rmse_baseline:.4f}")
print(f"RMSE counterfactual(female→male)  : {rmse_cf:.4f}")

In [None]:
f_m = np.sqrt(np.mean((baseline_pred - male_gt)**2)) #pred-F --- TM
m_m = rmse_cf= np.sqrt(np.mean((cf_pred - male_gt)**2)) # pred_M --- TM
f_f  =np.sqrt(np.mean((baseline_pred - female_gt)**2)) # pref F ----- TF
m_f = np.sqrt(np.mean((cf_pred - female_gt)**2)) # pref M ---- TF

print(f"RMSE pred F --- TM: {f_m:.4f}")
print(f"RMSE pred M --- TM (counterfactual)  : {m_m:.4f}")
print(f"RMSE pred F ----- TF : {f_f:.4f}")
print(f"RMSE pred M ---- TF (counterfactual)  : {m_f:.4f}")

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Log-transform the data
log_male_gt = np.log1p(male_gt)
log_cf_pred = np.log1p(cf_pred)
log_baseline_pred = np.log1p(baseline_pred)

# Scatter plot of log-transformed ground truth vs. predictions
fig, ax = plt.subplots(figsize=(10, 6))

sns.scatterplot(
    x=log_male_gt, y=log_cf_pred,
    alpha=0.7, ax=ax
)

sns.scatterplot(
    x=log_male_gt, y=log_baseline_pred,
    alpha=0.7, ax=ax
)

# Identity line (in log space)
mn = min(log_male_gt.min(), log_male_gt.min())
#mx = max(log_male_gt.max(), log_male_gt.max())
ax.plot([mn, mx], [mn, mx], ls="--", color="red")


ax.set_xlabel("Log Ground Truth (Male)")
ax.set_ylabel("Log Predicted Mean Expression")
ax.set_title("Leave One Out analysis" )
ax.legend(loc="upper left")
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Define RMSE values
rmse_data = [
    ["0.0183", "0.1348"],  # True Female
    ["0.1983", "0.1081"],  # True Male
]

# Define labels
column_labels = ["Pred Female", "Pred Male (CF)"]
row_labels = ["True Female", "True Male"]

# Create the figure and axis
fig, ax = plt.subplots(figsize=(6, 2))
ax.axis('tight')
ax.axis('off')

# Create the table
table = ax.table(
    cellText=rmse_data,
    rowLabels=row_labels,
    colLabels=column_labels,
    cellLoc='center',
    loc='center'
)

table.scale(1, 2)  # Increase row height
table.auto_set_font_size(False)
table.set_fontsize(12)

# Color the lower row cells
# Note: Rows and columns are 1-indexed in table.get_celld()
table[(2, 0)].set_facecolor("orange")  # True Male, Pred Female
table[(2, 1)].set_facecolor("lightblue")  # True Male, Pred Male

plt.title("RMSE between Predictions and Ground truth", pad=20)
plt.tight_layout()
plt.show()
