### A semi-supervised framework for the annotation problem

**NB**: please refer to the scVI-dev notebook for introduction of the scVI package.

In this notebook, we investigate how semi-supervised learning combined with the probabilistic modelling of latent variables in scVI can help address the annotation problem.

The annotation problem consists in labelling cells, ie. **inferring their cell types**, knowing only a part of the labels.

In [1]:
cd ..

/home/ubuntu/scVI


In [2]:
from scvi.dataset import load_datasets
from scvi.models import SVAEC
from scvi.dataset.utils import get_data_loaders
from scvi.train import train_semi_supervised_alternately, train_semi_supervised_jointly
from scvi.metrics.classification import compute_accuracy_svc, compute_accuracy_rf
from scvi.dataset.utils import get_raw_data
import numpy as np

In [3]:
gene_dataset = load_datasets('cortex')

use_batches=False
use_cuda=True
data_loader_all, data_loader_labelled, data_loader_unlabelled = get_data_loaders(gene_dataset, 10, 
                                                                                 batch_size=128, pin_memory=use_cuda)

# Sanity checks
print("Number of labelled samples : ",len(data_loader_labelled.sampler.indices))
print("Labels and their proportions: ",np.unique(gene_dataset.labels[data_loader_labelled.sampler.indices], return_counts=True))

Pickle for :  ../scVI-dev/data/cortex-tmp
Number of labelled samples :  70
Labels and their proportions:  (array([0, 1, 2, 3, 4, 5, 6]), array([10, 10, 10, 10, 10, 10, 10]))


We instantiate the SVAEC model and train it over 200 epochs. Only labels from the `data_loader_labelled` will be used, but to cross validate the results, the labels of `data_loader_unlabelled` will is used at test time. The accuracy of the `unlabelled` dataset reaches 93% here at the end of training.

In [4]:
svaec = SVAEC(gene_dataset.nb_genes, n_batch=gene_dataset.n_batches * use_batches, n_labels=gene_dataset.n_labels,
            use_cuda=use_cuda)
train_semi_supervised_jointly(svaec, data_loader_all, data_loader_labelled, data_loader_unlabelled, n_epochs=200, record_freq=100)

EPOCH [0/200]: 
LL labelled is: 32302.121429
Accuracy labelled is: 0.014286
LL unlabelled is: 38213.741227
Accuracy unlabelled is: 0.026576
EPOCH [100/200]: 
LL labelled is: 1274.317076
Accuracy labelled is: 1.000000
LL unlabelled is: 1367.440013
Accuracy unlabelled is: 0.944804
EPOCH [200/200]: 
LL labelled is: 1193.052344
Accuracy labelled is: 1.000000
LL unlabelled is: 1259.477050
Accuracy unlabelled is: 0.945826
Total runtime for 201 epochs is: 80.55557942390442 seconds for a mean per epoch runtime of 0.4007740269845991 seconds.


<scvi.metrics.stats.Stats at 0x7f41d657ceb8>

### Benchmarking against other algorithms

We can compare ourselves against the random forest and SVM algorithms, where we do grid search with 3-fold cross validation to find the best hyperparameters of these algorithms. This is automatically performed through the functions **`compute_accuracy_svc`** and **`compute_accuracy_rf`**.

These functions should be given as input the numpy array corresponding to the equivalent dataloaders, which is the purpose of the **`get_raw_data`** method from `scvi.dataset.utils`.

The format of the result is an Accuracy named tuple object giving higher granularity information about the accuracy ie, with attributes:

- **unweighted**: the standard definition of accuracy

- **weighted**: we might give the same weight to all classes in the final accuracy results. Informative only if the dataset is unbalanced.

- **worst**: the worst accuracy score for the classes

- **accuracy_classes** : give the detail of the accuracy per classes


1 - Load the data

In [5]:
(data_train, labels_train), (data_test, labels_test) = get_raw_data(data_loader_labelled, data_loader_unlabelled)

2 - Compute the accuracy score for svc

In [6]:
accuracy_train , accuracy_test = compute_accuracy_svc(data_train, labels_train,data_test, labels_test)
print(accuracy_test)

Accuracy(unweighted=0.8701873935264055, weighted=0.8465908861701248, worst=0.7223650385604113, accuracy_classes=[0.794392523364486, 0.9066666666666666, 0.8857142857142857, 0.7954545454545454, 0.9345679012345679, 0.8869752421959096, 0.7223650385604113])


3 - Compute the accuracy score for rf

In [7]:
accuracy_train , accuracy_test = compute_accuracy_rf(data_train, labels_train,data_test, labels_test)
print(accuracy_test)

Accuracy(unweighted=0.9281090289608177, weighted=0.8978442470701393, worst=0.7954545454545454, accuracy_classes=[0.8738317757009346, 0.9022222222222223, 0.9857142857142858, 0.7954545454545454, 0.9604938271604938, 0.9677072120559742, 0.7994858611825193])
