In [None]:
import scHash
import anndata as ad

from sklearn.metrics import f1_score, precision_score, recall_score
from statistics import median

# Tutorial for atlas level cell annotations
Here is an demonstration on atlas level dataset. We demonstrate the the atlas level annotation with the dataset Tabula Senis Muris and it can be download [here](https://figshare.com/projects/Tabula_Muris_Senis/64982). We followed scArches' preprocess pipeline and the preprocessed data can be downloaded through [here]( https://drive.google.com/file/d/1lfDu-TGsUvHrmXoSWkj0tptvWNYFgs2x/view?usp=share_link). The dataset contains 356213 cells with 5000 highly variable genes with cell type, method, and tissue annotations. 

## Load data

In [None]:
# define data path
data_dir = '../../../../share_data/Tabula_Muris_Senis(TM)/tabula_senis_normalized_all_hvg.h5ad'

# This data contains both the reference and query source
data = ad.read_h5ad(data_dir)

# random split to get query indices
# import random 
from sklearn.model_selection import train_test_split
reference_indicies, query_indicies = train_test_split(list(range(data.shape[0])), train_size=0.8, stratify=data.obs.cell_ontology_class,random_state=42)

train = data[reference_indicies]
test = data[query_indicies]

## Training Model

In [None]:
# set up the training datamodule
datamodule = scHash.setup_training_data(train_data = train,cell_type_key = 'cell_ontology_class', batch_key = 'tissue',batch_size=512)

# set a directory to save the model 
checkpointPath = '../checkpoint/'

# initiliza scHash model and train 
model = scHash.scHashModel(datamodule)
trainer, best_model_path, training_time = scHash.training(model = model, datamodule = datamodule, checkpointPath = checkpointPath, max_epochs = 200)

## Test Model

In [None]:
# add the test data
datamodule.setup_test_data(test)

# test the model
pred_labels, hash_codes = scHash.testing(trainer, model, best_model_path)

# show the test performance
labels_true = test.obs.cell_ontology_class
f1_median = round(median(f1_score(labels_true,pred_labels,average=None)),3)

print(f'F1 Median: {f1_median}')