https://docs.scvi-tools.org/en/stable/user_guide/notebooks/api_overview.html  
https://github.com/YosefLab/scvi-tools

In [None]:
# !pip install leidenalg
# !pip install scanpy==1.7.0
# !pip install scvi-tools
#!pip install --user scikit-misc

In [None]:
import sys
sys.path.append("..")
import argparse
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, calinski_harabasz_score, silhouette_score
from sklearn.cluster import KMeans
from sklearn import metrics

import torch
import torch.nn as nn
import copy
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import h5py
import scipy as sp
# import scanpy as sc # scanpy version 1.5
import scanpy.api as sc # scanpy version 1.7
from collections import Counter
import time
import scvi
import pickle
import os
import glob2
plt.ion()
plt.show()
%load_ext autoreload
%autoreload 2


In [None]:
for category in [
        "balanced_data",
        "imbalanced_data",
        "real_data",
]:

    path = ".."
    if category in ["balanced_data", "imbalanced_data"]:
        files = glob2.glob(f'{path}/R/simulated_data/{category}/*.h5')
        files = [
            f[len(f"{path}/R/simulated_data/{category}/"):-3] for f in files
        ]
    else:
        files = glob2.glob(f'{path}/real_data/*.h5')
        files = [f[len(f"{path}/real_data/"):-3] for f in files]
    print(files)

    df = pd.DataFrame(
        columns=["dataset", "ARI", "NMI", "sil", "run", "time", "pred", "cal"])
    for dataset in files:
        if category in ["balanced_data", "imbalanced_data"]:
            data_mat = h5py.File(
                f"{path}/R/simulated_data/{category}/{dataset}.h5", "r")
        else:
            data_mat = h5py.File(f"{path}/real_data/{dataset}.h5", "r")

        Y = np.array(data_mat['Y'])
        X = np.array(data_mat['X'])
        print(f">>>>dataset {dataset}")

        X = np.ceil(X).astype(np.int)
        for run in range(3):
            start = time.time()
            adata = sc.AnnData(X)
            adata.obs['Group'] = Y
            adata.var_names_make_unique()

            adata.layers["counts"] = adata.X.copy()  # preserve counts
            sc.pp.normalize_total(adata, target_sum=1e4)
            sc.pp.log1p(adata)
            adata.raw = adata  # freeze the state in `.raw`
            #             sc.pp.highly_variable_genes( # old version of scanpy
            #                 adata,
            #                 n_top_genes=2000,
            #                 subset=True,
            #                 flavor="seurat",
            # #                 layer="counts",

            #             )
            sc.pp.highly_variable_genes( # scanpy 1.7
                adata,
                n_top_genes=2000,
                subset=True,
                flavor="seurat_v3",
                layer="counts",
            )
            scvi.data.setup_anndata(adata, layer="counts")
            model = scvi.model.SCVI(adata)
            model.train()
            latent = model.get_latent_representation()
            adata.obsm["X_scVI"] = latent
            adata.layers["scvi_normalized"] = model.get_normalized_expression(
                library_size=10e4)

            sc.pp.neighbors(adata, use_rep="X_scVI")
            sc.tl.umap(adata, min_dist=0.2)
            sc.tl.leiden(adata, key_added="leiden_scVI")

            pred = adata.obs['leiden_scVI'].to_list()
            pred = [int(x) for x in pred]

            elapsed = time.time() - start
            ARI = adjusted_rand_score(Y, pred)
            NMI = np.around(normalized_mutual_info_score(Y, pred), 5)
            ss = silhouette_score(adata.obsm["X_umap"], pred)
            cal = calinski_harabasz_score(adata.obsm["X_umap"], pred)

            df.loc[df.shape[0]] = [
                dataset, ARI, NMI, ss, run, elapsed, pred, cal
            ]
            df.to_pickle(
                f"../output/pickle_results/{category}/{category}_scvi.pkl")