In [None]:
import os 
os.chdir("../")
import warnings

In [None]:
pip install scarches numpy anndata scvi pandas==2.2.3

In [None]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown

In [None]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [None]:
adata= sc.read('/work/trvae_new/New_fixed_data/healthy_hamstring_processed_adata_raw.h5ad')

In [None]:
adata

In [None]:
# Subset the data
subset_mask = (adata.obs['sex'] == 'male') & (adata.obs['cell_type'] == 'slow muscle cell')
subset_adata = adata[subset_mask].copy()  
train_adata = adata[~subset_mask].copy()  # Training data is everything except the subset

In [None]:
check_mask = (train_adata.obs["sex"] == "male") & (train_adata.obs["cell_type"] == "slow muscle cell")


num_bad_cells = check_mask.sum()

print(f"Number of male slow muscle cells in train_adata: {num_bad_cells}")


In [None]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0.001,
    "patience": 40,
    "reduce_lr": True,
    "lr_patience": 25,
    "lr_factor": 0.1,
}

In [None]:
train_adata

In [None]:
# 2) Make sure 'sex' is categorical/string
train_adata.obs['sex'] = train_adata.obs['sex'].astype(str)
conditions = ['female', 'male']                 # both levels appear in train_adata

# 3) Initialize TRVAE conditioned on sex
trvae = sca.models.TRVAE(
    adata         = train_adata,
    condition_key = 'sex',
    conditions    = conditions,
    hidden_layer_sizes = [128, 128],
)

# 4) Train exactly as before
trvae.train(n_epochs=200, alpha_epoch_anneal=200, early_stopping_kwargs=early_stopping_kwargs)

In [None]:
adata

In [None]:
trvae

In [None]:
trvae.save("/work/trvae_new/LOO_trVAE/LOO_HH_trvae_Gpu_run.h5ad")

In [None]:
from scarches.trainers.trvae._utils import make_dataset, custom_collate

In [None]:
# With Gpu run this instead: 

def predict_trvae(model, adata, condition_key, batch_size=128):
    # evaluation mode
    model.model.eval()

    # Create a dataset and dataloader for prediction
    predict_data, _ = make_dataset(
        adata,
        train_frac=1.0,
        condition_key=condition_key,
        cell_type_keys=None, 
        condition_encoder=model.model.condition_encoder,
        cell_type_encoder=None, 
    )
    # Create dataloader 
    dataloader = torch.utils.data.DataLoader(
        dataset=predict_data,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=custom_collate,
        num_workers=0,
    )

    # store results
    latent_list = []
    reconstructed_list = []


    # Perform prediction, moves each part of the data that the device the model is trained on 
    with torch.no_grad():
        for batch_data in dataloader:
            for k,v in batch_data.items():
                batch_data[k] = v.to(model.trainer.device)

            # sum across features → shape [batch_size]
            sf = batch_data["x"].sum(dim=1)  
            # expand into [batch_size, n_genes]
            size_factor_view = sf.unsqueeze(1).expand(
                batch_data["x"].size(0),
                batch_data["x"].size(1)
            )

            # log‐transform
            x_log = torch.log1p(batch_data["x"])
            z1_mean, z1_log_var = model.model.encoder(x_log, batch_data["batch"])
            latent = model.model.sampling(z1_mean, z1_log_var)
            latent_list.append(latent.cpu().numpy())

            outputs = model.model.decoder(latent, batch_data["batch"])
            recon_x, _ = outputs
            sf_rate = size_factor_view * recon_x
            reconstructed_list.append(sf_rate.cpu().numpy())


            

    latent = np.concatenate(latent_list, axis=0)
    reconstructed = np.concatenate(reconstructed_list, axis=0)

    return latent, reconstructed

In [None]:
# 1. Select your baseline (female slow‐muscle) cells from train_adata
base_adata = train_adata[
    (train_adata.obs['cell_type'] == 'slow muscle cell') &
    (train_adata.obs['sex']       == 'female')
].copy()

In [None]:

# 2. Make sure you’ve got predict_trvae in scope (run its cell!), then call:
latent_base, rec_base = predict_trvae(
    trvae,
    base_adata,
    condition_key="sex",
    batch_size=128
)


In [None]:
# 3. Flip the sex label on the exact same cells
cf_adata = base_adata.copy()
cf_adata.obs['sex'] = 'male'


In [None]:

latent_cf, rec_cf = predict_trvae(
    trvae,
    cf_adata,
    condition_key='sex',
    batch_size=128
)


In [None]:
import numpy as np


female_gt = (
    base_adata.X.toarray().mean(axis=0)
    if hasattr(base_adata.X, "toarray")
    else base_adata.X.mean(axis=0)
)


#    Gt for the male is now the held out dataset
male_gt = (
    subset_adata.X.toarray().mean(axis=0)
    if hasattr(subset_adata.X, "toarray")
    else subset_adata.X.mean(axis=0)
)

In [None]:
# 2. Predicted means
base_mean = rec_base.mean(axis=0).ravel()
cf_mean   = rec_cf.mean(axis=0).ravel()

In [None]:
f_m = np.sqrt(np.mean((base_mean - male_gt)**2)) #pred-F --- TM
m_m = np.sqrt(np.mean((cf_mean - male_gt)**2)) # pred_M --- TM
f_f  =np.sqrt(np.mean((base_mean - female_gt)**2)) # pref F ----- TF
m_f = np.sqrt(np.mean((cf_mean - female_gt)**2)) # pref M ---- TF

print(f"RMSE pred F --- TM: {f_m:.4f}")
print(f"RMSE pred M --- TM (counterfactual)  : {m_m:.4f}")
print(f"RMSE pred F ----- TF : {f_f:.4f}")
print(f"RMSE pred M ---- TF (counterfactual)  : {m_f:.4f}")

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Log-transform the data
log_male_gt = np.log1p(male_gt)
log_cf_pred = np.log1p(cf_mean)
log_baseline_pred = np.log1p(base_mean)

# Scatter plot of log-transformed ground truth vs. predictions
fig, ax = plt.subplots(figsize=(10, 6))

sns.scatterplot(
    x=log_male_gt, y=log_cf_pred,
    alpha=0.7, ax=ax
)

sns.scatterplot(
    x=log_male_gt, y=log_baseline_pred,
    alpha=0.7, ax=ax
)

# Identity line (in log space)
mn = min(log_male_gt.min(), log_male_gt.min())
mx = max(log_male_gt.max(), log_male_gt.max())
ax.plot([mn, mx], [mn, mx], ls="--", color="red")


ax.set_xlabel("Log Ground Truth (Male)")
ax.set_ylabel("Log Predicted Mean Expression")
ax.set_title("Leave One Out analysis" )
ax.legend(loc="upper left")
ax.grid(False)
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Define RMSE values
rmse_data = [
    ["0.0295", "0.2754"],  # True Female
    ["0.3912", "0.3333"],  # True Male
]

# Define labels
column_labels = ["Pred Female", "Pred Male (CF)"]
row_labels = ["True Female", "True Male"]

# Create the figure and axis
fig, ax = plt.subplots(figsize=(6, 2))
ax.axis('tight')
ax.axis('off')

# Create the table
table = ax.table(
    cellText=rmse_data,
    rowLabels=row_labels,
    colLabels=column_labels,
    cellLoc='center',
    loc='center'
)

table.scale(1, 2)  # Increase row height
table.auto_set_font_size(False)
table.set_fontsize(12)

# Color the lower row cells
# Note: Rows and columns are 1-indexed in table.get_celld()
table[(2, 0)].set_facecolor("orange")  # True Male, Pred Female
table[(2, 1)].set_facecolor("lightblue")  # True Male, Pred Male

plt.title("RMSE between Predictions and Ground truth", pad=20)
plt.tight_layout()
plt.show()
