In [None]:
import warnings
warnings.simplefilter(action='ignore')
import scanpy as sc
import torch
import scarches as sca
import numpy as np
import gdown
import os

sc.set_figure_params(frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [None]:
base_path = os.path.expanduser(f"~/top_adatas/")
base_path

In [None]:
dataset_name = "pancreas"
cell_type_key = "clusters"

adata_path = base_path+f"mivelo_{dataset_name}.h5ad"
adata = sc.read_h5ad(adata_path)

In [None]:
adata.varm["I"].shape

In [None]:
data = np.concatenate([adata.layers["Mu"], adata.layers["Ms"]], axis=1)
adata_expimap = sc.AnnData(X=data)
adata_expimap.obs[cell_type_key] = adata.obs[cell_type_key].copy()
adata_expimap.varm["I"] = adata.uns["mask"]

In [None]:
print(f"Hard mask shape: {adata_expimap.varm['I'].shape}")
adata_expimap.obs["study"] = "0"
intr_cvae = sca.models.EXPIMAP(
    adata=adata_expimap,
    condition_key='study',
    hidden_layer_sizes=[512, 512, 512],
    recon_loss='mse',
    soft_mask = False,
    n_ext = 0,
    use_hsic=False,
)

In [None]:
ALPHA = .7
early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss", # val_unweighted_loss
    "threshold": 0,
    "patience": 50,
    "reduce_lr": False,
    "lr_patience": 13,
    "lr_factor": 0.1,
}
intr_cvae.train(
    n_epochs=1000,
    alpha_epoch_anneal=500,
    alpha=ALPHA,
    alpha_kl=1e-5,
    weight_decay=0.01,
    early_stopping_kwargs=early_stopping_kwargs,
    use_early_stopping=False,
    monitor_only_val=False,
    seed=2020,
    train_frac=1,
    print_stats=True

)

In [None]:
adata.obsm['z'] = intr_cvae.get_latent(mean=False, only_active=True)
sc.pp.neighbors(adata, use_rep='z')
sc.tl.umap(adata)
sc.pl.umap(adata, color=[cell_type_key], frameon=False)

In [None]:
adata.write_h5ad(base_path+f"expimap_{dataset_name}.h5ad")
adata.write_h5ad(f"expimap_{dataset_name}.h5ad")

In [20]:
adata.write_h5ad("expimap_pancreas.h5ad")