In [None]:

%load_ext autoreload
%autoreload 2



import numpy as np
import pandas as pd
import scanpy as sc

# Import the ConformalSCAnnotator class 
from conformalSC_annotator import  ConformalSCAnnotator
from torchcp.classification.score import  APS


Current kernel CWD: /beegfs/home/mlopezdecas/single_cell/conformalized_single_cell_annotator


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#  Load the reference dataset

reference_adata_path = 'path_to_query/your_reference.h5ad'
adata_reference = sc.read_h5ad(reference_adata_path)

query_adata_path = 'path_to_query/your_query.h5ad'
adata_query = sc.read_h5ad(query_adata_path)

gene_names_column = "features"  # Column name in the query data with gene names (e.g., "features", "gene_names", etc.). If not existing, it will be created from adata.var_names.
underlying_model = "torch_net"  # Choose between "torch_net", "celltypist" , "scmap"

# Initialize the annotator
annotator = ConformalSCAnnotator(adata_query,
                                var_query_gene_column_name = gene_names_column,
                                underlying_model = underlying_model)    

network_architecture:dict = {   
                            "hidden_sizes": [ 72,64,32, 32],
                            "dropout_rates": [ 0.15, 0.15, 0.15,  0.15],
                            "learning_rate": 1e-4}



OOD_detector_config = {
                        "alpha": None,                              # Significance level for the hyoothesis test. 
                        "delta": 0.1,                               # Only for conditional pvalues
                        "hidden_sizes":  [50, 48, 32, 24],          # AE hidden sizes and topology of the network
                        "dropout_rates": [0.15, 0.15, 0.15, 0.15],  
                        "learning_rate": 1e-4,
                        "batch_size":    72,
                        "n_epochs":      1000,                      # Number of epochs for training the OOD detector
                        "patience":      9,
                        "noise_level":   0.1,
                        "lambda_sparse": 1e-3}





do_test = True          # If True, a small fraction of the reference set is reserved as an independent test set.

taxonomy = "standard"       # Choose from: "standard", "mondrian", or "cluster"
cells_OOD = 50              # Exclude cell types with fewer than 50 cells (Optional, it could be an int or a list of cell types)
nc_function = APS()         # Non-conformity function compatible with torchCP
ref_column = "cell_type"    # Column name in the reference data with class labels (e.g., "cell_type", "cell_class", etc.)

annotator.configure(reference_path = adata_reference,                  # Path or AnnData object (.h5ad) for the reference dataset
                    model_architecture = network_architecture,   # Optional: user-defined model; otherwise defaults are used
                    OOD_detector = OOD_detector_config,          # Optional: specify custom OOD detector config
                    CP_predictor = taxonomy,                     
                    cell_names_column = ref_column,              # Column name in reference data with class labels 
                    cell_types_excluded_treshold = cells_OOD,    
                    test =  do_test,                             
                    alpha = [0.01, 0.05, 0.1],                   # List of confidence levels for prediction sets. Can be a single float too; e.g. alpha = 0.1
                    non_conformity_function = nc_function,       # NC-function provided by or compatible with torchCP    
                    epoch = 1000,                                # Only applicable if using "torch_net" as underlying model
                    batch_size = 72,                             # Only applicable if using "torch_net" as underlying model
                    random_state = None)                         # Random seed for reproducibility



obsm_layer_ = "obsm"         # Choose from: None (adata.X), "obsm" (adata.obsm), or "layer" (adata.layers)
layer_ = None                # Required only if obsm_layer_ is "layer" — provide the layer name to use
obsm_ = "X_pca_harmony"      # Required if obsm_layer_ is "obsm" — name of the embedding in adata.obsm

obsm_OOD_ = "X_pca_harmony"  # The embedding used by the OOD detector (typically the same as obsm_). If None, adata.X will be used.


annotator.annotate(obsm_layer = obsm_layer_,
                    obsm = obsm_,
                    layer = layer_,
                    obsm_OOD = obsm_OOD_)


# Retrieve the annotated observations from the query dataset
annotated_cells = annotator.adata_query.obs                     
print("\nPredicted annotations sets: \n" , annotated_cells)



# Internal test results (if do_test = True):
test_results = annotator.test_results

# Unique labels from the reference data
unique_labels = annotator.unique_labels

# Automatically determined alpha (if alpha=None in OOD_detector_config)
alpha_OOD = annotator.alpha_OOD


#  Extract annotation results into a pandas DataFrame


results = []
for pred,cp_pred_001,cp_pred_005, cp_pred_010 in zip(
        annotator.adata_query.obs["predicted_labels"],
        annotator.adata_query.obs["prediction_sets_0.01"],
        annotator.adata_query.obs["prediction_sets_0.05"],
        annotator.adata_query.obs["prediction_sets_0.1"] ):
         
    #print(f"Predicted: {pred} - CP 0.01: {cp_pred_001} - CP 0.05: {cp_pred_005} - CP 0.10: {cp_pred_010}")
        
    results.append({
        "Predicted": pred,
        "CP 0.01": cp_pred_001,
        "CP 0.05": cp_pred_005,
        "CP 0.1": cp_pred_010
    })
    

df_results = pd.DataFrame(results)
#df_results.to_csv("saved_Results.csv", index=False)  # Save to CSV if needed

df_results.head(10)