In [None]:
pip install session_info

In [None]:
import session_info

In [None]:
pip install git+https://github.com/theislab/scgen.git

In [None]:
pip install scvi-tools==1.1.1

In [None]:
import scanpy as sc
import torch
import logging
import scgen # Development version only works!!!!!!!!! Confirmed 25
import sklearn
import seaborn as sns
import torch
import warnings
import os
import sys
import re





#import numpy as np

# Remember to downgrade scvi-tools (Sometimes need to downgrade not always, use pip install scvi-tools 1.6, 1.1.1 ) 
# sqrt issue in latent space

# 1. Download scanpy
# 2. Download scgen (not development version) --use that one for now

In [None]:
session_info.show()

In [None]:
from sklearn.metrics import mean_squared_error,mean_absolute_error

In [None]:
adata = sc.read("/work/scGen_Human_vascular/new_data_fix_may/HTAPP_997_processed_raw_FINAL.h5ad")

In [None]:
adata.obs["replicate"]

In [None]:
adata.obs

In [None]:
# Subset the data
subset_mask = (adata.obs['replicate'] == '1') & (adata.obs['cell_type'] == 'hepatocyte')
subset_adata = adata[subset_mask].copy()  
train_adata = adata[~subset_mask].copy()  # Training data is everything except the subset

In [None]:
adata

In [None]:
adata.obs["Phase"]

In [None]:
print(train_adata.obs.groupby(['replicate', 'cell_type']).size())

In [None]:
print("Max:", adata.X.max())
print("Mean:", adata.X.mean())

In [None]:
import numpy as np

# Convert to dense temporarily
dense_X = adata.X.toarray() if hasattr(adata.X, "toarray") else adata.X

# Check if all values are integers
is_integer_counts = np.all(dense_X == np.floor(dense_X))

print("Is count matrix all integers:", is_integer_counts)


In [None]:
scgen.SCGEN.setup_anndata(train_adata, batch_key="replicate", labels_key="cell_type")

In [None]:
model = scgen.SCGEN(train_adata)
#model.save("scGen_Human_vascular/LOA/LOA_scGen_HTAPP", overwrite=True)

In [None]:
model.train(
    max_epochs=400,
    early_stopping=False,
    early_stopping_patience=25, 
)


In [None]:
model.save("scGen_Human_vascular/work/scGen_Human_vascular/LOA_HTAPP_fixed_rub/", overwrite=True)

In [None]:
scgen.SCGEN.load("/work/scGen_Human_vascular/LOA/scGen_Human_vascular/work/scGen_Human_vascular/LOA_models_gou/scGen_HTAPP_LOA_FINAL",
                 adata = train_adata)


In [None]:
hepatocyte_2 = train_adata[
    (train_adata.obs["replicate"] == "2") &
    (train_adata.obs["cell_type"] == "hepatocyte")
].copy()

In [None]:
hepatocyte_2

In [None]:
baseline_adata,_ = model.predict(
    ctrl_key="2",        # encode cells as replicate 2
    stim_key="2",        # then decode them _still_ as replicate 2
    adata_to_predict=hepatocyte_2
)

In [None]:
cf = hepatocyte_2.copy()
cf.obs["replicate"] = "1"

In [None]:
cf.obs["replicate"]

In [None]:
cf_pred, _ = model.predict(
    ctrl_key="2",        # encode those same cells as rep 2
    stim_key="1",          
    adata_to_predict=cf
)

In [None]:
# Compute per-gene ground-truth means
# Gt for females is now females with hepatocytes

rep_2_gt = (
    hepatocyte_2.X.toarray().mean(axis=0)
    if hasattr(hepatocyte_2.X, "toarray")
    else hepatocyte_2.X.mean(axis=0)
)


#    Gt for the male is now the held out dataset
rep_1_gt = (
    subset_adata.X.toarray().mean(axis=0)
    if hasattr(subset_adata.X, "toarray")
    else subset_adata.X.mean(axis=0)
)


In [None]:
subset_adata

In [None]:
import numpy as np

In [None]:
#  Compute per-gene prediction means 
baseline_pred = np.asarray(baseline_adata.X.mean(axis=0)).ravel()
cf_pred       = np.asarray(cf_pred.X.mean(axis=0)).ravel()

In [None]:
# RMSEs
rmse_baseline = np.sqrt(np.mean((baseline_pred-rep_1_gt)**2))
rmse_cf       = np.sqrt(np.mean((cf_pred-rep_1_gt)**2))

print(f"RMSE baseline      (rep2→rep2): {rmse_baseline:.4f}")
print(f"RMSE counterfactual(rep2→rep1)  : {rmse_cf:.4f}")

In [None]:
print("Max ground truth:", np.max(rep_1_gt))
print("Min ground truth:", np.min(rep_1_gt))
print("How many zeros:", np.sum(rep_1_gt == 0))


In [None]:
f_m = np.sqrt(np.mean((baseline_pred - rep_1_gt)**2)) #pred-F --- TM
m_m = np.sqrt(np.mean((cf_pred - rep_1_gt)**2)) # pred_M --- TM
f_f  =np.sqrt(np.mean((baseline_pred - rep_2_gt)**2)) # pref F ----- TF
m_f = np.sqrt(np.mean((cf_pred - rep_2_gt)**2)) # pref M ---- TF

print(f"RMSE pred F --- TM: {f_m:.4f}")
print(f"RMSE pred M --- TM (counterfactual)  : {m_m:.4f}")
print(f"RMSE pred F ----- TF : {f_f:.4f}")
print(f"RMSE pred M ---- TF (counterfactual)  : {m_f:.4f}")

In [None]:
import matplotlib.pyplot as plt

# Define RMSE values
rmse_data = [
    ["0.0255", "0.1887"],  # True Female
    ["0.9332", "0.8605"],  # True Male
]

# Define labels
column_labels = ["Pred Replicate 2", "Pred Replicate 1 (CF)"]
row_labels = ["True Replicate 2", "True Replicate 1"]

# Create the figure and axis
fig, ax = plt.subplots(figsize=(6, 2))
ax.axis('tight')
ax.axis('off')

# Create the table
table = ax.table(
    cellText=rmse_data,
    rowLabels=row_labels,
    colLabels=column_labels,
    cellLoc='center',
    loc='center'
)

table.scale(1, 2)  # Increase row height
table.auto_set_font_size(False)
table.set_fontsize(12)

# Color the lower row cells
# Note: Rows and columns are 1-indexed in table.get_celld()
table[(2, 0)].set_facecolor("orange")  # True Male, Pred Female
table[(2, 1)].set_facecolor("lightblue")  # True Male, Pred Male

plt.title("RMSE between Predictions and Ground truth", pad=20)
plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Log-transform the data
log_male_gt = np.log1p(rep_1_gt)
log_cf_pred = np.log1p(cf_pred)
log_baseline_pred = np.log1p(baseline_pred)

# Scatter plot of log-transformed ground truth vs. predictions
fig, ax = plt.subplots(figsize=(10, 6))

sns.scatterplot(
    x=log_male_gt, y=log_cf_pred,
    alpha=0.7, ax=ax
)

sns.scatterplot(
    x=log_male_gt, y=log_baseline_pred,
    alpha=0.7, ax=ax
)

# Identity line (in log space)
mn = min(log_male_gt.min(), log_male_gt.min())
mx = max(log_male_gt.max(), log_male_gt.max())
ax.plot([mn, mx], [mn, mx], ls="--", color="red")


ax.set_xlabel("Log Ground Truth (Male)")
ax.set_ylabel("Log Predicted Mean Expression")
ax.set_title("Leave One Out analysis" )
ax.legend(loc="upper left")
plt.tight_layout()
plt.show()
