In [1]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.image import imread
from scipy.optimize import brentq
from scipy.stats import binom
import pandas as pd
import seaborn as sns
import random
import pdb
import csv
from tabulate import tabulate



# DataFrame Columns:

`probs`: A column containing lists or arrays of probabilities. Each element in the probs column should be an iterable (like a list or a numpy array) with probabilities corresponding to different classes. This is inferred because the function uses row['probs'].max() and row['probs'].argmax(), which are operations performed on lists or arrays to find the maximum value and the index of the maximum value, respectively.

`label_numeric`: A numeric label for each observation. This column is used to compute whether the predictions (yhat) are correct compared to the actual labels.

# Parameters Dictionary (params):
`This` function also uses a dictionary named params that is expected to include specific keys:

`n`: An integer specifying the number of observations to consider in some calculations or operations. If a validation DataFrame (val_df) is not provided, this parameter likely defines how the DataFrame is split into calibration and validation sets.

`alpha`: This parameter is referenced, but its use in the function is not directly visible from the code snippet. It could be part of further calculations not shown here.

`delta`: Similar to alpha, its usage isn't directly observed but is expected to play a role perhaps in subsequent calculations or conditions.

`lambda_max` and lambda_min: These parameters define the range for creating a grid of lambda values (lambdas) which could be used for regularization or other calculations in extended parts of the function.

`N_lambda`: The number of lambda values to generate within the specified range from lambda_min to lambda_max.

In [2]:
def fix_randomness(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    
def get_selective_results(df, name, params, val_df=None):
    # Unpack the parameters
    if val_df is not None:
        cal_df = df.copy(deep=True)
        params['n'] = len(cal_df)
    n, alpha, delta, lambda_max, lambda_min, N_lambda = (params[k] for k in ('n', 'alpha', 'delta', 'lambda_max', 'lambda_min', 'N_lambda'))
    
    # If only one dataframe is passed, split into cal/val
    if val_df is None:
        fix_randomness()
        val_df = df.sample(n=len(df)-n)
        cal_df = df.drop(val_df.index)
    
    # Define lambda grid
    lambdas = np.linspace(lambda_min, lambda_max, N_lambda)
    
    # Unpack the labels, probs, and estimated labels
    cal_df['max_prob'] = cal_df.apply(lambda row: row['probs'].max(), axis=1)
    cal_df['yhat'] = cal_df.apply(lambda row: row['probs'].argmax(), axis=1)
    cal_labels = cal_df['label_numeric'].squeeze()
    cal_phats = cal_df['max_prob']
    cal_yhats = cal_df['yhat']
    cal_corrects = (cal_yhats == cal_labels)
    
    val_df['max_prob'] = val_df.apply(lambda row: row['probs'].max(), axis=1)
    val_df['yhat'] = val_df.apply(lambda row: row['probs'].argmax(), axis=1)
    val_labels = val_df['label_numeric'].squeeze()
    val_phats = val_df['max_prob']
    val_yhats = val_df['yhat']
    val_corrects = (val_yhats == val_labels)
    
    results = {
        'labels' : [np.unique(np.concatenate([cal_labels, val_labels]))],
        'label counts (cal)' : [np.unique(cal_labels,return_counts=True)[1]],
        'label counts (val)' : [np.unique(val_labels,return_counts=True)[1]],
        'accuracy (cal, before selection)' : cal_corrects.mean(),
        'accuracy (val, before selection)' : val_corrects.mean(),
        'dataset size (cal)' : cal_yhats.shape[0],
        'dataset size (val)' : val_yhats.shape[0],
    }
    results.update(params)
    
    
    # Define selective risk
    def selective_error(lam): return 1-cal_corrects[cal_phats > lam].mean()
    def nlambda(lam): return (cal_phats > lam).sum()
    lambdas_trunc = np.array([lam for lam in lambdas if nlambda(lam) >= 50]) # Make sure there's some data in the top bin.
    def invert_for_ub(r,lam): return binom.cdf(selective_error(lam)*nlambda(lam),nlambda(lam),r)-delta
    # Construct upper bound
    def selective_risk_ub(lam): return brentq(invert_for_ub,0,0.9999,args=(lam,))
    # Scan to choose lamabda hat
    for lhat in np.flip(lambdas_trunc):
        if selective_risk_ub(lhat-1/lambdas_trunc.shape[0]) > alpha: break
    if lhat == lambdas_trunc[-1]:
        raise Exception("This level is too stringent!")
    # Deploy procedure on test data
    predictions = val_phats >= lhat
    
    # Calculate initial metrics
    results['number correct (cal, predicted)'] = cal_corrects[cal_phats >= lhat].sum()
    results['number incorrect (cal, predicted)'] = (1-cal_corrects.astype(int))[cal_phats >= lhat].sum()
    results['number correct (cal, abstained)'] = cal_corrects[cal_phats < lhat].sum()
    results['number incorrect (cal, abstained)'] = (1-cal_corrects.astype(int))[cal_phats < lhat].sum()
    results['number correct (val, predicted)'] = val_corrects[predictions].sum()
    results['number incorrect (val, predicted)'] = (1-val_corrects.astype(int))[predictions].sum()
    results['number correct (val, abstained)'] = val_corrects[val_phats < lhat].sum()
    results['number incorrect (val, abstained)'] = (1-val_corrects.astype(int))[val_phats < lhat].sum()
    results['lambda_hat'] = lhat

    # Pretty-print and save results
    base_folder = './results/' + name + '/'
    os.makedirs(base_folder, exist_ok = True)
    results_df = pd.DataFrame(results)
    results_df = results_df.to_csv(base_folder + 'results.csv', index=False)
    print(tabulate(results_df))
    
    # Make plots
    sns.set(style='white')
    accuracy_selected = [val_corrects[val_phats > lam].mean() for lam in lambdas]
    fraction_selected = [(val_phats > lam).mean() for lam in lambdas]
    fraction_correct_points_selected = [val_corrects[val_phats > lam].sum()/val_corrects.sum() for lam in lambdas]
    # Make plots non-overlapping
    xlims = [0, None]
    plt.figure()
    plt.plot(lambdas,accuracy_selected,label='all',color='#2222AA',linewidth=2)
    plt.gca().axhline(y=1-alpha,linewidth=1.5,linestyle='dotted',label=r'target accuracy ($1-\alpha$)',color='gray')
    plt.gca().axvline(x=lhat,linewidth=1.5,linestyle='--',label=r'$\hat{\lambda}$',color='gray')
    sns.despine(top=True,right=True)
    plt.legend(loc='lower right')
    plt.xlabel(r'$\lambda$')
    plt.ylabel('accuracy')
    plt.xlim(xlims)
    plt.tight_layout()
    plt.savefig(base_folder + 'accuracy_vs_lambda.pdf')
    plt.close()

    plt.figure()
    plt.plot(lambdas,fraction_selected,color='#DD5500', linewidth=2)
    plt.gca().axvline(x=lhat,linewidth=1.5,linestyle='--',label=r'$\hat{\lambda}$',color='gray')
    sns.despine(top=True,right=True)
    plt.legend(loc='lower left')
    plt.xlabel(r'$\lambda$')
    plt.ylabel('fraction selected of total val')
    plt.xlim(xlims)
    plt.tight_layout()
    plt.savefig(base_folder + 'fraction_predicted_vs_lambda.pdf')
    plt.close()

    plt.figure()
    plt.plot(lambdas,fraction_correct_points_selected,color='#AA22AA', linewidth=2)
    plt.gca().axvline(x=lhat,linewidth=1.5,linestyle='--',label=r'$\hat{\lambda}$',color='gray')
    sns.despine(top=True,right=True)
    plt.legend(loc='lower left')
    plt.xlabel(r'$\lambda$')
    plt.ylabel('fraction selected of total correct')
    plt.xlim(xlims)
    plt.tight_layout()
    plt.savefig(base_folder + 'correct_points_predicted_over_total_correct_vs_lambda.pdf')
    plt.close()

In [3]:
# Process dataframe
def process_df(df, institution=None, abnormal=None, specimen_type=None, scanner=None, split=None):
    proc_df = df.copy(deep=True) #Create a copy df to modify
    # Get softmax score of every class
    cols = list(proc_df.columns)
    cols = [col for col in cols if 'prob_' in col]
    label_to_num = {
        cols[i].split('prob_')[1] : i for i in range(len(cols))
    }
    proc_df['probs'] = proc_df.apply(lambda row: np.array([row[col] for col in cols]), axis=1)
    # Remove or rename junk classes
    #proc_df['label'] = proc_df['label'].replace('MO3', 'MO2')
    proc_df = proc_df[~proc_df.label.isin(['U2', 'MO3'])]
    # Create numeric labels for every class
    proc_df['label_numeric'] = proc_df.apply(lambda row: label_to_num[row['label']], axis=1)
    # Stratify
    if institution is not None: # Possible values: MSKCC, UCSF
        proc_df = proc_df[proc_df.institution.isin(institution)]
    if abnormal is not None: # Possible values: True, False
        proc_df = proc_df[proc_df.abnormal.isin(abnormal)]
    if specimen_type is not None: # Possible values: PB, BMA
        proc_df = proc_df[proc_df.specimen_type.isin(specimen_type)]
    if scanner is not None: # Possible values: Aperio, Hamamatsu
        proc_df = proc_df[proc_df.scanner.isin(scanner)]
    if split is not None: # Possible values: Train, Test, Val
        proc_df = proc_df[proc_df.split.isin(split)]
    if len(proc_df) == 0:
        raise Exception("There are no data points in this stratum!")

    return proc_df

In [4]:
df = pd.read_csv('labeled_cells.csv')

In [5]:
# Experiment 1: Random split UCSF.
params = {
    'n' : 5000, # Will be overwritten if we manually input a cal/val set.
    'alpha' : 0.05,
    'delta' : 0.2,
    'lambda_max' : 1,
    'lambda_min' : 0,
    'N_lambda' : 5000,
}
total_df = process_df(df, institution=['UCSF'], split=['val','test'])
name = 'UCSF-val-test-random-split'
get_selective_results(total_df, name, params)




In [6]:
# Experiment 2: Random split MSKCC normal.
params = {
    'n' : 1500, # Will be overwritten if we manually input a cal/val set.
    'alpha' : 0.08,
    'delta' : 0.2,
    'lambda_max' : 1,
    'lambda_min' : 0,
    'N_lambda' : 5000,
}
total_df = process_df(df, institution=['MSKCC'], specimen_type=['BMA'], abnormal=[False])
name = 'MSKCC-normal-bma-random-split'
get_selective_results(total_df, name, params)




In [7]:
# Experiment 3: Random split MSKCC abnormal.
params = {
    'n' : 1500, # Will be overwritten if we manually input a cal/val set.
    'alpha' : 0.08,
    'delta' : 0.2,
    'lambda_max' : 1,
    'lambda_min' : 0,
    'N_lambda' : 5000,
}
total_df = process_df(df, institution=['MSKCC'], specimen_type=['BMA'], abnormal=[True])
name = 'MSKCC-abnormal-bma-random-split'
get_selective_results(total_df, name, params)




In [11]:
# Experiment 4: Random split MSKCC PB.
params = {
    'n' : 1000, # Will be overwritten if we manually input a cal/val set.
    'alpha' : 0.08,
    'delta' : 0.2,
    'lambda_max' : 1,
    'lambda_min' : 0,
    'N_lambda' : 5000,
}
total_df = process_df(df, institution=['MSKCC'], specimen_type=['PB'])
name = 'MSKCC-abnormal-pb-random-split'
get_selective_results(total_df, name, params)




This method was first introduced in https://arxiv.org/abs/2110.01052.