In [1]:
%load_ext autoreload
%autoreload 2


import sys
import os

cwd = os.getcwd()
print("Current kernel CWD:", cwd)

## 2a. Opción: ir 'hacia abajo' al subdirectorio donde está tu módulo
#module_dir = os.path.join(cwd, "conformalized_single_cell_annotator")
#sys.path.insert(0, module_dir)


import numpy as np  
import scanpy as sc
import pandas as pd


from conformalSC_annotator import  ConformalSCAnnotator
from torchcp.classification.score import  APS,RAPS, THR


# First we read the already preprocessed query data using scanpy. In this case data is coming from a h5ad file.

query_data_path = '../test_data/gastrulation/gastrulation_query_1.h5ad'

adata_query = sc.read_h5ad(query_data_path) 

## We need a .var column that contains the gene names (if not created).
## Sometimes this information is on index column adata_query.var_names, but we explicity in a new column if not exist .
## In this case, we suppose that the column is already created and named: "features".

gene_names_column = "features" 




Current kernel CWD: /beegfs/home/mlopezdecas/single_cell/conformalized_single_cell_annotator


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import inspect
print(inspect.signature(ConformalSCAnnotator.__init__))

(self, X, var_query=None, obs_query=None, var_query_gene_column_name='gene_name', underlying_model='torch_net')


In [None]:

model = "scmap"  # Choose the underlying model for the annotator. Options: "torch_net", "celltypist", "scmap". Default: "torch_net"

annotator = ConformalSCAnnotator(adata_query,
                                var_query_gene_column_name = gene_names_column, 
                                underlying_model = model) # obs_query is optional, it will be used for annotate the predicted cells. 



# Define que network architecture if torch_net is used as underlying model.   
network_architecture:dict = {   
            "hidden_sizes": [128, 128, 64, 64],
            "dropout_rates": [0.4, 0.3, 0.4, 0.25],
            "learning_rate": 0.0001}


OOD_detector_config = { "pvalues": "marginal",             # choose between marginal or conditional. Def: "marginal"
                        "alpha": 0.1,                      # Significance level for the hyoothesis test
                        "delta": 0.1,                      # only for conditional pvalues
                        "hidden_sizes": [ 556,  124, 64],      # AE hidden sizes and topology of the network
                        "dropout_rates": [ 0.2,  0.25, 0.25],
                        "learning_rate": 0.0001,
                        "batch_size": 32,
                        "n_epochs": 2}
    

reference_data_path = '../test_data/gastrulation/gastrulation_reference_1.h5ad'     # Path to the reference data


annotator.configure(reference_path = reference_data_path,        # Path to the reference data in format .h5ad or adata
                    model_architecture = network_architecture,   # Optional, if not provided, default values will be used
                    OOD_detector = OOD_detector_config,          # Optional, if not provided, default values will be used
                    CP_predictor = "standard",                   # standard, mondrian or cluster
                    cell_names_column = "celltype",               # class name for fitting the model.  cell_type or celltype_level3, celltype 
                    cell_types_excluded_treshold = 45,           # Exclude cell types with less than 50 cells
                    test =  True,                                # Perform internal test of the model
                    alpha = [0.01, 0.05, 0.1],                   # Confidence of the predictions (can be a single element)
                    non_conformity_function = APS(),             # NC-function provided by or compatible with torchCP   (APS, RAPS, THR) 
                    epoch=200,
                    batch_size = 42,
                    random_state = None)  
    

# Annotate the query data. 
# If batch corrected data is available, it can be used .

annotator.annotate(obsm_layer=None)  # If in .obsm some reduction of integration is conducted: None, "X_pca_harmony" or "X_pca"



Data stored in adata.X will be used for annotation
Model not trained yet. Fitting the model first.
Loading reference data...
Reference data shape: (1843, 2000)

Reference data loaded.

Initial reference data label distribution:
celltype
Rostral neurectoderm              300
Epiblast                          283
Mesenchyme                        218
Mixed mesoderm                    190
Nascent mesoderm                  160
Primitive Streak                  143
Visceral endoderm                 104
Caudal epiblast                    80
Def. endoderm                      79
Haematoendothelial progenitors     69
Blood progenitors 2                56
Gut                                56
Surface ectoderm                   53
Blood progenitors 1                52
Name: count, dtype: int64

Query data label distribution:
celltype
Rostral neurectoderm              736
Epiblast                          641
Nascent mesoderm                  485
Primitive Streak                  395
Mixed mesode

In [7]:
# The annotations are provided in the original query data object:

annotated_cells = annotator.adata_query.obs
print("\nPredicted annotations sets: \n" , annotated_cells)

# And the results of the internal test:

test_results = annotator.test_results


Predicted annotations sets: 
                    barcode  sample stage sequencing.batch theiler  \
index                                                               
cell_43479  AAACATACCATGCA      19  E7.5                2    TS11   
cell_43481  AAACATACGGTACT      19  E7.5                2    TS11   
cell_43483  AAACATACTTGACG      19  E7.5                2    TS11   
cell_43484  AAACATTGACGGAG      19  E7.5                2    TS11   
cell_43487  AAACATTGGGTGAG      19  E7.5                2    TS11   
...                    ...     ...   ...              ...     ...   
cell_52454  TTTAGGCTGTTCAG      20  E7.5                2    TS11   
cell_52457  TTTATCCTCTGTTT      20  E7.5                2    TS11   
cell_52458  TTTCAGTGACTACG      20  E7.5                2    TS11   
cell_52464  TTTGCATGACGCTA      20  E7.5                2    TS11   
cell_52465  TTTGCATGGACGTT      20  E7.5                2    TS11   

            doub.density  doublet  cluster  cluster.sub  cluster.stage 

In [8]:

# We can get the results from the adata object and store in a classical df:
# predicted labels sntands for the predictions of the underlying model without conformal prediction.

results = []
for pred,cp_pred_001,cp_pred_005, cp_pred_010 in zip(
        annotator.adata_query.obs["predicted_labels"],
        annotator.adata_query.obs["prediction_sets_0.01"],
        annotator.adata_query.obs["prediction_sets_0.05"],
        annotator.adata_query.obs["prediction_sets_0.1"] ):
        
        
    #print(f"Predicted: {pred} - CP 0.01: {cp_pred_001} - CP 0.05: {cp_pred_005} - CP 0.10: {cp_pred_010}")
        
    results.append({
        "Predicted": pred,
        "CP 0.01": cp_pred_001,
        "CP 0.05": cp_pred_005,
        "CP 0.1": cp_pred_010
    })
    

df_results = pd.DataFrame(results)
#df_results.to_csv("saved_Results.csv", index=False)  # Save to CSV if needed

In [10]:
df_results.head(10)

Unnamed: 0,Predicted,CP 0.01,CP 0.05,CP 0.1
0,Mesenchyme,[Mesenchyme],[Mesenchyme],[Mesenchyme]
1,Rostral neurectoderm,[Rostral neurectoderm],[],[Rostral neurectoderm]
2,OOD,[OOD],[OOD],[OOD]
3,OOD,[OOD],[OOD],[OOD]
4,Haematoendothelial progenitors,"[Haematoendothelial progenitors, Surface ectod...",[Haematoendothelial progenitors],[Haematoendothelial progenitors]
5,OOD,[OOD],[OOD],[OOD]
6,Mesenchyme,"[Blood progenitors 1, Gut, Haematoendothelial ...","[Haematoendothelial progenitors, Mesenchyme]","[Haematoendothelial progenitors, Mesenchyme]"
7,Rostral neurectoderm,"[Rostral neurectoderm, Surface ectoderm]",[Rostral neurectoderm],[Rostral neurectoderm]
8,Haematoendothelial progenitors,"[Blood progenitors 1, Haematoendothelial proge...",[Haematoendothelial progenitors],[Haematoendothelial progenitors]
9,Nascent mesoderm,"[Mixed mesoderm, Nascent mesoderm, Primitive S...","[Mixed mesoderm, Nascent mesoderm]",[Nascent mesoderm]


[0.9324392080307007,
 0.994827151298523,
 0.9989463686943054,
 0.996710479259491,
 0.9997898936271667,
 0.9871706366539001,
 0.5648959875106812,
 0.974643886089325,
 0.9802080392837524,
 0.9994446039199829,
 0.9996454119682312,
 0.9999656081199646,
 0.9975572824478149,
 0.9892067909240723]