In [None]:
import scanpy as sc
import logging
import scgen # Development version only works
import sklearn
import seaborn as sns
import torch
import warnings
import os
import sys
import re




#import numpy as np

# Remember to downgrade scvi-tools (Sometimes need to downgrade not always, use pip install scvi-tools 1.6 ) 
# sqrt issue in latent space

# 1. Download scanpy
# 2. Download scgen (not development version) 

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

In [None]:
adata = sc.read("/work/scGen_Human_vascular/new_data_fix_may/HTAPP_997_processed_raw_FINAL.h5ad")

In [None]:
# Remove all cells where the cell_type is 'mature NK T cell'
adata = adata[adata.obs["cell_type"] != "mature NK T cell"] # BBecause there is only one cell in replicate 1

In [None]:
adata.write("HTAPP_997_processed_raw_FINAL_fixed.h5ad")

In [None]:
adata.obs["cell_type"].value_counts()

In [None]:
# Split the data set into train and test
from sklearn.model_selection import train_test_split


split_key = "split"
adata.obs[split_key] = "train"
idx = list(range(len(adata)))
idx_train, idx_test = train_test_split(adata.obs_names, test_size=0.1, random_state=42)
adata.obs.loc[idx_train, split_key] = "train"
adata.obs.loc[idx_test, split_key] = "test"

# Filter the data to use only the training set and make a copy
adata_train = adata[adata.obs[split_key] == "train"].copy()
adata_test = adata[adata.obs[split_key] == "test"].copy()

In [None]:
adata

In [None]:
scgen.SCGEN.setup_anndata(adata_train, batch_key = "replicate", labels_key="cell_type")

In [None]:
model = scgen.SCGEN(adata_train)
model.save("work/abtch_removal/HTAPP_batchremoval", overwrite=True)

In [None]:
model

In [None]:
model.train(
    max_epochs=300,
    batch_size=32,
    early_stopping=True,
    early_stopping_patience=100,
)

In [None]:
model

In [None]:
model.load("/work/scGen_Human_vascular/work/scGen_Human_vascular_new_run_fix/saved_models/scGen_HTAPP_GPU_run_fix_raw", adata = adata_train)

In [None]:
model.is_trained = True

In [None]:
model

In [None]:
# Batch removal
corrected_adata = model.batch_removal()
corrected_adata

In [None]:
adata_test

In [None]:
model.adata_manager.summary()

In [None]:
pred, delta = model.predict(
    ctrl_key="1",
    stim_key="1",
    adata_to_predict=adata_test
)


In [None]:
# If rec is an AnnData object, extract the X attribute (i.e., the data matrix)
import anndata
if isinstance(pred, anndata.AnnData):
    pred = pred.X

# Now, rec should be a numpy array or sparse matrix, which is what obsm expects
adata_test.obsm["X_reconstructed"] = pred

# Save the entire object with the reconstructed data
adata_test.write("scGen_HTAPP_raw_fix_adata_post_with_latent_and_recon_batch_removed.h5ad")

In [None]:
latent_X=model.get_latent_representation(adata, batch_size=256)
#latent_adata = sc.AnnData(X=latent_X, obs=adata_train.obs.copy())