In [None]:
pip install lamindb

In [None]:
import os 
os.chdir("../")
import warnings

In [None]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown

In [None]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [None]:
adata= sc.read('/work/trvae_new/New_fixed_data/scAtlas_Human_vascular_cells_processed_RAW_1.h5ad')

In [None]:
adata

In [None]:
# Train/test split
from sklearn.model_selection import train_test_split
train_ids, test_ids = train_test_split(adata.obs_names, test_size=0.1, random_state=42)
adata.obs["split"] = "train"
adata.obs.loc[test_ids, "split"] = "test"

train_adata = adata[adata.obs["split"] == "train"]
test_adata = adata[adata.obs["split"] == "test"]

In [None]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}


In [None]:
trvae = sca.models.TRVAE(
    adata=train_adata,
    condition_key="donor_id",
    conditions=train_adata.obs["donor_id"].unique().tolist(), 
    hidden_layer_sizes=[128, 128],
)
trvae.train(n_epochs=300, alpha_epoch_anneal=200, early_stopping_kwargs=early_stopping_kwargs)

In [None]:
trvae.save("trvae_new/fixed_models/trvae_scAtlas_raw_model_batch_remove")

In [None]:

trvae.load("/work/trvae_new/new_model_runs_GPU/trVAE_scAtlas_new", adata=train_adata, map_location=torch.device("cpu"))


In [None]:
model = trvae

In [None]:
from scarches.trainers.trvae._utils import make_dataset, custom_collate

In [None]:
# With Gpu run this instead: 

def predict_trvae(model, adata, condition_key, batch_size=128):
    # evaluation mode
    model.model.eval()

    # Create a dataset and dataloader for prediction
    predict_data, _ = make_dataset(
        adata,
        train_frac=1.0,
        condition_key=condition_key,
        cell_type_keys=None, 
        condition_encoder=model.model.condition_encoder,
        cell_type_encoder=None, 
    )
    # Create dataloader 
    dataloader = torch.utils.data.DataLoader(
        dataset=predict_data,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=custom_collate,
        num_workers=0,
    )

    # store results
    latent_list = []
    reconstructed_list = []


    # Perform prediction, moves each part of the data that the device the model is trained on 
    with torch.no_grad():
        for batch_data in dataloader:
            for k,v in batch_data.items():
                batch_data[k] = v.to(model.trainer.device)

            # sum across features → shape [batch_size]
            sf = batch_data["x"].sum(dim=1)  
            # expand into [batch_size, n_genes]
            size_factor_view = sf.unsqueeze(1).expand(
                batch_data["x"].size(0),
                batch_data["x"].size(1)
            )

            # log‐transform
            x_log = torch.log1p(batch_data["x"])
            z1_mean, z1_log_var = model.model.encoder(x_log, batch_data["batch"])
            latent = model.model.sampling(z1_mean, z1_log_var)
            latent_list.append(latent.cpu().numpy())

            outputs = model.model.decoder(latent, batch_data["batch"])
            recon_x, _ = outputs
            sf_rate = size_factor_view * recon_x
            reconstructed_list.append(sf_rate.cpu().numpy())


            

    latent = np.concatenate(latent_list, axis=0)
    reconstructed = np.concatenate(reconstructed_list, axis=0)

    return latent, reconstructed

In [None]:
def predict_trvae(model, adata, condition_key, batch_size=128):
    # evaluation mode
    model.model.eval()

    # Create a dataset and dataloader for prediction
    predict_data, _ = make_dataset(
        adata,
        train_frac=1.0,
        condition_key=condition_key,
        cell_type_keys=None, 
        condition_encoder=model.model.condition_encoder,
        cell_type_encoder=None, 
    )
    # Create dataloader 
    dataloader = torch.utils.data.DataLoader(
        dataset=predict_data,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=custom_collate,
        num_workers=0,
    )

    # store results
    latent_list = []
    reconstructed_list = []

    device = next(model.model.parameters()).device

    # Perform prediction, moves each part of the data that the device the model is trained on 
    with torch.no_grad():
        for batch_iter, batch_data in enumerate(dataloader):
            for key, batch in batch_data.items():
                batch_data[key] = batch.to(device)
            # Get latent
            sf = np.ravel(batch_data["x"].sum(1))
            sf=torch.tensor(sf,device=batch_data["x"].device)
            size_factor_view = sf.unsqueeze(1).expand(batch_data["x"].size(0), batch_data["x"].size(1))
            
            x_log = torch.log(1 + batch_data["x"])
            z1_mean, z1_log_var = model.model.encoder(x_log, batch_data["batch"])
            latent = model.model.sampling(z1_mean, z1_log_var)
            latent_list.append(latent.cpu().numpy())


            # Get recon, NB, takes latent space from encoder and decodes it
            outputs = model.model.decoder(latent, batch_data["batch"])
            recon_x, _ = outputs

            sf_rate = size_factor_view * recon_x


            reconstructed_list.append(sf_rate.cpu().numpy())

            

    latent = np.concatenate(latent_list, axis=0)
    reconstructed = np.concatenate(reconstructed_list, axis=0)

    return latent, reconstructed

In [None]:
latent,rec = predict_trvae(model,test_adata,condition_key="donor_id")

In [None]:
latent_2, rec_2 = predict_trvae(model, adata, condition_key="donor_id")

In [None]:
import os
print("CWD:", os.getcwd(), "Writable?", os.access(os.getcwd(), os.W_OK))

# 1) copy to avoid view‐warning
test_adata = test_adata.copy()
test_adata.obsm["X_reconstructed"] = rec

# 2) write to /tmp (or somewhere you have access)
outfn = "/work/trvae_new/trvae_newpredict/adata_post_with_latent_and_reconstructed_Atlas_RAW_trVAE.h5ad"
test_adata.write(outfn)
print("Wrote to", outfn)


In [None]:
# If rec is an AnnData object, extract the X attribute (i.e., the data matrix)
import anndata
if isinstance(rec, anndata.AnnData):
    rec = rec.X

# Now, rec should be a numpy array or sparse matrix, which is what obsm expects
test_adata.obsm["X_reconstructed"] = rec

# Save the entire object with the reconstructed data
test_adata.write("adata_post_with_latent_and_reconstructed_Atlas_trVAE_removed_batch.h5ad")

In [None]:
rec.sum(axis=1)

In [None]:
adata_2 = adata[test_adata.obs_names].X

# Convert to dense if it's sparse
if not isinstance(adata_2, np.ndarray):
    print("Converting y_true from sparse to dense.")
    adata_2 = adata_2.toarray()



# Now flatten
adata_2_flat = adata_2.flatten()
#rec_2_flat = rec_2.flatten()

In [None]:
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    r2_score,
    mutual_info_score
)

In [None]:
# R2 or R2 adj
# Flatten arrays it is needed, depends on the dimensionality
adata_2_flat = adata_2.flatten()
rec_2_flat = rec.flatten()


r_square = r2_score(adata_2_flat, rec_2_flat)
print("R2:", r_square)

In [None]:
# MSE

mse = mean_squared_error(adata_2, rec)
print(mse)

In [None]:
# MAE 

mae = mean_absolute_error(adata_2, rec)
print(f"Mean absolute error (MAE): {mae}")