In [2]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, StratifiedKFold
from scHPL import train, predict, update, progressive_learning, utils, evaluate

### Data loading

#### Reading a csv file
During this tutorial, we will work with the simulated data. This dataset can be downloaded from the Zenodo repository (https://doi.org/10.5281/zenodo.4557712)

For scHPL, the input format of the data needs to be a pandas dataframe (rows: cells, columns: genes)

In [None]:
data = pd.read_csv('Simulated_data.csv', index_col = 0)
data = data.T

labels = pd.read_csv('Simulated_labels.csv')

#### Anndata objects

If you're working with an Anndata object, this can be transformed into a pandas dataframe using the following lines

In [None]:
import scanpy as sc

adata_object = sc.read('anndata.h5ad')
data = pd.DataFrame(data = adata_object.X, index = adata_object.obs_names, columns=adata_object.var_names)
labels = pd.DataFrame(data = adata_object.obs['labels'].values)

### scHPL without progressive learning

Here, we explain how to train a classifier without progressive learning, which can be used when the hierarchy is known beforehand.

First we split the dataset into a training and test set

In [None]:
x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

#### Creating a tree

We create a tree using the utils.create_tree() function. This function creates a tree based on the newick format.

In [None]:
tree = utils.create_tree('((Group1, Group2)Group12, Group3,(Group4, (Group5, Group6)Group56)Group456)root')
utils.print_tree(tree)

#### Training the tree

Next, we train this tree. There are different options here:
    - classifier: can be either 'svm_occ' for the one_class svm or 'svm' for the linear SVM
    - dimred: whether to apply dimensionality reduction to select features. 
      For the one-class SVM, this is recommended. 
      For the linear SVM, it is recommended to turn off and rely on the built-in L2-regularization
    - useRE: whether cells are rejected based on the reconstruction error
    - FN: percentage of false negatives allowed when using the reconstruction error

In [None]:
tree = train.train_tree(x_train, y_train, tree, classifier = 'svm_occ', dimred = True, useRE = True, FN = 1)

### scHPL with progressive learning

scHPL can be used to learn a hierarchy of cell populations by combining the annotations of different datasets. scHPL, however, is not robust to batch effects between the datasets. We recommend to align the datasets before using scHPL

#### Preprocessing the simulated data

We will again split the data in a train and test dataset. We will split the training dataset again in 3 batches to simulate different datasets. To simulate the effect of different resolutions in these batches, we have to relabel some of the populations (e.g. 'Group1' and 'Group2' are renamed as 'Group12')

In [None]:
x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

sss = StratifiedKFold(n_splits = 3, shuffle = True, random_state = 0)
sss.get_n_splits(x_train, y_train)

x_batches = []
y_batches = []

for trainindex, testindex in sss.split(x_train, y_train):
    x_batches.append(x_train.iloc[testindex])
    y_batches.append(y_train.iloc[testindex])

# In batch 1, we merge group1 and group2
xx = np.where((y_batches[0] == 'Group1') | (y_batches[0] == 'Group2'))[0]
y_batches[0].values[xx] = 'Group12'

# In batch 1, we merge group4 and group5 and group6
xx = np.where((y_batches[0] == 'Group4') | (y_batches[0] == 'Group5') | (y_batches[0] == 'Group6'))[0]
y_batches[0].values[xx] = 'Group456'

# In batch 2, we merge group5 and group6
xx = np.where((y_batches[1] == 'Group5') | (y_batches[1] == 'Group6'))[0]
y_batches[1].values[xx] = 'Group56'

#### Learning and training the tree

We learn the tree using the progressive_learning.learn_tree() function. 
There are different input parameters here:
    - data: array of datasets. The first dimension is the number of datasets. Each dataset is a pandas DataFrame as described in the Data Loading section. 
    - labels: array of labels belonging to the datasets. The first dimension is the number of datasets.
    - classifier: can be either 'svm_occ' for the one_class svm or 'svm' for the linear SVM
    - dimred: whether to apply dimensionality reduction to select features. 
      For the one-class SVM, this is recommended. 
      For the linear SVM, it is ecommended to turn off and rely on the built-in L2-regularization
    - useRE: whether cells are rejected based on the reconstruction error
    - FN: percentage of false negatives allowed when using the reconstruction error
    - threshold: matching threshold (default = 0.25)
    - return_missing: populations that caused a complex scenario and are not added to the tree can be returned to the user (return_missing = True) or attached to the root node (return_missing = False)

In [None]:
tree = progressive_learning.learn_tree(x_batches, y_batches, classifier = 'svm_occ', dimred = True, useRE = True, 
                  FN = 1, threshold = 0.25, return_missing = True)

### Predict labels

A trained tree (with or withour progressive learning) can be used to predict the labels of another dataset

In [None]:
y_pred = predict.predict_labels(x_test, tree)

### Evaluate

Here, we evaluate the predictions based on the hierarchical F1-score and look at the confusion matrix

In [None]:
HF1_score = evaluate.hierarchical_F1(y_test.values, y_pred, tree)
confmatrix = evaluate.confusion_matrix(y_test, y_pred)

In [None]:
print(HF1_score)
print(confmatrix)

### Other useful functions

Using the utils functions, nodes in the tree can be added, removed, or renamed. Note that after adding or removing a node, the tree has to be retrained.

In [7]:
tree = utils.create_tree('((Group1, Group2)Group12, Group3,(Group4, (Group5, Group6)Group56)Group456)root')

print('Original tree:')
utils.print_tree(tree)

# Now we add a node to the tree
tree = utils.add_node(name = 'extra node', tree = tree, parent = 'Group2')
print('Tree after adding the new node:')
utils.print_tree(tree)

# Now we remove a node from the tree
# Children = False, indicates that the children should not be removed
tree = utils.remove_node(name = 'Group56', tree = tree, children = False)
print('Tree after removing the node:')
utils.print_tree(tree)

# We rename a node
tree = utils.rename_node(old_name = 'Group12', new_name = 'new name', tree = tree)
print('Tree after renaming the node:')
utils.print_tree(tree)



Original tree:
root
	Group12
		Group1
		Group2
	Group3
	Group456
		Group4
		Group56
			Group5
			Group6
Tree after adding the new node:
root
	Group12
		Group1
		Group2
			extra node
	Group3
	Group456
		Group4
		Group56
			Group5
			Group6
Tree after removing the node:
root
	Group12
		Group1
		Group2
			extra node
	Group3
	Group456
		Group4
		Group5
		Group6
Tree after renaming the node:
root
	new name
		Group1
		Group2
			extra node
	Group3
	Group456
		Group4
		Group5
		Group6
