In [None]:
import os
import tempfile
import scanpy as sc
import scvi
import seaborn as sns
import torch

import numpy as np
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

In [None]:
#scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__) # 1.2.2 when running cpu but 1.2.1 when running GPU

In [None]:
sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

In [None]:
adata = sc.read("/work/SCVI_models/scAtlas_data/atlas_raw/scAtlas_Human_vascular_cells_processed_RAW_1.h5ad")

In [None]:
adata

In [None]:
# Subset the data
subset_mask = (adata.obs['sex'] == 'male') & (adata.obs['cell_type'] == 'endothelial cell')
subset_adata = adata[subset_mask].copy()  
train_adata = adata[~subset_mask].copy()  # Training data is everything except the subset

In [None]:
train_adata

In [None]:
subset_adata

In [None]:
subset_adata.obs["cell_type"]

In [None]:
print(train_adata.obs.groupby(['sex', 'cell_type']).size())

In [None]:
train_adata.obs["cell_type"]

In [None]:
scvi.model.SCVI.setup_anndata(
    train_adata,
    layer=None,
    categorical_covariate_keys=["cell_type", "sex", "self_reported_ethnicity", "bmi_group", "donor_id", "surgery", "fat_type", "tissue"],
    continuous_covariate_keys=None,
)

In [None]:
model = scvi.model.SCVI(train_adata, n_layers=2, n_latent=30, gene_likelihood="nb")

In [None]:
model

In [None]:
model.train()

In [None]:
model.save("/work/SCVI_models/LOO_models_SCVI_NEW/SCVI_scAtlas_LOO_raw_fixed_data")

In [None]:
cf_adata.obs.columns.get_loc("sex")

In [None]:
model

In [None]:
cf_adata.obs.columns

In [None]:
# 1. Define the two subsets:

#    - female slow‐muscle cells (these are in the training AnnData)
female_endothelial = train_adata[(train_adata.obs["sex"] == "female") &
                          (train_adata.obs["cell_type"] == "endothelial cell")].copy()

In [None]:
library_female=female_endothelial.X.sum(axis = 1)

In [None]:
# 2. Baseline prediction: leave them as female
#    would be same as rec2 in biolord
y_pred_base = model.get_normalized_expression(
    female_endothelial, return_numpy=True
)  # shape: (n_cells, n_genes)

In [None]:
library_female=female_endothelial.X.sum(axis = 1)

In [None]:
library_female = np.array(library_female.flatten())

In [None]:
rec_female = (library_female * y_pred_base.T ).T

In [None]:
# 3. Counterfactual prediction: flip sex → male
cf = female_endothelial.copy()
cf.obs["sex"] = "male"

In [None]:

y_pred_cf = model.get_normalized_expression(
    cf, return_numpy=True
)

In [None]:
library_counter=cf.X.sum(axis = 1)

In [None]:
library_counter = np.array(library_counter.flatten())

In [None]:
rec_cf_male = (library_counter * y_pred_cf.T ).T

In [None]:
#  Compute per-gene prediction means 
baseline_pred_mean = np.asarray(rec_female.mean(axis=0)).ravel()
cf_pred_mean = np.asarray(rec_cf_male.mean(axis=0)).ravel()

In [None]:
# Compute per-gene ground-truth means
# Gt for females is now females with hepatocytes

female_gt = (
    female_endothelial.X.toarray().mean(axis=0)
    if hasattr(female_endothelial.X, "toarray")
    else female_endothelial.X.mean(axis=0)
)


#    Gt for the male is now the held out dataset
male_gt = (
    subset_adata.X.toarray().mean(axis=0)
    if hasattr(subset_adata.X, "toarray")
    else subset_adata.X.mean(axis=0)
)


In [None]:
f_m = np.sqrt(np.mean((baseline_pred_mean - male_gt)**2)) #pred-F --- TM
m_m = np.sqrt(np.mean((cf_pred_mean - male_gt)**2)) # pred_M --- TM
f_f  =np.sqrt(np.mean((baseline_pred_mean - female_gt)**2)) # pref F ----- TF
m_f = np.sqrt(np.mean((cf_pred_mean - female_gt)**2)) # pref M ---- TF

print(f"RMSE pred F --- TM: {f_m:.4f}")
print(f"RMSE pred M --- TM (counterfactual)  : {m_m:.4f}")
print(f"RMSE pred F ----- TF : {f_f:.4f}")
print(f"RMSE pred M ---- TF (counterfactual)  : {m_f:.4f}")

In [None]:
import matplotlib.pyplot as plt

# Define RMSE values
rmse_data = [
    ["0.0226", "0.1314"],  # True Female
    ["0.1464", "0.2254"],  # True Male
]

# Define labels
column_labels = ["Pred Female", "Pred Male (CF)"]
row_labels = ["True Female", "True Male"]

# Create the figure and axis
fig, ax = plt.subplots(figsize=(6, 2))
ax.axis('tight')
ax.axis('off')

# Create the table
table = ax.table(
    cellText=rmse_data,
    rowLabels=row_labels,
    colLabels=column_labels,
    cellLoc='center',
    loc='center'
)

table.scale(1, 2)  # Increase row height
table.auto_set_font_size(False)
table.set_fontsize(12)

# Color the lower row cells
# Note: Rows and columns are 1-indexed in table.get_celld()
table[(2, 0)].set_facecolor("orange")  # True Male, Pred Female
table[(2, 1)].set_facecolor("lightblue")  # True Male, Pred Male

plt.title("RMSE between Predictions and Ground truth", pad=20)
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Ensure arrays are 1D
log_male_gt = np.log1p(male_gt).ravel()
log_cf_pred = np.log1p(cf_pred_mean).ravel()
log_baseline_pred = np.log1p(baseline_pred_mean).ravel()

# Create the figure and axis
fig, ax = plt.subplots(figsize=(10, 6))

# Scatter: Counterfactual (Female → Male)
sns.scatterplot(
    x=log_male_gt, y=log_cf_pred,
    alpha=0.7, ax=ax
)

# Scatter: Baseline (Female → Female)
sns.scatterplot(
    x=log_male_gt, y=log_baseline_pred,
    alpha=0.7, ax=ax
)

# Identity line
mn = min(log_male_gt.min(), log_male_gt.min())
mx = max(log_male_gt.max(), log_male_gt.max())
ax.plot([mn, mx], [mn, mx], ls="--", color="red")

# Labels and legend
ax.set_xlabel("Log Ground Truth (Male)")
ax.set_ylabel("Log Predicted Mean Expression")
ax.set_title("Leave-One-Out Analysis")
ax.legend(loc="upper left")
plt.tight_layout()
plt.show()
