### A semi-supervised framework for the annotation problem

**NB**: please refer to the scVI-dev notebook for introduction of the scVI package.

In this notebook, we investigate how semi-supervised learning combined with the probabilistic modelling of latent variables in scVI can help address the annotation problem.

The annotation problem consists in labelling cells, ie. **inferring their cell types**, knowing only a part of the labels.

In [1]:
cd ~/scVI

/home/ubuntu/scVI


In [2]:
from run_benchmarks import load_datasets
from scvi.models import SVAEC, VAE
from scvi.inference import JointSemiSupervisedVariationalInference

We instantiate the SVAEC model and train it over 250 epochs. Only labels from the `data_loader_labelled` will be used, but to cross validate the results, the labels of `data_loader_unlabelled` will is used at test time. The accuracy of the `unlabelled` dataset reaches 93% here at the end of training.

In [3]:
gene_dataset = load_datasets('cortex')

use_batches=False
use_cuda=True

svaec = SVAEC(gene_dataset.nb_genes, gene_dataset.n_labels)
infer = JointSemiSupervisedVariationalInference(svaec, gene_dataset, n_labelled_samples_per_class=10)
infer.fit(n_epochs=50)

infer.accuracy('unlabelled')

File data/expression.bin already downloaded
training: 100%|██████████| 50/50 [00:19<00:00,  2.50it/s]
Acc for unlabelled is : 0.9376


0.9376490712165833

### Benchmarking against other algorithms

We can compare ourselves against the random forest and SVM algorithms, where we do grid search with 3-fold cross validation to find the best hyperparameters of these algorithms. This is automatically performed through the functions **`compute_accuracy_svc`** and **`compute_accuracy_rf`**.

These functions should be given as input the numpy array corresponding to the equivalent dataloaders, which is the purpose of the **`get_raw_data`** method from `scvi.dataset.utils`.

The format of the result is an Accuracy named tuple object giving higher granularity information about the accuracy ie, with attributes:

- **unweighted**: the standard definition of accuracy

- **weighted**: we might give the same weight to all classes in the final accuracy results. Informative only if the dataset is unbalanced.

- **worst**: the worst accuracy score for the classes

- **accuracy_classes** : give the detail of the accuracy per classes


Compute the accuracy score for rf and svc

In [4]:
svc_scores, rf_scores = infer.svc_rf()
print("\nSVC score test :\n", svc_scores[1])
print("\nRF score train :\n", rf_scores[1])


SVC score test :
 Accuracy(unweighted=0.87018739352640551, weighted=0.84659088617012479, worst=0.72236503856041134, accuracy_classes=[0.79439252336448596, 0.90666666666666662, 0.88571428571428568, 0.79545454545454541, 0.9345679012345679, 0.88697524219590962, 0.72236503856041134])

RF score train :
 Accuracy(unweighted=0.92810902896081771, weighted=0.89784424707013932, worst=0.79545454545454541, accuracy_classes=[0.87383177570093462, 0.90222222222222226, 0.98571428571428577, 0.79545454545454541, 0.96049382716049381, 0.96770721205597421, 0.7994858611825193])
