# Tutorial: GMM Postprocessing and Hard Labeling 

The following jupyter notebook focuses on the GMM postprocessing and hard labeling approach considered in the article of this package. The GMM postprocessing allows mapping scores to probabilities. During hard labeling, we aim to assign each cell a state based on the signature scores associated with the corresponding state. Depending on our data assumptions, we either consider that a cell belongs to one of the signature associated states or none of them. Therefore, we will consider two scenarios on the preprocessed PBMC dataset also used in the `basic_usage.ipynb` jupyter notebook. :
1. We score **three signatures**, one associated with each of the three available cell states: B-, Monocytes, and NK-cells. Since each cell belongs to one of the three cell-types, we select **K=3** during  GMM postprocessing. Finally, we can assign the cell state associated with the highest-scoring/ highest probability signature to a cell. 
2. We score **two signatures** associated with two of the three available cell states. Now we assume that a cell can belong to either one of the cell states we score **or** to none. Therefore we select **K=2+1** during GMM postprocessing (+ 1 being the remaining class). Hard labeling on probabilities is based on selecting the cell state associated with the signature with the highest probability. The process on scores, however, requires finding the correct signature-specific thresholds for each signature and then assigning a cell with the state of the remaining class if its scores are below the signature-specific thresholds. See below for more information. 

*We will use the scored data from the `basic_usage.ipynb` jupyter notebook.*

In [1]:
import scanpy as sc
import pandas as pd

from signaturescoring.scoring_methods.gmm_postprocessing import GMMPostprocessor

sc.settings.verbosity = 0

## Load preprocessed and scored data

In [2]:
adata = sc.read_h5ad('tut_data/pp_pbmc_b_mono_nk_scored.h5ad')

## To avoid errors 
if 'log1p' in adata.uns_keys():
    adata.uns['log1p']['base'] = None
else:
    adata.uns['log1p'] = {'base': None}

In [3]:
adata.obs['celltype.l1'].value_counts()

Mono    43553
NK      14408
B       10613
Name: celltype.l1, dtype: int64

## GMM Postprocessing 
As previously explained we would like to map the signature scores to probabilities using a GMM. The `GMMPostprocessor` first trains a GMM on the passed score columns (3 columns in scenario 1, 2 columns in scenario 2). Then we assign a cluster to a cell state based on the correlation of the pobabilities and the signature scores.

### Scenario 1: Three signatures

In [4]:
scoring_names_3_sigs = list(adata.obs.columns[-(8*3):])
scoring_names_3_sigs

['ANS_B',
 'ANS_Mono',
 'ANS_NK',
 'Scanpy_B',
 'Scanpy_Mono',
 'Scanpy_NK',
 'Tirosh_B',
 'Tirosh_Mono',
 'Tirosh_NK',
 'Tirosh_AG_B',
 'Tirosh_AG_Mono',
 'Tirosh_AG_NK',
 'Tirosh_LVG_B',
 'Tirosh_LVG_Mono',
 'Tirosh_LVG_NK',
 'Jasmine_LH_B',
 'Jasmine_LH_Mono',
 'Jasmine_LH_NK',
 'Jasmine_OR_B',
 'Jasmine_OR_Mono',
 'Jasmine_OR_NK',
 'UCell_B',
 'UCell_Mono',
 'UCell_NK']

In [5]:
## Helper list with colum names 
scoring_names_3_gmm = []

In [6]:
# For the THREE cell state signature score colums of each scoring method we fit a 
# GMM to convert the scores into probabilities
for i in range(0, len(scoring_names_3_sigs), 3):
    
    # initialize GMMPostprocessor
    gmm_post = GMMPostprocessor(
        n_components=3
    )
    
    # fit the GMM model on the 3 columns of the scores 
    store_name_pred, store_names_proba, _ = gmm_post.fit_and_predict(adata, scoring_names_3_sigs[i:(i+3)])
    
    # assign clusters to signatures 
    assignments = gmm_post.assign_clusters_to_signatures(adata, scoring_names_3_sigs[i:(i+3)], store_names_proba, plot=False)
    print(assignments)
    
    for key, val in assignments.items():
        adata.obs[key+'_gmm_3sigs'] = adata.obs[val].copy()
        scoring_names_3_gmm.append(key+'_gmm_3sigs')

{'ANS_B': 'ANS_2_GMM_proba', 'ANS_Mono': 'ANS_0_GMM_proba', 'ANS_NK': 'ANS_1_GMM_proba'}
{'Scanpy_B': 'Scanpy_2_GMM_proba', 'Scanpy_Mono': 'Scanpy_1_GMM_proba', 'Scanpy_NK': 'Scanpy_0_GMM_proba'}
{'Tirosh_B': 'Tirosh_2_GMM_proba', 'Tirosh_Mono': 'Tirosh_1_GMM_proba', 'Tirosh_NK': 'Tirosh_0_GMM_proba'}
{'Tirosh_AG_B': 'Tirosh_AG_2_GMM_proba', 'Tirosh_AG_Mono': 'Tirosh_AG_0_GMM_proba', 'Tirosh_AG_NK': 'Tirosh_AG_1_GMM_proba'}
{'Tirosh_LVG_B': 'Tirosh_LVG_2_GMM_proba', 'Tirosh_LVG_Mono': 'Tirosh_LVG_0_GMM_proba', 'Tirosh_LVG_NK': 'Tirosh_LVG_1_GMM_proba'}
{'Jasmine_LH_B': 'Jasmine_LH_2_GMM_proba', 'Jasmine_LH_Mono': 'Jasmine_LH_1_GMM_proba', 'Jasmine_LH_NK': 'Jasmine_LH_0_GMM_proba'}
{'Jasmine_OR_B': 'Jasmine_OR_2_GMM_proba', 'Jasmine_OR_Mono': 'Jasmine_OR_1_GMM_proba', 'Jasmine_OR_NK': 'Jasmine_OR_0_GMM_proba'}
{'UCell_B': 'UCell_0_GMM_proba', 'UCell_Mono': 'UCell_1_GMM_proba', 'UCell_NK': 'UCell_2_GMM_proba'}


### Scenario 2: Two signatures

In [7]:
scoring_names_2_sigs = [x for x in scoring_names_3_sigs if 'NK' not in x] 
scoring_names_2_sigs

['ANS_B',
 'ANS_Mono',
 'Scanpy_B',
 'Scanpy_Mono',
 'Tirosh_B',
 'Tirosh_Mono',
 'Tirosh_AG_B',
 'Tirosh_AG_Mono',
 'Tirosh_LVG_B',
 'Tirosh_LVG_Mono',
 'Jasmine_LH_B',
 'Jasmine_LH_Mono',
 'Jasmine_OR_B',
 'Jasmine_OR_Mono',
 'UCell_B',
 'UCell_Mono']

In [8]:
## Helper list with colum names 
scoring_names_2_gmm = []

In [9]:
# For the TWO cell state signature score colums of each scoring method we fit a 
# GMM to convert the scores into probabilities

for i in range(0, len(scoring_names_2_sigs), 2):
    # initialize GMMPostprocessor
    gmm_post = GMMPostprocessor(
        n_components=3
    )
    
    # fit the GMM model on the 2 columns of the scores 
    store_name_pred, store_names_proba, _ = gmm_post.fit_and_predict(adata, scoring_names_2_sigs[i:(i+2)])
    
    # assign clusters to signatures
    assignments = gmm_post.assign_clusters_to_signatures(adata, scoring_names_2_sigs[i:(i+2)], store_names_proba, plot=False)
    
    print(assignments)
    for key, val in assignments.items():
        if key =='rest':
            continue
        adata.obs[key+'_gmm_2sigs'] = adata.obs[val].copy()
        scoring_names_3_gmm.append(key+'_gmm_2sigs')
    
    curr_name = '_'.join(scoring_names_2_sigs[i].split('_')[0:-1])
    adata.obs[curr_name +'_NK_gmm_2sigs'] = adata.obs[next(iter(assignments['rest']))].copy()
    scoring_names_3_gmm.append(curr_name +'_NK_gmm_2sigs')

{'ANS_B': 'ANS_2_GMM_proba', 'ANS_Mono': 'ANS_1_GMM_proba', 'rest': {'ANS_0_GMM_proba'}}
{'Scanpy_B': 'Scanpy_2_GMM_proba', 'Scanpy_Mono': 'Scanpy_1_GMM_proba', 'rest': {'Scanpy_0_GMM_proba'}}
{'Tirosh_B': 'Tirosh_1_GMM_proba', 'Tirosh_Mono': 'Tirosh_0_GMM_proba', 'rest': {'Tirosh_2_GMM_proba'}}
{'Tirosh_AG_B': 'Tirosh_AG_1_GMM_proba', 'Tirosh_AG_Mono': 'Tirosh_AG_0_GMM_proba', 'rest': {'Tirosh_AG_2_GMM_proba'}}
{'Tirosh_LVG_B': 'Tirosh_LVG_2_GMM_proba', 'Tirosh_LVG_Mono': 'Tirosh_LVG_1_GMM_proba', 'rest': {'Tirosh_LVG_0_GMM_proba'}}
{'Jasmine_LH_B': 'Jasmine_LH_1_GMM_proba', 'Jasmine_LH_Mono': 'Jasmine_LH_2_GMM_proba', 'rest': {'Jasmine_LH_0_GMM_proba'}}
{'Jasmine_OR_B': 'Jasmine_OR_1_GMM_proba', 'Jasmine_OR_Mono': 'Jasmine_OR_0_GMM_proba', 'rest': {'Jasmine_OR_2_GMM_proba'}}
{'UCell_B': 'UCell_1_GMM_proba', 'UCell_Mono': 'UCell_0_GMM_proba', 'rest': {'UCell_2_GMM_proba'}}


In [10]:
## remove all unused obs columns
adata.obs.drop(columns=[x for x in adata.obs.columns if ('gmm' in x.lower()) and ('sigs' not in x.lower())], 
               inplace=True)

## Hard Labeling

### Scenario 1: Three signatures

##### Scores 

In [14]:
for i in range(0, len(scoring_names_3_sigs), 3):
    method = scoring_names_3_sigs[i].rsplit('_',1)[0]
    #adata.obs[scoring_names_3_sigs[i:(i+3)]].idxmax(axis="columns")
    adata.obs[f'{method}_scores_']
    

ANS
Scanpy
Tirosh
Tirosh_AG
Tirosh_LVG
Jasmine_LH
Jasmine_OR
UCell


##### Probabilities 

### Scenario 2: Two signatures

##### Scores 

##### Probabilities 

## Performances 

### Scenario 1: Three signatures

### Scenario 2: Two signatures