# Set Up

In [1]:
# %load_ext nb_black

# %matplotlib inline
# import matplotlib.pyplot as plt

In [2]:
import argparse
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import numpy as np
import torch
import pandas as pd
import scanpy as sc
import sklearn.metrics as met

from utils import (
    import_lung_data,
    get_node_to_leaves,
    get_names,
    get_to_idx_mappers,
    get_high_variable_genes,
    get_leaf_sampling_probs,
    get_classification_performance,
    predict_meta,
)
from sklearn.svm import LinearSVC
from sklearn.neighbors import KNeighborsClassifier
from scvi.dataset import GeneExpressionDataset
from scvi.inference import SemiSupervisedTrainer, UnsupervisedTrainer
from scvi.models import SCANVI, VAE

SAVE_DIR = "/data/yosef2/users/chenling/TabulaSapiensData/hierarchial_scANVI/scVI_TSP/experiments/lung_chenling/"

if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

import sklearn.metrics as met
import pickle as pkl
from anndata import read_h5ad

os.getcwd()

'/data/yosef2/users/chenling/tabula-sapiens_old/lung_eval'

Notebook parameters:

- `is_fair`: if True, ensures that training data is balanced, by making sure than no more than 100 labelled examples are used for each cell-type
- `ON_LEAVES_ONLY`: if True, only the leaves are used to compare the different algorithnms

# Notebook parameters

In [3]:
is_fair = True
ON_LEAVES_ONLY = False
N_EXAMPLES_MAX = 100
init = 0

In [4]:
DATA_PATH = "/data/yosef2/users/chenling/TabulaSapiensData/PublicDatasets/lung/"
combined = read_h5ad(DATA_PATH + "combined.lung.3000.clean.h5ad")

In [5]:
# parser = argparse.ArgumentParser(description="Process some integers.")
# parser.add_argument("--mode", type=int, default=0)
# parser.add_argument("--test_seed", type=int, default=0)
# args = parser.parse_args()
# mode = int(args.mode)
# TEST_SPLIT_SEED = int(args.test_seed)

mode = 3
TEST_SPLIT_SEED = 0 


In [6]:
a, b = np.unique(combined.obs["batch_names"], return_counts=True)
c = []
for x in a:
    c.append(
        len(
            np.unique(combined[combined.obs["batch_names"] == x].obs["free_annotation"])
        )
    )

pd.DataFrame([a, b, c], index=["batchname", "n_cells", "n_celltypes"]).T

Unnamed: 0,batchname,n_cells,n_celltypes
0,Atlas_droplet,65662,42
1,Atlas_facs,9409,38
2,Barga,5916,26
3,Reyfman,25246,19
4,TSP1_10X,14780,34
5,TSP1_smartseq2,728,31
6,TSP2_10X,20872,34
7,TSP2_smartseq2,782,25


In [7]:
len(np.unique(combined.obs["free_annotation"]))

52

In [8]:
for dataset in [
    ["Atlas_droplet", "Atlas_facs"],
    ["Reyfman"],
    ["Barga"],
    ["TSP1_10X", "TSP1_smartseq2", "TSP2_10X", "TSP2_smartseq2"],
]:
    print(
        " ".join(dataset),
        np.sum(np.isin(combined.obs["batch_names"], dataset)),
        len(
            np.unique(
                combined[np.isin(combined.obs["batch_names"], dataset)].obs[
                    "free_annotation"
                ]
            )
        ),
    )

Atlas_droplet Atlas_facs 75071 43
Reyfman 25246 19
Barga 5916 26
TSP1_10X TSP1_smartseq2 TSP2_10X TSP2_smartseq2 37162 34


## Labels and ontology processing


Loads all the tools necessary to use hscANVI

In [9]:
# Load adjacency property and matrix
adj = pkl.load(
    open(
        "/data/yosef2/users/chenling/TabulaSapiensData/ontology/ontology.lung.2.flat.pkl",
        "rb",
    )
)
adjm = adj.adjacency_matrix()

adjmc = []
for x in adjm:
    a = x.shape[0]
    b = x.shape[1]
    temp = np.zeros((a, b))
    temp[:a, :b] = np.asarray(x)
    # temp[a, b] = 1
    adjmc.append(temp)

print(len(adjmc))

5


We now construct mappers between the different representations of cell-types:
    
    - Names: ("B cell")
    - Tree representation ("CL:XXXXXXX")
    - Integer representation, used to train the different models

Constructs the dataset labels

In [10]:
all_to_leaves = get_node_to_leaves(adjm)
leaf_node_names = list(adjm[-1].columns.values) + ["low_quality"] + ["unassigned"]
other_node_names = list(
    all_to_leaves.index[~np.isin(all_to_leaves.index, leaf_node_names)]
)

id_to_label, label_to_id = get_names(all_to_leaves=all_to_leaves, adj=adj)
nodes_to_indices, indices_to_nodes = get_to_idx_mappers(
    leaf_node_names=leaf_node_names, other_node_names=other_node_names,
)

nodes = label_to_id[combined.obs["free_annotation"]]
labels_f = nodes_to_indices[label_to_id[combined.obs["free_annotation"]]]

Construct table of leaf frequencies, used to sample leave labels from intermediate labels

In [11]:
flat_leaves = []
for li in all_to_leaves[nodes].tolist():
    flat_leaves += li
flat_leaves = pd.Series(flat_leaves)
node_counts = flat_leaves.groupby(flat_leaves).size()
leaf_counts = node_counts.reindex(leaf_node_names).fillna(0.0)

nodes_to_leaves_probs = get_leaf_sampling_probs(
    all_to_leaves=all_to_leaves,
    leaf_counts=leaf_counts,
    nodes_to_indices=nodes_to_indices,
)
nodes_to_leaves_probs = torch.tensor(nodes_to_leaves_probs, device="cuda")
nodes_to_leaves_probs_cpu = torch.tensor(nodes_to_leaves_probs, device="cpu")

  1, keepdims=True
  


## Dataset construction

The h5ad data used in this notebook contains 7 batches:

    0 reyman_ann
    1 barga_ann
    2 lung_droplet_ann
    3 lung_facs_ann
    4 lung_ts1_10x
    5 lung_ts1_facs
    6 lung_ts2_10x
    7 lung_ts2_facs
    

    I labelled in unlabelled
    II unlabelled in labelled
    III other
    
    
    CA Reyfman more II than I
    CA Barga more II than I
    TS Reyfman more II than I
    TS Barga III
    TS CA Approximately the same

/data/yosef2/users/chenling/TabulaSapiensData/hierarchial_scANVI/scVI_TSP/experiments/lung_chenling/

In [12]:
# TS -> CA
if mode == 0:
    train_batch_indices = ["TSP1_10X", "TSP1_smartseq2", "TSP2_10X", "TSP2_smartseq2"]
    test_batch_indices = ["Atlas_droplet", "Atlas_facs"]

# TS -> Reyfman
elif mode == 1:
    train_batch_indices = ["TSP1_10X", "TSP1_smartseq2", "TSP2_10X", "TSP2_smartseq2"]
    test_batch_indices = ["Reyfman"]
    

# TS -> Barga
elif mode == 2:
    train_batch_indices = ["TSP1_10X", "TSP1_smartseq2", "TSP2_10X", "TSP2_smartseq2"]
    test_batch_indices = ["Barga"]


# CA -> Reyfman
elif mode == 3:
    train_batch_indices = ["Atlas_droplet", "Atlas_facs"]
    test_batch_indices = ["Reyfman"]
    

# CA -> Barga                                  
elif mode == 4:
    train_batch_indices = ["Atlas_droplet", "Atlas_facs"]
    test_batch_indices = ["Barga"]


all_batch_indices = train_batch_indices + test_batch_indices

## construct cell type labels

In [13]:
mask_sset = combined.obs["batch_names"].isin(all_batch_indices).values
labels_final = labels_f[mask_sset]
X = combined.X[mask_sset]
batch_data = combined.obs["batch_names"][mask_sset]
batch_indices, batches_mapper = pd.factorize(batch_data)

print("Len of subset of indices :", mask_sset.sum())

if ON_LEAVES_ONLY:
    n_labels_final = len(leaf_node_names) - 2
    n_labels_hscanvi = len(leaf_node_names) - 2

    mask_sset = (
        combined.obs["batch_names"].isin(all_batch_indices).values
        & (labels_f < n_labels_final).values
    )
    labels_final = labels_f[mask_sset]
    X = combined.X[mask_sset]
    batch_data = combined.obs["batch_names"][mask_sset]
    batch_indices, batches_mapper = pd.factorize(batch_data)

    where_eval = (labels_final.values < n_labels_final) & (
        batch_data.isin(test_batch_indices).values
    )

    print("subsampled eval", where_eval.sum())

    where_leaves = (labels_final.values < n_labels_final) & (
        batch_data.isin(test_batch_indices).values
    )
    where_leaves = where_leaves
    where_eval = where_eval

    labels_all = labels_final.values.copy()

    labels_final = nodes_to_leaves_probs_cpu.numpy()[labels_final.values].argmax(-1)

else:
    labels = labels_final.values
    where_eval = batch_data.isin(test_batch_indices).values
    print("subsampled eval", where_eval.sum())

    n_labels_final = nodes_to_leaves_probs_cpu.shape[0]
    n_labels_hscanvi = len(leaf_node_names) - 2

    where_leaves = (labels_final.values < n_labels_hscanvi) & batch_data.isin(
        test_batch_indices
    ).values

    labels_all = labels_final.values.copy()
    labels_final = labels_final.values


print("labels_final", labels_final.shape)
print("X", X.shape)
print("batch_data", batch_data.shape)
print("batch_indices", batch_indices.shape)
print("labels_all", labels_all.shape)

Len of subset of indices : 100317
subsampled eval 25246
labels_final (100317,)
X (100317, 3000)
batch_data (100317,)
batch_indices (100317,)
labels_all (100317,)


## construct scVI dataset

In [14]:
train_data = GeneExpressionDataset()
train_data.populate_from_data(
    X=X, batch_indices=batch_indices, labels=labels_final,
)

train_data.cell_types = nodes_to_indices.index.values
train_data.labels = labels_final
train_data.n_labels = len(train_data.cell_types)

labelled = np.where(
    batch_data.isin(train_batch_indices)
)[0]
unlabelled = np.where(
    ~batch_data.isin(train_batch_indices)
)[0]

# Subsampling labelled examples in the case where we impose balance
if is_fair:
    pops = pd.Series(labels_final[labelled], index=labelled)
    n_examples_ct = (pops).groupby(pops).size()
    n_examples_ct = np.minimum(n_examples_ct, N_EXAMPLES_MAX)

    selected_indices = pops.groupby(pops).apply(
        lambda x: x.sample(n_examples_ct[x.name],random_state=TEST_SPLIT_SEED).index.values
    )
    labelled = np.concatenate(selected_indices.values)
    unlabelled = np.where(~np.isin(np.arange(len(train_data)), labelled))[0]


[2020-08-17 01:28:35,947] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2020-08-17 01:28:35,954] INFO - scvi.dataset.dataset | Remapping labels to [0,N]


In [15]:
annd = combined[mask_sset]
# Train labels
annd.obs["labelled"] = np.isin(np.arange(len(annd)), labelled)
annd.obs["unlabelled"] = np.isin(np.arange(len(annd)), unlabelled)

assert (annd.obs["labelled"].sum() == labelled.shape[0]) & (
    annd.obs["unlabelled"].sum() == unlabelled.shape[0]
)
# Eval labels
annd.obs["eval"] = where_eval
annd.obs["eval_leaves"] = where_leaves

# Memory efficiency
annd.obs["labelled"] = annd.obs["labelled"].astype("category")
annd.obs["unlabelled"] = annd.obs["unlabelled"].astype("category")
annd.obs["eval"] = annd.obs["eval"].astype("category")
annd.obs["eval_leaves"] = annd.obs["eval_leaves"].astype("category")

Trying to set attribute `.obs` of view, copying.


In [16]:
annd.obs["gt"] = labels_final

In [17]:
assert np.unique(train_data.batch_indices).shape[0] == len(all_batch_indices)

In [18]:
annd_file = os.path.join(
    SAVE_DIR,
    "train{}_test{}_{}_{}3.h5ad".format(
        "".join(train_batch_indices),
        "".join(test_batch_indices),
        ON_LEAVES_ONLY,
        TEST_SPLIT_SEED,
    ),
)

annd.write_h5ad(annd_file)

In [19]:
def update_annd(new_key, annd, predictions):
    new_id_key = "predictions_ids_{}".format(new_key)
    new_node_key = "predictions_nodes_{}".format(new_key)
    new_name_key = "predictions_names_{}".format(new_key)

    annd.obs[new_id_key] = predictions
    annd.obs[new_node_key] = indices_to_nodes[annd.obs[new_id_key]].values
    annd.obs[new_name_key] = id_to_label[annd.obs[new_node_key]].values

    annd.obs[new_id_key] = annd.obs[new_id_key].astype("category")
    annd.obs[new_node_key] = annd.obs[new_node_key].astype("category")
    annd.obs[new_name_key] = annd.obs[new_name_key].astype("category")

# Algorithms

All the results are saved in two ways.
First classification performance are incrementally added to the list `results`, that we will later convert to a DataFrame, that will be easily savable.
Second, all the predictions are hard-saved in a h5ad output file, that include the original data, label predictions and latent predictions.

In [20]:
results = []

## scVI and scANVI
### parameters

In [21]:
TRAINER_KWARGS = dict(
    n_epochs_classifier=0,
    lr_classification=1e-3,
    batch_size=500,
    n_epochs_kl_warmup=1,
)
TRAIN_SCVI_KWARGS = dict(n_epochs=30, lr=1e-3)
TRAIN_SCANVI_KWARGS_vanilla = dict(n_epochs=50, lr=1e-3)
TRAIN_SCANVI_KWARGS = dict(n_epochs=50, lr=1e-3)


CLASSIFIER_PARAMETERS = dict(n_hidden=128, n_layers=1, dropout_rate=0.5)
SCANVI_KWARGS = dict(
    n_layers=1,
    n_latent=50,
    symmetric_kl=True,
    dispersion="gene",
    classifier_parameters=CLASSIFIER_PARAMETERS,
)
SCVI_KWARGS = dict(
    n_layers=1, n_latent=50, dispersion="gene",  # dropout_rate=0.5
)

### scVI + kNN (+ ARI Reference)

Train scVI and generates the data latent space, that will serve as input for the kNN algorithm

In [None]:
vae = VAE(n_input=train_data.nb_genes, n_batch=train_data.n_batches, **SCVI_KWARGS)
trainer = UnsupervisedTrainer(
    vae,
    train_data,
    batch_size=128,
    n_epochs_kl_warmup=5,
    data_loader_kwargs={"pin_memory": False},
)
trainer.train(**TRAIN_SCVI_KWARGS)

full_unsup = trainer.create_posterior(
    vae, train_data, indices=np.arange(len(train_data)),
)
latent_unsup, _, _ = full_unsup.sequential().get_latent()

In [None]:
scVI_filename = SAVE_DIR + "scVI_batch{}_fair{}_{}".format(
    str(train_batch_indices), int(is_fair), init
)

torch.save(vae.state_dict(), scVI_filename)

In [None]:
# annd = combined[mask_sset].copy()
annd.obsm["X_scvi"] = latent_unsup

sc.pp.neighbors(annd, n_neighbors=30, n_pcs=50, use_rep="X_scvi")
sc.tl.leiden(annd)
_cluster_unsup = annd.obs["leiden"].values.astype(int)
cluster_unsup = _cluster_unsup[where_eval]
annd.obs["cluster_id"] = _cluster_unsup.astype(str)

### run KNN

In [None]:
knn = KNeighborsClassifier(algorithm="kd_tree")
latent_train = latent_unsup[labelled]
labels_train = labels[labelled]
knn.fit(latent_train, labels_train)
y_knn_full = knn.predict(latent_unsup)
update_annd(new_key="kNN", annd=annd, predictions=y_knn_full)

In [22]:
from collections import namedtuple
from typing import List
import numpy as np
import logging

from sklearn import neighbors
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC

import torch
from torch.nn import functional as F
import torch.distributions as db

from scvi.inference import Posterior
from scvi.inference import Trainer
from scvi.inference.inference import UnsupervisedTrainer
from scvi.inference.posterior import unsupervised_clustering_accuracy
import sklearn.metrics as met

logger = logging.getLogger(__name__)


class AnnotationPosterior(Posterior):
    def __init__(self, *args, model_zl=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_zl = model_zl

    def accuracy(self):
        model, cls = (
            (self.sampling_model, self.model)
            if hasattr(self, "sampling_model")
            else (self.model, None)
        )
        acc = compute_accuracy(model, self, classifier=cls, model_zl=self.model_zl)
        logger.debug("Acc: %.4f" % (acc))
        return acc

    accuracy.mode = "max"

    @torch.no_grad()
    def hierarchical_accuracy(self):
        all_y, all_y_pred = self.compute_predictions()
        acc = np.mean(all_y == all_y_pred)

        all_y_groups = np.array([self.model.labels_groups[y] for y in all_y])
        all_y_pred_groups = np.array([self.model.labels_groups[y] for y in all_y_pred])
        h_acc = np.mean(all_y_groups == all_y_pred_groups)

        logger.debug("Hierarchical Acc : %.4f\n" % h_acc)
        return acc

    accuracy.mode = "max"

    @torch.no_grad()
    def compute_predictions(self, soft=False):
        """
        :return: the true labels and the predicted labels
        :rtype: 2-tuple of :py:class:`numpy.int32`
        """
        model, cls = (
            (self.sampling_model, self.model)
            if hasattr(self, "sampling_model")
            else (self.model, None)
        )
        return compute_predictions(
            model, self, classifier=cls, soft=soft, model_zl=self.model_zl
        )

    @torch.no_grad()
    def compute_predictions_full(self):
        """
        :return: the true labels and the predicted labels
        :rtype: 2-tuple of :py:class:`numpy.int32`
        """
        model, cls = (
            (self.sampling_model, self.model)
            if hasattr(self, "sampling_model")
            else (self.model, None)
        )
        return compute_predictions_full(
            model, self, classifier=cls, model_zl=self.model_zl
        )

    @torch.no_grad()
    def unsupervised_classification_accuracy(self):
        all_y, all_y_pred = self.compute_predictions()
        uca = unsupervised_clustering_accuracy(all_y, all_y_pred)[0]
        logger.debug("UCA : %.4f" % (uca))
        return uca

    unsupervised_classification_accuracy.mode = "max"

    @torch.no_grad()
    def nn_latentspace(self, posterior):
        data_train, _, labels_train = self.get_latent()
        data_test, _, labels_test = posterior.get_latent()
        nn = KNeighborsClassifier()
        nn.fit(data_train, labels_train)
        score = nn.score(data_test, labels_test)
        return score


class ClassifierTrainer(Trainer):
    r"""The ClassifierInference class for training a classifier either on the raw data or on top of the latent
        space of another model (VAE, VAEC, SCANVI).

    Args:
        :model: A model instance from class ``VAE``, ``VAEC``, ``SCANVI``
        :gene_dataset: A gene_dataset instance like ``CortexDataset()``
        :train_size: The train size, either a float between 0 and 1 or and integer for the number of training samples
            to use Default: ``0.8``.
        :test_size: The test size, either a float between 0 and 1 or and integer for the number of test samples
            to use Default: ``None``.
        :sampling_model: Model with z_encoder with which to first transform data.
        :sampling_zl: Transform data with sampling_model z_encoder and l_encoder and concat.
        :\**kwargs: Other keywords arguments from the general Trainer class.


    Examples:
        >>> gene_dataset = CortexDataset()
        >>> vae = VAE(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * False,
        ... n_labels=gene_dataset.n_labels)

        >>> classifier = Classifier(vae.n_latent, n_labels=cortex_dataset.n_labels)
        >>> trainer = ClassifierTrainer(classifier, gene_dataset, sampling_model=vae, train_size=0.5)
        >>> trainer.train(n_epochs=20, lr=1e-3)
        >>> trainer.test_set.accuracy()
    """

    def __init__(
        self,
        *args,
        train_size=0.8,
        test_size=None,
        sampling_model=None,
        sampling_zl=False,
        use_cuda=True,
        nodes_to_leaves_probs=None,
        **kwargs,
    ):
        self.sampling_model = sampling_model
        self.sampling_zl = sampling_zl
        super().__init__(*args, use_cuda=use_cuda, **kwargs)
        self.train_set, self.test_set, self.validation_set = self.train_test_validation(
            self.model,
            self.gene_dataset,
            train_size=train_size,
            test_size=test_size,
            type_class=AnnotationPosterior,
        )

        self.nodes_to_leaves_probs = nodes_to_leaves_probs
        self.train_set.to_monitor = ["accuracy"]
        self.test_set.to_monitor = ["accuracy"]
        self.validation_set.to_monitor = ["accuracy"]
        self.train_set.model_zl = sampling_zl
        self.test_set.model_zl = sampling_zl
        self.validation_set.model_zl = sampling_zl

    @property
    def posteriors_loop(self):
        return ["train_set"]

    def __setattr__(self, key, value):
        if key in ["train_set", "test_set"]:
            value.sampling_model = self.sampling_model
        super().__setattr__(key, value)

    def loss(self, tensors_labelled):
        x, _, _, _, labels_train = tensors_labelled
        if self.sampling_model:
            if hasattr(self.sampling_model, "classify"):
                return F.cross_entropy(
                    self.sampling_model.classify(x), labels_train.view(-1)
                )
            else:
                if self.sampling_model.log_variational:
                    x = torch.log(1 + x)
                if self.sampling_zl:
                    x_z = self.sampling_model.z_encoder(x)[0]
                    x_l = self.sampling_model.l_encoder(x)[0]
                    x = torch.cat((x_z, x_l), dim=-1)
                else:
                    x = self.sampling_model.z_encoder(x)[0]
        return F.cross_entropy(self.model(x), labels_train.view(-1))

    @torch.no_grad()
    def compute_predictions(self, soft=False):
        """
        :return: the true labels and the predicted labels
        :rtype: 2-tuple of :py:class:`numpy.int32`
        """
        model, cls = (
            (self.sampling_model, self.model)
            if hasattr(self, "sampling_model")
            else (self.model, None)
        )
        full_set = self.create_posterior(type_class=AnnotationPosterior)
        return compute_predictions(
            model, full_set, classifier=cls, soft=soft, model_zl=self.sampling_zl
        )

    def on_training_loop(self, tensors_list):
        if self.nodes_to_leaves_probs is not None:
            new_tensors_list = self.convert_to_leaf_nodes(tensors_list[0])
            new_tensors_list = [new_tensors_list]
        else:
            new_tensors_list = tensors_list
        super().on_training_loop(new_tensors_list)
        # self.current_loss = loss = self.loss(*tensors_list)
        # self.optimizer.zero_grad()
        # loss.backward()
        # self.optimizer.step()

    def convert_to_leaf_nodes(self, tensors):
        sample_batch, b, c, d, y = tensors
        if y is None:
            return tensors
        y_probs = self.nodes_to_leaves_probs[y]
        leaves_batch = db.Categorical(probs=y_probs).sample()
        new_tensors = (sample_batch, b, c, d, leaves_batch)
        return new_tensors


class SemiSupervisedTrainer(UnsupervisedTrainer):
    r"""The SemiSupervisedTrainer class for the semi-supervised training of an autoencoder.
    This parent class can be inherited to specify the different training schemes for semi-supervised learning
    """

    def __init__(
        self,
        model,
        gene_dataset,
        n_labels_final: int = None,
        labels_of_use=None,
        n_labelled_samples_per_class=50,
        n_epochs_classifier=1,
        lr_classification=5 * 1e-3,
        classification_ratio=50,
        seed=0,
        nodes_to_leaves_probs: torch.Tensor = None,
        indices_labelled: List = None,
        indices_unlabelled: List = None,
        include_conditionned_elbo: bool = False,
        classify_full_ontology: bool = False,
        **kwargs,
    ):
        """
        :param n_labelled_samples_per_class: number of labelled samples per class
        """
        super().__init__(model, gene_dataset, **kwargs)
        self.model = model
        self.gene_dataset = gene_dataset
        self.nodes_to_leaves_probs = nodes_to_leaves_probs

        self.n_epochs_classifier = n_epochs_classifier
        self.lr_classification = lr_classification
        self.include_conditionned_elbo = include_conditionned_elbo
        self.labelled_fraction = len(indices_labelled) / (1.0 * len(indices_unlabelled))
        self.classification_ratio = classification_ratio
        self.classify_full_ontology = classify_full_ontology

        self.metrics = {
            "Heldout accuracy": [],
        }

        if labels_of_use is None or n_labels_final is None:
            labels_of_use = np.array(self.gene_dataset.labels).ravel()
            n_labels_final = self.gene_dataset.n_labels

        n_labelled_samples_per_class_array = [
            n_labelled_samples_per_class
        ] * n_labels_final
        np.random.seed(seed=seed)
        permutation_idx = np.random.permutation(len(labels_of_use))
        labels_of_use_mix = labels_of_use[permutation_idx]
        indices = []
        current_nbrs = np.zeros(len(n_labelled_samples_per_class_array))
        if (indices_labelled is None) or (indices_unlabelled is None):
            for idx, (label) in enumerate(labels_of_use_mix):
                label = int(label)
                if label == -1:
                    # print("Negative label, not included")
                    indices.append(idx)
                elif current_nbrs[label] < n_labelled_samples_per_class_array[label]:
                    indices.insert(0, idx)
                    current_nbrs[label] += 1
                else:
                    indices.append(idx)
            indices = np.array(indices)
            total_labelled = sum(n_labelled_samples_per_class_array)
            indices_labelled = permutation_idx[indices[:total_labelled]]
            indices_unlabelled = permutation_idx[indices[total_labelled:]]

        print(
            "labelled indices: ",
            np.unique(gene_dataset.labels[indices_labelled].squeeze()),
        )
        print(
            "unlabelled indices: ",
            np.unique(gene_dataset.labels[indices_unlabelled].squeeze()),
        )

        self.classifier_trainer = ClassifierTrainer(
            model.classifier,
            gene_dataset,
            train_size=0.99,
            nodes_to_leaves_probs=nodes_to_leaves_probs,
            metrics_to_monitor=[],
            show_progbar=False,
            frequency=0,
            sampling_model=self.model,
        )
        self.full_dataset = self.create_posterior(shuffle=True)
        self.labelled_set = self.create_posterior(indices=indices_labelled)
        self.unlabelled_set = self.create_posterior(indices=indices_unlabelled)

        self.indices_labelled = indices_labelled
        self.indices_unlabelled = indices_unlabelled

        for posterior in [self.labelled_set, self.unlabelled_set]:
            posterior.to_monitor = ["reconstruction_error", "accuracy"]

    @property
    def posteriors_loop(self):
        return ["full_dataset", "labelled_set"]

    def __setattr__(self, key, value):
        if key == "labelled_set":
            self.classifier_trainer.train_set = value
        super().__setattr__(key, value)

    def loss(self, tensors_all, tensors_labelled):
        if self.include_conditionned_elbo:
            unsup_loss = super().loss(tensors_all, feed_labels=False)
            sup_loss = super().loss(tensors_labelled, feed_labels=True)
            loss = unsup_loss + (self.labelled_fraction * sup_loss)
        else:
            loss = super().loss(tensors_all, feed_labels=False)
        sample_batch, _, _, _, y = tensors_labelled

        if not self.classify_full_ontology:
            probs = self.model.classify(sample_batch)
            y_gt = y.view(-1)
            classification_loss = F.cross_entropy(probs, y_gt)
        else:
            # Compute all internal node labels
            all_lbls = [y.view(-1)]
            for a in self.model.ontology[::-1]:
                A = torch.cuda.FloatTensor(a)
                new_labels = A[:, all_lbls[-1].squeeze()].argmax(0)
                all_lbls.append(new_labels)
            all_lbls = all_lbls[::-1]

            # Compute predictions
            preds = []
            for dep in range(1, self.model.depth + 1):
                preds.append(self.model.classify(sample_batch, depth=dep))

            # Compute classification loss
            classification_loss = 0.0
            for pred, lbl in zip(preds, all_lbls):
                classification_loss += F.cross_entropy(pred, lbl)
            classification_loss = classification_loss / len(preds)
        # print("probs", probs.min().item(), probs.max().item())
        # print("y_gt", y_gt.min().item(), y_gt.max().item())
        loss += classification_loss * self.classification_ratio
        # print("loss: ", loss.item())
        # print("classification_loss: ", classification_loss.item())
        # print("####################################################################################")
        return loss

    def on_epoch_end(self):
        self.model.eval()

        if self.n_epochs_classifier != 0:
            self.classifier_trainer.train(
                self.n_epochs_classifier, lr=self.lr_classification
            )

        with torch.no_grad():
            full = self.create_posterior(
                self.model, train_data, indices=np.arange(len(train_data))
            )
            gt, pred = full.sequential().compute_predictions()
            acc = met.accuracy_score(gt[where_eval], pred[where_eval])
            acc2 = met.balanced_accuracy_score(gt[where_eval], pred[where_eval])
            print(acc, acc2)
        #     self.metrics["Heldout accuracy"].append(acc)
        self.model.train()
        return super().on_epoch_end()

    def create_posterior(
        self,
        model=None,
        gene_dataset=None,
        shuffle=False,
        indices=None,
        type_class=AnnotationPosterior,
    ):
        return super().create_posterior(
            model, gene_dataset, shuffle, indices, type_class
        )

    def on_training_loop(self, tensors_list):
        # Modifies labels!!!
        tensors_all, tensors_labelled = tensors_list
        if self.nodes_to_leaves_probs is not None:
            new_tensors_labelled = self.convert_to_leaf_nodes(tensors_labelled)
            new_tensors_all = self.convert_to_leaf_nodes(tensors_all)
        else:
            new_tensors_labelled = tensors_labelled
            new_tensors_all = tensors_all

        self.current_loss = loss = self.loss(new_tensors_all, new_tensors_labelled)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def convert_to_leaf_nodes(self, tensors):
        sample_batch, b, c, d, y = tensors
        if y is None:
            return tensors
        y_probs = self.nodes_to_leaves_probs[y]
        leaves_batch = db.Categorical(probs=y_probs).sample()
        new_tensors = (sample_batch, b, c, d, leaves_batch)
        return new_tensors


class JointSemiSupervisedTrainer(SemiSupervisedTrainer):
    def __init__(self, model, gene_dataset, **kwargs):
        kwargs.update({"n_epochs_classifier": 0})
        super().__init__(model, gene_dataset, **kwargs)


class AlternateSemiSupervisedTrainer(SemiSupervisedTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def loss(self, all_tensor):
        return UnsupervisedTrainer.loss(self, all_tensor)

    @property
    def posteriors_loop(self):
        return ["full_dataset"]


@torch.no_grad()
def compute_predictions(
    model, data_loader, classifier=None, soft=False, model_zl=False, depth=None
):
    all_y_pred = []
    all_y = []

    for i_batch, tensors in enumerate(data_loader):
        sample_batch, _, _, _, labels = tensors
        all_y += [labels.view(-1).cpu()]

        if hasattr(model, "classify"):
            y_pred = model.classify(sample_batch, depth=depth)
        elif classifier is not None:
            # Then we use the specified classifier
            if model is not None:
                if model.log_variational:
                    sample_batch = torch.log(1 + sample_batch)
                if model_zl:
                    sample_z = model.z_encoder(sample_batch)[0]
                    sample_l = model.l_encoder(sample_batch)[0]
                    sample_batch = torch.cat((sample_z, sample_l), dim=-1)
                else:
                    sample_batch, _, _ = model.z_encoder(sample_batch)
            y_pred = classifier(sample_batch)
        else:  # The model is the raw classifier
            y_pred = model(sample_batch)

        if not soft:
            y_pred = y_pred.argmax(dim=-1)

        all_y_pred += [y_pred.cpu()]

    all_y_pred = np.array(torch.cat(all_y_pred))
    all_y = np.array(torch.cat(all_y))

    return all_y, all_y_pred


@torch.no_grad()
def compute_predictions_full(model, data_loader, classifier=None, model_zl=False):

    # Compute predictions for each layer
    all_preds = []
    for dep in range(1, model.depth + 1):
        all_y, all_y_pred = compute_predictions(
            model, data_loader, classifier=None, soft=True, depth=dep
        )
        all_preds.append(all_y_pred)
    return all_y, all_preds


@torch.no_grad()
def compute_accuracy(vae, data_loader, classifier=None, model_zl=False):
    all_y, all_y_pred = compute_predictions(
        vae, data_loader, classifier=classifier, model_zl=model_zl
    )
    return np.mean(all_y == all_y_pred)


Accuracy = namedtuple(
    "Accuracy", ["unweighted", "weighted", "worst", "accuracy_classes"]
)


@torch.no_grad()
def compute_accuracy_tuple(y, y_pred):
    y = y.ravel()
    n_labels = len(np.unique(y))
    classes_probabilities = []
    accuracy_classes = []
    for cl in range(n_labels):
        idx = y == cl
        classes_probabilities += [np.mean(idx)]
        accuracy_classes += [
            np.mean((y[idx] == y_pred[idx])) if classes_probabilities[-1] else 0
        ]
        # This is also referred to as the "recall": p = n_true_positive / (n_false_negative + n_true_positive)
        # ( We could also compute the "precision": p = n_true_positive / (n_false_positive + n_true_positive) )
        accuracy_named_tuple = Accuracy(
            unweighted=np.dot(accuracy_classes, classes_probabilities),
            weighted=np.mean(accuracy_classes),
            worst=np.min(accuracy_classes),
            accuracy_classes=accuracy_classes,
        )
    return accuracy_named_tuple


@torch.no_grad()
def compute_accuracy_nn(data_train, labels_train, data_test, labels_test, k=5):
    clf = neighbors.KNeighborsClassifier(k, weights="distance")
    return compute_accuracy_classifier(
        clf, data_train, labels_train, data_test, labels_test
    )


@torch.no_grad()
def compute_accuracy_classifier(clf, data_train, labels_train, data_test, labels_test):
    clf.fit(data_train, labels_train)
    # Predicting the labels
    y_pred_test = clf.predict(data_test)
    y_pred_train = clf.predict(data_train)

    return (
        (
            compute_accuracy_tuple(labels_train, y_pred_train),
            compute_accuracy_tuple(labels_test, y_pred_test),
        ),
        y_pred_test,
    )


@torch.no_grad()
def compute_accuracy_svc(
    data_train,
    labels_train,
    data_test,
    labels_test,
    param_grid=None,
    verbose=0,
    max_iter=-1,
):
    if param_grid is None:
        param_grid = [
            {"C": [1, 10, 100, 1000], "kernel": ["linear"]},
            {"C": [1, 10, 100, 1000], "gamma": [0.001, 0.0001], "kernel": ["rbf"]},
        ]
    svc = SVC(max_iter=max_iter)
    clf = GridSearchCV(svc, param_grid, verbose=verbose, cv=3)
    return compute_accuracy_classifier(
        clf, data_train, labels_train, data_test, labels_test
    )


@torch.no_grad()
def compute_accuracy_rf(
    data_train, labels_train, data_test, labels_test, param_grid=None, verbose=0
):
    if param_grid is None:
        param_grid = {"max_depth": np.arange(3, 10), "n_estimators": [10, 50, 100, 200]}
    rf = RandomForestClassifier(max_depth=2, random_state=0)
    clf = GridSearchCV(rf, param_grid, verbose=verbose, cv=3)
    return compute_accuracy_classifier(
        clf, data_train, labels_train, data_test, labels_test
    )


### hscANVI

In [None]:
train_data.cell_types = nodes_to_indices.index.values
train_data.labels = labels_all
train_data.n_labels = len(train_data.cell_types)

# resh contains all the evaluation metrics for this algorithmn
resh = dict(batch=train_batch_indices, model="h-scanVI", init=init)

# Hierarchical training & eval
scanvi = SCANVI(
    train_data.nb_genes,
    n_batch=train_data.n_batches,
    n_labels=n_labels_hscanvi,
    use_ontology=True,
    ontology=adjmc,
    provide_onto_info=False,
    #     provide_onto_info=False,
    **SCANVI_KWARGS,
)
trainer_scanvi = SemiSupervisedTrainer(
    scanvi,
    train_data,
    indices_labelled=labelled,
    indices_unlabelled=unlabelled,
    classification_ratio=50.,
    nodes_to_leaves_probs=nodes_to_leaves_probs,
    include_conditionned_elbo=False,
    classify_full_ontology=True,
    **TRAINER_KWARGS,
)
trainer_scanvi.train(**TRAIN_SCANVI_KWARGS)

h_filename = SAVE_DIR + "hier_batch{}_fair{}_{}".format(
    str(train_batch_indices), int(is_fair), init
)
torch.save(scanvi.state_dict(), h_filename)
full = trainer_scanvi.create_posterior(
    trainer_scanvi.model, train_data, indices=np.arange(len(train_data))
)
gt, pred = full.sequential().compute_predictions()
latent, _, _ = full.sequential().get_latent()

# if ON_LEAVES_ONLY:
predicted_ = pred
# else:
#     # Fancy procedure to predict internal nodes
y_true = labels[where_eval]
y_pred = predicted_[where_eval]

labelled indices:  [  0   1   2   3   5   6   7  10  11  12  13  14  15  16  17  18  19  20
  21  22  23  24  25  26  28  29  30  32  33  35  36  37  38  39  40  41
  42  43  44  45  46  47 134]
unlabelled indices:  [  0   1   2   3   5   6   7  10  11  12  15  16  17  19  20  21  22  23
  24  25  28  29  30  32  33  35  36  37  38  39  40  41  42  44  45  46
  47  84 132 134]


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))



0.6260001584409411 0.33024461504026553
0.6292878079695794 0.375659255514279
0.6171670759724313 0.4419814600720896
0.6957141725421849 0.48445512489194453
0.764952863820011 0.48925036389659615
0.8459161847421374 0.5478673625444642
0.8672264913253585 0.5626630210658505
0.8757426919115899 0.5674472387858178
0.8928147033193378 0.5787485766226788
0.9092925611978135 0.5988183555455147
0.8965776756713935 0.5997975481859219
0.9236710766061951 0.6190985580303681
0.9170957775489187 0.5966416718562065
0.9148776043729699 0.6208595676118334
0.9173730491959122 0.605661504876602
0.9107185296680662 0.6044580024631773
0.9130555335498692 0.60347061152838
0.8796641052047849 0.6176510720329184
0.8852095381446565 0.622230258539734
0.8899627663788323 0.6450257418614758
0.8755842509704508 0.6383828839606885
0.8611661253267845 0.6325387013118077
0.8785946288520954 0.6544960639517295
0.8702764794422879 0.6585527477895635
0.8846946050859542 0.6620120694465187
0.8668699992077953 0.6632927374003245
0.8511447357997

In [None]:
# gt, pred = full.sequential().compute_predictions()
# latent, _, _ = full.sequential().get_latent()

# # if ON_LEAVES_ONLY:
# predicted_ = pred

In [None]:
gt, pred_soft = full.sequential().compute_predictions(soft=True)
unacceptable_cts = []
acceptable_cts = []

for leaf, nodes_subset in tqdm(leaf_paths.iterrows()):
    if nodes_to_indices[leaf] >= n_labels_hscanvi:
        continue
    ind_subset = nodes_to_indices[nodes_subset].values
    cells_subselect = np.isin(labels_train, np.unique(ind_subset))
    if cells_subselect.sum() >= 1:
        acceptable_cts.append(nodes_to_indices[leaf])
    else:
        unacceptable_cts.append(nodes_to_indices[leaf])

pred_soft[:, unacceptable_cts] = 0
predicted_ = pred_soft.argmax(-1)

In [None]:
# # trainer_scanvi.model.eval()
# trainer_scanvi.model.train()
# with torch.no_grad():
#     full = trainer_scanvi.create_posterior(
#         trainer_scanvi.model, train_data, indices=np.where(where_eval)[0]
#     )
#     gt, pred = full.sequential().compute_predictions()
#     preds = [full.sequential().compute_predictions(soft=True)[1] for _ in tqdm(range(100))]
# preds = [pred[..., None] for pred in preds]
# preds = np.concatenate(preds, -1)
# pred = preds.sum(-1).argmax(-1)
# preds_oh = preds.copy()
# preds_oh = (preds_oh == preds_oh.max(1)[:, None]).astype(int).sum(-1)
# pred = preds_oh.argmax(-1)

# latent, _, _ = full.sequential().get_latent()

# # if ON_LEAVES_ONLY:
# predicted_ = pred
# # else:
# #     # Fancy procedure to predict internal nodes
# # y_true = labels[where_eval]
# # y_pred = predicted_[where_eval]

In [None]:
renamer = pd.Series(
    id_to_label[indices_to_nodes].values, index=indices_to_nodes.index.astype(str)
)

tree_mapper = []
for idx, node in (indices_to_nodes).iteritems():
    _node = node
    path = []
    for mat in adjm[::-1]:
        if _node in mat.columns:
            anc = mat.index[mat.loc[:, _node].argmax()]
        else:
            anc = _node
        path.append(anc)
        _node = anc
    tree_path = {"d{}".format(i): nod for i, nod in enumerate(path)}
    tree_mapper.append(
        dict(
            node=node,
            **tree_path,
        )
    )
tree_mapper = pd.DataFrame(tree_mapper).set_index("node").assign(orig=lambda x: x.index)


In [None]:
# h_class_perfs = get_classification_performance(
#     y_true=y_true, y_pred=y_pred, y_ari=cluster_unsup, labels_mapper=renamer,
# )

# y_true_leaves = y_true[y_true < n_labels_hscanvi]
# y_pred_leaves = y_pred[y_true < n_labels_hscanvi]
# leaves_class_perfs = get_classification_performance(
#     y_true=y_true_leaves,
#     y_pred=y_pred_leaves,
#     y_ari=None,
#     labels_mapper=renamer,
#     prefix="Leaves. ",
# )
# resh = {**resh, **h_class_perfs, **leaves_class_perfs}


# nodes_true = indices_to_nodes[y_true].values
# nodes_pred = indices_to_nodes[y_pred].values
# y_anc_true = tree_mapper.reindex(nodes_true)
# y_anc_pred = tree_mapper.reindex(nodes_pred)
# for depth in ["d0", "d2", "d4"]:
#     h_parent_class_perfs = get_classification_performance(
#         y_true=nodes_to_indices[y_anc_true[depth]].values,
#         y_pred=nodes_to_indices[y_anc_pred[depth]].values,
#         prefix="Anc. {} ".format(depth),
#         y_ari=cluster_unsup,
#         labels_mapper=renamer,
#     )
#     resh = {**resh, **h_parent_class_perfs}
# results.append(resh)

#### Latent kNN approach

In [None]:
latent_train = latent[labelled]
latent_eval = latent[where_eval]
labels_train = train_data.labels.squeeze()[labelled]

In [None]:
from sklearn.neighbors import KNeighborsClassifier
from tqdm.auto import tqdm

from sklearn.ensemble import RandomForestClassifier

In [None]:
leaf_paths = tree_mapper.loc[leaf_node_names]

bottomsup_classifiers = dict()
extended_predictions = dict()
for leaf, nodes_subset in tqdm(leaf_paths.iterrows()):
#     print(nodes_to_indices[leaf])
    ind_subset = nodes_to_indices[nodes_subset].values
    
    cells_subselect = np.isin(labels_train, np.unique(ind_subset))
    x_train = latent_train[cells_subselect]
    y_train = labels_train[cells_subselect]
#     knn = KNeighborsClassifier(n_neighbors=3, algorithm="kd_tree")
    knn = RandomForestClassifier()
    print(ind_subset, cells_subselect.sum(), np.unique(y_train, return_counts=True))
    if cells_subselect.sum() >= 1:
        knn.fit(x_train, y_train)
        bottomsup_classifiers[nodes_to_indices[leaf]] = knn
        extended_predictions[nodes_to_indices[leaf]] = knn.predict(latent)
    else:
        _pred = nodes_to_indices[leaf] * np.ones(len(latent))
        extended_predictions[nodes_to_indices[leaf]] = _pred
        

extended_predictions = pd.DataFrame(extended_predictions)
predicted_all = extended_predictions.lookup(
    list(extended_predictions.index), 
    predicted_
)

In [None]:
labels_eval = train_data.labels.squeeze()[where_eval]
get_classification_performance(labels_eval, predicted_all[where_eval])


5 neighbors

    {'ARI': 0.6756242870939023,
     'Accuracy': 0.6399808181588097,
     'Accuracy (weighted)': 0.502664518029026,
     'wAccuracy2': 0.6399808181588097,
     'F1 (macro)': 0.3839107657664729,
     'F1 (weighted)': 0.6738910799565824,
     'y_true': array([20, 20, 20, ..., 19, 19, 19], dtype=uint16),
     'y_pred': array([20., 20., 20., ..., 19., 19., 19.]),
     'poe': 0.07313076953817053}
     
     
20 neighbors

    {'ARI': 0.6737769771356849,
     'Accuracy': 0.6376630123483102,
     'Accuracy (weighted)': 0.5017265938647667,
     'wAccuracy2': 0.6376630123483104,
     'F1 (macro)': 0.38358085532087355,
     'F1 (weighted)': 0.6727164574114042,
     'y_true': array([20, 20, 20, ..., 19, 19, 19], dtype=uint16),
     'y_pred': array([20., 20., 20., ..., 19., 19., 19.]),
     'poe': 0.07586151776318419}

In [None]:
update_annd(new_key="hscANVI (full)", annd=annd, predictions=predicted_all)

In [None]:
update_annd(new_key="hscANVI (leaves)", annd=annd, predictions=predicted_)

#### Bottoms-up approach

In [None]:
# ## Fancy procedure
# # hscanvi
# labels_train = np.unique(train_data.labels.squeeze())
# y_gt, y_preds = full.sequential().compute_predictions_full()
# chil_to_parents = [np.arange(scanvi.ontology[-1].shape[-1])]
# chil_to_parent0 = np.array([np.where(y == 1)[0][0] for y in scanvi.ontology[-1].T])
# chil_to_parents.append(chil_to_parent0)
# associated_children = []
# for itx, ont in enumerate(scanvi.ontology[-2::-1]):
#     print(ont.shape)
#     chil_to_parent = np.array([np.where(y == 1)[0][0] for y in ont.T])
#     res = chil_to_parent[chil_to_parent0]
#     associated_children.append(res)
#     print(res.shape, np.unique(res).shape)
#     chil_to_parent0 = res.copy()
#     chil_to_parents.append(res)

# chil_to_parents = chil_to_parents[::-1]

# y_preds_unif = []
# for yp, mapper in zip(y_preds, chil_to_parents):
#     print(yp.shape)
#     y_preds_unif.append(yp[:, mapper][..., None])
# y_preds_unif = np.concatenate(y_preds_unif, axis=-1)

# y_preds_pred = y_preds_unif[..., -1].argmax(-1)[:, None, None]
# y_preds_unif.shape

# for p_thresh in [0.7, 0.9, 0.95]:
#     y_probs_best_path = np.take_along_axis(y_preds_unif, y_preds_pred, 1).squeeze()
#     predicted_leaf_node = indices_to_nodes[y_preds_pred.squeeze()].values
#     predicted_path_nodes = tree_mapper.loc[predicted_leaf_node]
#     selected_depth = (len(adjm) + 1) - (y_probs_best_path >= p_thresh).sum(1)
#     selected_depth = np.clip(selected_depth, a_min=0, a_max=len(adjm))
#     predicted_node = predicted_path_nodes.loc[
#         :, ["orig", "d0", "d1", "d2", "d3", "d4"]
#     ].values[np.arange(len(predicted_path_nodes)), selected_depth]
#     predicted_multi = nodes_to_indices[predicted_node].values
#     print(get_classification_performance(labels[where_eval], predicted_multi[where_eval]))

In [None]:
# get_classification_performance(labels[where_eval], y_knn_full[where_eval])

In [None]:
# get_classification_performance(labels[where_eval], predicted_[where_eval])

In [None]:
# print(resh.keys())

In [None]:
update_annd(new_key="hscANVI (leaves)", annd=annd, predictions=predicted_)

### Vanilla

In [None]:
scanvi_vanilla = SCANVI(
    train_data.nb_genes,
    n_batch=train_data.n_batches,
    n_labels=n_labels_final,
    use_ontology=False,
    **SCANVI_KWARGS,
)
trainer_scanvi_vanilla = SemiSupervisedTrainer(
    scanvi_vanilla,
    train_data,
    indices_labelled=labelled,
    indices_unlabelled=unlabelled,
    include_conditionned_elbo=True,
    **TRAINER_KWARGS,
)
trainer_scanvi_vanilla.train(**TRAIN_SCANVI_KWARGS_vanilla)

vanilla_filename = SAVE_DIR + "vanilla_batch{}_fair{}_{}".format(
    str(train_batch_indices), int(is_fair), init
)
torch.save(scanvi_vanilla.state_dict(), vanilla_filename)

full_vanilla = trainer_scanvi_vanilla.create_posterior(
    trainer_scanvi_vanilla.model, train_data, indices=np.arange(len(train_data)),
)
(gt_vanilla, pred_vanilla,) = full_vanilla.sequential().compute_predictions()
latent_vanilla, _, _ = full_vanilla.sequential().get_latent()

In [None]:
labels_eval = train_data.labels.squeeze()[where_eval]
get_classification_performance(labels_eval, pred_vanilla[where_eval])

In [None]:
update_annd(new_key="scANVI", annd=annd, predictions=pred_vanilla)

### scVI + scANVI

Can ignore, corresponds to the original scANVI algorithm which does not use symmetrical KL

In [None]:
# scanvi = SCANVI(
#     train_data.nb_genes,
#     n_batch=train_data.n_batches,
#     n_labels=n_labels_final,
#     use_ontology=False,
#     **SCANVI_KWARGS,
# )
# trainer_scanvi = SemiSupervisedTrainer(
#     scanvi,
#     train_data,
#     indices_labelled=labelled,
#     indices_unlabelled=unlabelled,
#     include_conditionned_elbo=True,
#     **TRAINER_KWARGS,
# )
# trainer_scanvi.model.load_state_dict(torch.load(scVI_filename), strict=False)
# trainer_scanvi.model.eval()
# trainer_scanvi.train(**TRAIN_SCANVI_KWARGS)

# scanvi_filename = SAVE_DIR + "scanvi_batch{}_fair{}_{}".format(
#     str(train_batch_indices), int(is_fair), init
# )
# torch.save(scanvi.state_dict(), scanvi_filename)

# full = trainer_scanvi.create_posterior(
#     trainer_scanvi.model, train_data, indices=np.arange(len(train_data)),
# )
# (gt, pred_scanvi,) = full.sequential().compute_predictions()
# latent_scanvi, _, _ = full.sequential().get_latent()

In [None]:
# update_annd(
#     new_key="scVI_scANVI", annd=annd, predictions=pred_scanvi,
# )

## SVM

In [None]:
data = np.log1p(X)
data_train = data[labelled]
y_train = labels[labelled]
svm = LinearSVC()
svm.fit(data_train, y_train)
y_svm_full = svm.predict(data)
update_annd(new_key="SVM", annd=annd, predictions=y_svm_full)

## MetaClassifier

In [None]:
# resh contains all the evaluation metrics for this algorithmn
y_meta = predict_meta(
    preds=[pred_vanilla, pred_scanvi, y_knn_full, y_svm_full], n_classes=n_labels_final
)

update_annd(new_key="META", annd=annd, predictions=y_meta)

# Results comparison

## `.h5ad` save

**SUPER IMPORTANT SECTION: Saves all predictions in H5AD**

In [None]:
annd.obsm["vanilla_latent"] = latent_vanilla
annd.obsm["scanvi_latent"] = latent_scanvi
annd.obsm["scvi_latent"] = latent_unsup

sc.pp.neighbors(final_data,n_neighbors=30, n_pcs=50, use_rep='scvi_latent')
sc.tl.umap(final_data)

annd_file = os.path.join(
    SAVE_DIR,
    "results1layer_train{}_test{}_leavesonly{}_{}3.h5ad".format(
        "".join(train_batch_indices),
        "".join(test_batch_indices),
        ON_LEAVES_ONLY,
        TEST_SPLIT_SEED,
    ),
)
annd.write_h5ad(annd_file)

## Path clustering


In [None]:
# is_leaf = y_gt < n_labels_hscanvi

# leaf_cells = np.random.choice(np.where(is_leaf)[0], 300)
# internal_cells = np.random.choice(np.where(~is_leaf)[0], 300)

# plt.hist(y_probs_best_path[is_leaf, -1], label="leaves", alpha=0.5, density=True)
# plt.hist(y_probs_best_path[~is_leaf, -1], label="internal", alpha=0.5, density=True)

In [None]:
# y_ = ~is_leaf
# y_pred_ = y_probs_best_path[:, -1] <= 0.9

# print(met.classification_report(y_, y_pred_))

In [None]:
# paths_n = os.path.join(
#     SAVE_DIR,
#     "paths_eval{}_{}.pdf".format(
#         "".join(train_batch_indices), "".join(test_batch_indices),
#     ),
# )

In [None]:
# is_leaf = y_gt < n_labels_hscanvi
# is_leaf = is_leaf & where_eval
# is_internal = (~is_leaf) & where_eval

# leaf_cells = np.random.choice(np.where(is_leaf)[0], 100)
# internal_cells = np.random.choice(np.where(is_internal)[0], 100)

# fig, axes = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(8, 3.5))
# plt.xlabel("Tree depth", fontsize=13)
# plt.sca(axes[0])
# plt.ylabel("Posterior Probability", fontsize=13)
# plt.title("Leaf cells", fontsize=15)

# _y_probs_best_path = y_probs_best_path[leaf_cells]
# plt.plot(_y_probs_best_path.T)
# _ = plt.plot(np.median(y_probs_best_path[is_leaf].T, 1), color="red", linewidth=5)
# plt.sca(axes[1])
# plt.yscale("log")
# plt.title("Internal nodes cells", fontsize=15)

# _y_probs_best_path = y_probs_best_path[internal_cells]
# plt.plot(_y_probs_best_path.T)
# _ = plt.plot(np.median(y_probs_best_path[~is_leaf].T, 1), color="red", linewidth=5)

# plt.suptitle("Predictions with hscANVI", fontsize=16)
# plt.tight_layout()
# plt.savefig(paths_n)

In [None]:
# pathsumap_n = "pathsumap_eval{}_{}.pdf".format(
#     "".join(train_batch_indices), "".join(test_batch_indices),
# )


# # pathsumap_n = os.path.join(
# #     SAVE_DIR,
# #     "pathsumap_eval{}_{}.pdf".format(
# #         "".join(train_batch_indices), "".join(test_batch_indices),
# #     ),
# # )

In [None]:
# pathsumap_n

In [None]:
# y_gt_vanilla, y_preds_vanilla = full_vanilla.sequential().compute_predictions(soft=True)


# annd.obsm["logp_path"] = np.log(y_probs_best_path)
# annd.obsm["p_path"] = y_probs_best_path

# y_true = labels
# nodes_true = indices_to_nodes[y_true].values
# y_anc_true = tree_mapper.reindex(nodes_true)

# keys = []
# for depth in ["orig", "d2", "d4"]:
#     annd.obs[depth] = id_to_label[y_anc_true[depth]].values
#     annd.obs[depth] = annd.obs[depth].astype("category")
#     keys.append(depth)

# annd.obs["logp_max"] = np.log(y_probs_best_path).max(-1)
# keys.append("logp_max")

# annd.obs["is_leaf"] = labels < n_labels_hscanvi
# keys.append("is_leaf")

# anndeval = annd[where_eval]
# sc.pp.neighbors(anndeval, n_neighbors=20, use_rep="p_path")
# sc.tl.umap(anndeval, min_dist=0.1)

# sc.pl.umap(
#     anndeval, color=keys, legend_fontsize=5, ncols=2, save=pathsumap_n,
# )

# res_study = pd.DataFrame(
#     {
#         "Leaf probability": (y_probs_best_path)[:, -1][where_eval],
#         "Probability (scANVI)": y_preds_vanilla.max(-1)[where_eval],
#         "Cell-type": anndeval.obs["orig"],
#         "Ancestor": anndeval.obs["d4"],
#     }
# )

# _res_study = res_study.groupby("Cell-type").mean()

# import seaborn as sns

# fig, axes = plt.subplots(ncols=2, figsize=(12, 6))
# chart = sns.barplot(
#     x="Cell-type",
#     y="Leaf probability",
#     order=_res_study.sort_values("Leaf probability").index.values,
#     data=res_study,
#     ax=axes[0],
# )
# _ = chart.set_xticklabels(chart.get_xticklabels(), rotation=90)

# chart = sns.barplot(
#     x="Cell-type",
#     y="Probability (scANVI)",
#     order=_res_study.sort_values("Leaf probability").index.values,
#     data=res_study,
#     ax=axes[1],
# )
# _ = chart.set_xticklabels(chart.get_xticklabels(), rotation=90)

In [None]:
# assert False