In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import warnings

In [None]:
warnings.filterwarnings("ignore")

In [None]:
import anndata as ad
import scanpy as sc
from matplotlib import pyplot as plt
from IPython.display import display
from gprofiler import gprofiler
import torch
import drvi
from drvi.model import DRVI
from drvi.utils.misc import hvg_batch
import numpy as np
import pandas as pd

In [None]:
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

In [None]:
sc.settings.set_figure_params(dpi=100, frameon=False)
sc.set_figure_params(dpi=100)
sc.set_figure_params(figsize=(3, 3))
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.figsize"] = (3, 3)

In [None]:
adata = sc.read("/work/DRVI/fixed_data/scAtlas/scAtlas_Human_vascular_cells_processed_RAW_1.h5ad")
adata

In [None]:
adata.layers["counts"] = adata.X.copy()

In [None]:
# Split the data set into train and test
from sklearn.model_selection import train_test_split


split_key = "split"
adata.obs[split_key] = "train"
idx = list(range(len(adata)))
idx_train, idx_test = train_test_split(adata.obs_names, test_size=0.1, random_state=42)
adata.obs.loc[idx_train, split_key] = "train"
adata.obs.loc[idx_test, split_key] = "test"

In [None]:
train_adata = adata[adata.obs[split_key] == "train"].copy()
test_adata = adata[adata.obs[split_key] == "test"].copy()

In [None]:
# Setup data
DRVI.setup_anndata(
    train_adata,
    # DRVI accepts count data by default.
    # Do not forget to change gene_likelihood if you provide a non-count data.
    layer="counts",
    # Always provide a list. DRVI can accept multiple covariates.
    categorical_covariate_keys=["donor_id"],
    # DRVI accepts count data by default.
    # Set to false if you provide log-normalized data and use normal distribution (mse loss).
    is_count_data=True,
)

# construct the model
model = DRVI(
    train_adata,
    # Provide categorical covariates keys once again. Refer to advanced usages for more options.
    categorical_covariates=["donor_id"],
    n_latent=128,
    # For encoder and decoder dims, provide a list of integers.
    encoder_dims=[128, 128],
    decoder_dims=[128, 128],
)
model

In [None]:
# For cpu training you should add the following line to the model.train parameters:
#accelerator="cpu", devices=1,
#
# For mps acceleration on macbooks, add the following line to the model.train parameters:
# accelerator="mps", devices=1,
#
# For gpu training don't provide any additional parameter.
# More details here: https://lightning.ai/docs/pytorch/stable/accelerators/gpu_basic.html

n_epochs = 400

# train the model
model.train(
    max_epochs=n_epochs,
    early_stopping=True,
    early_stopping_patience=20,
    # mps
    # accelerator="mps", devices=1,
    # cpu
    #accelerator="cpu", devices=1,
    # gpu: no additional parameter
    #
    # No need to provide `plan_kwargs` if n_epochs >= 400.
    plan_kwargs={
        "n_epochs_kl_warmup": n_epochs,
    },
)

# Runtime:
# The runtime for CPU laptop (M1) is 208 minutes
# The runtime for Macbook gpu (M1) is 64 minutes
# The runtime for GPU (A100) is 17 minutes

In [None]:
model.save("FIXED_trained_models/DRVI_scAtlas_batch_removed")

In [None]:
model = model.load("FIXED_trained_models/DRVI_scAtlas_train_raw_new/", adata = train_adata)

In [None]:
model

In [None]:
def predict(model, adata):
    model._validate_anndata(adata)
    model.module.eval()

    scdl = model._make_data_loader(adata=adata, indices=None, batch_size=128, shuffle=False)
    mus = []
    for tensors in scdl:
        inference_outputs, generative_outputs = model.module.forward(
                    tensors,
                    compute_loss=False,
                )
        _mus = torch.nan_to_num(generative_outputs['px'].mean, nan=0, neginf=0, posinf=100) 
        mus.append(_mus.detach().cpu().numpy())
    mus = np.concatenate(mus, axis=0)
    out_adata = adata.copy()
    out_adata.X = mus
    return out_adata


model._validate_anndata(test_adata)
rec = predict(model, test_adata)

In [None]:
z = model.get_latent_representation(
    adata,
    batch_size=256,)

In [None]:
# If rec is an AnnData object, extract the X attribute (i.e., the data matrix)
import anndata
if isinstance(rec, anndata.AnnData):
    rec = rec.X

# Now, rec should be a numpy array or sparse matrix, which is what obsm expects
test_adata.obsm["X_reconstructed"] = rec

# Save the entire object with the reconstructed data
test_adata.write("adata_post_with_latent_and_reconstructed_scAtlas_DRVI_raw_batch_remvoed.h5ad")


In [None]:
y_true = test_adata.X

In [None]:
rec