<a href="https://colab.research.google.com/github/blindauth/abstention_experiments/blob/master/abstention_experiments/diabetic_retinopathy/RunDiabeticRetinopathyExperiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!rm -r abstention
!pip uninstall abstention
!git clone https://github.com/blindauth/abstention.git
!pip install abstention/
![[ -e abstention_experiments ]] || git clone https://github.com/blindauth/abstention_experiments.git

Uninstalling abstention-0.1.2.1:
  Would remove:
    /usr/local/lib/python3.6/dist-packages/abstention-0.1.2.1.dist-info/*
    /usr/local/lib/python3.6/dist-packages/abstention/*
Proceed (y/n)? y
  Successfully uninstalled abstention-0.1.2.1
Cloning into 'abstention'...
remote: Enumerating objects: 51, done.[K
remote: Counting objects: 100% (51/51), done.[K
remote: Compressing objects: 100% (38/38), done.[K
remote: Total 51 (delta 18), reused 42 (delta 9), pack-reused 0[K
Unpacking objects: 100% (51/51), done.
Processing ./abstention
Building wheels for collected packages: abstention
  Building wheel for abstention (setup.py) ... [?25l[?25hdone
  Stored in directory: /tmp/pip-ephem-wheel-cache-dmeowolv/wheels/60/8e/d2/9e9ca02e7b5f76bfda2e2daa6dcbe42c19095c502ccb653729
Successfully built abstention
Installing collected packages: abstention
Successfully installed abstention-0.1.2.1


In [2]:
!ls abstention_experiments/diabetic_retinopathy

flip-False_rotamt-0    flip-False_rotamt-90  flip-True_rotamt-270
flip-False_rotamt-180  flip-True_rotamt-0    flip-True_rotamt-90
flip-False_rotamt-270  flip-True_rotamt-180  valid_labels.txt.gz


In [0]:
import numpy as np
import gzip
from sklearn.preprocessing import LabelBinarizer

#The model used was the one made available at
# https://github.com/JeffreyDF/kaggle_diabetic_retinopathy
# It was the fifth-place winner of the kaggle DR competition
#Because this was a kaggle competition, we don't have access to the test-set
# labels, thus we are going to split the validation set that was used to
# train this publicly-available model into a 'pseudo test set' and a 
# 'pseudo validation set'.
orig_valid_labels = LabelBinarizer().fit_transform(
    np.array([float(x.decode("utf-8").split("\t")[1])
              for x in gzip.open("abstention_experiments/diabetic_retinopathy/valid_labels.txt.gz",'rb')]))

#num_folds is the number of folds we are going to use to split the data
# into a 'pseudo test' and 'pseudo validation' set
num_folds = 30

In [0]:
#Each 'parent folder' stores the predictions of the model under different
# input data transformations (flip or rotation). We do this facilititate
# data augmentation, which is helpful when we simulate distribution shift
parent_folders = ["flip-False_rotamt-0",
                  "flip-True_rotamt-0",
                  "flip-False_rotamt-90",
                  "flip-True_rotamt-90",
                  "flip-False_rotamt-180",
                  "flip-True_rotamt-180",
                  "flip-False_rotamt-270",
                  "flip-True_rotamt-270"]

#read in the predictions for the different transformations, both deterministic
# and with test-time dtropout enabled
parent_folder_to_det_pred = {}
for parent_folder in parent_folders:
    det_preds = np.array([
            [float(y) for y in x.decode("utf-8").split("\t")[1:]]
             for x in gzip.open("abstention_experiments/diabetic_retinopathy/"
                          +parent_folder+"/deterministic_preds.txt.gz", 'rb')])
    parent_folder_to_det_pred[parent_folder] = det_preds
    
parent_folder_to_nondet_pred = {}
parent_folder_to_mean_nondet_pred = {}
for parent_folder in parent_folders:
    nondet_preds = []
    for i in range(100):
        single_nondet_pred = np.array([
            [float(y) for y in x.decode("utf-8").split("\t")[1:]]
             for x in gzip.open("abstention_experiments/diabetic_retinopathy/"
              +parent_folder+"/nondeterministic_preds_"+str(i)+".txt.gz", 'rb')])
        nondet_preds.append(single_nondet_pred)
    nondet_preds = np.array(nondet_preds)
    parent_folder_to_nondet_pred[parent_folder] = nondet_preds
    #Also take the mean of the test-time dropout runs so that we can compare
    # the results of weight rescaling to the results of monte-carlo dropout
    parent_folder_to_mean_nondet_pred[parent_folder] = np.mean(nondet_preds,axis=0)

In [5]:
import abstention
from abstention.calibration import (compute_ece, compute_ece_with_bins,
                                    TempScaling)
from abstention.label_shift import EMImbalanceAdapter
from abstention.abstention import (
    weighted_kappa_metric, EstMarginalWeightedKappa, DistMaxClassProbFromOne,
    Entropy, Uncertainty, OneMinusJSDivFromClassFreq)
from collections import OrderedDict
import numpy as np
import sys
import json
import os


def inverse_softmax(preds):
    return np.log(preds) - np.mean(np.log(preds),axis=1)[:,None]

quadratic_weights = np.array([[(i-j)**2 for i in range(5)]
                               for j in range(5)])

#When a test-set distribution shift is simulated, these are the factors
# by which different classes will be relatively upsampled.
imbalance_upsampling=[1, 2, 5, 8, 8]

abstname_to_factory = OrderedDict([
    ("estmarginalweightedkappa",
     EstMarginalWeightedKappa(weights=quadratic_weights,
                      verbose=False, mode='argmax')),
    ("jsdiv", OneMinusJSDivFromClassFreq()),
    ("distmaxclassprobfromone", DistMaxClassProbFromOne()),
    ("entropy", Entropy()),
    ("variance", Uncertainty())
])


imbalanceadapterfactory = EMImbalanceAdapter(verbose=False)

abstfracs = [0.15, 0.30]

for useweightrescalepreds in [True, False]:
  
  if (useweightrescalepreds == True):
    preds_lookup = parent_folder_to_det_pred
  else:
    preds_lookup = parent_folder_to_mean_nondet_pred
  ttdsamples_lookup = parent_folder_to_nondet_pred
  
  for labelshift in [True, False]:
    
    if (labelshift==True):
      #if labelshift is True, need to use bias-corrected temperature
      # scaling to get rid of bias in the calibration that can compromise
      # label shift adaptation
      calibfactory = TempScaling(ece_bins=15, verbose=False,
                           bias_positions=[0,1,2,3,4])
    else:
      #otherwise, use regular temperature scaling.
      calibfactory = TempScaling(ece_bins=15, verbose=False)
    
    print("\nuseweightrescalepreds", useweightrescalepreds)
    print("labelshift", labelshift)
    
    methodprefix_to_baseperfs = OrderedDict()
    if (labelshift==False):
      methodprefix_to_baseperfs["yescalib"] = []
      methodprefix_to_baseperfs["nocalib"] = []
    else:
      methodprefix_to_baseperfs["yesadapt"] = []
      methodprefix_to_baseperfs["noadapt"] = []
    
    abstfrac_to_methodname_to_perfs = OrderedDict()
    for abstfrac in abstfracs:
      methodname_to_perfs = OrderedDict()
      for abstname in abstname_to_factory:
        if (labelshift==False):
          methodname_to_perfs["yescalib:"+abstname] = []
          methodname_to_perfs["nocalib:"+abstname] = []
        else:
          methodname_to_perfs["yesadapt:"+abstname] = []
          methodname_to_perfs["noadapt:"+abstname] = []
      abstfrac_to_methodname_to_perfs[abstfrac] = methodname_to_perfs
    
    for fold_number in range(num_folds):
      
      print("\non fold",fold_number)

      rng = np.random.RandomState(fold_number*1000)

      #to avoid contamination between the validation and test sets, we will
      # do the split according to patient ids.
      #the data is in pairs of (left eye, right eye) per patient (entry for
      # the right eye comes after the entry for the left eye); hence, the number of
      # unique patients is 0.5*len(valid_labels)
      patient_id_ordering = list(range(int(0.5*len(orig_valid_labels))))
      rng.shuffle(patient_id_ordering) #shuffle the patient id order

      #prepare the valid-test split, with an imbalance shift if appropriate
      pseudovalid_uncalib_preds = []
      pseudotest_uncalib_preds = []
      pseudovalid_uncalib_ttdsamples = [] #ttd = test-time dropout
      pseudotest_uncalib_ttdsamples = []
      pseudovalid_labels = []
      pseudotest_labels = []
      pseudovalid_label_counts = np.zeros(5)
      pseudotest_label_counts = np.zeros(5)
      for i in patient_id_ordering:
        left_eye_label = orig_valid_labels[2*i]
        right_eye_label = orig_valid_labels[(2*i)+1]
        most_diseased_label = max(np.argmax(left_eye_label),
                                  np.argmax(right_eye_label))
        #If labelshift=False, strive for roughly equal representation
        # of all classes. If imbalanceshift=True, we will modify the
        # distributions in a later step
        if (pseudovalid_label_counts[most_diseased_label] <
          pseudotest_label_counts[most_diseased_label]):
          in_test = False #append these examples to the validation set
          appendto_uncalib_preds = pseudovalid_uncalib_preds
          appendto_uncalib_ttdsamples = pseudovalid_uncalib_ttdsamples
          appendto_labels = pseudovalid_labels
          appendto_label_counts = pseudovalid_label_counts
        else:
          in_test = True #append these examples to the test set
          appendto_uncalib_preds = pseudotest_uncalib_preds
          appendto_uncalib_ttdsamples = pseudotest_uncalib_ttdsamples
          appendto_labels = pseudotest_labels
          appendto_label_counts = pseudotest_label_counts

        #increment by label counts for left and right eye
        appendto_label_counts += orig_valid_labels[2*i]
        appendto_label_counts += orig_valid_labels[(2*i)+1]
        #iterate through all the augmentation splits
        for parent_folder_idx,parent_folder in enumerate(parent_folders):
          #we include a particular augmentation split if:
          # (1) we are in the validation set, OR
          # (2) labelshift=False, OR
          # (3) labelshift=True, AND imbalance_upsampling for the example's
          #     class is greater than the parent folder idx.
          #Left eye:
          if ((in_test==False) or (labelshift==False) or
              imbalance_upsampling[
               np.argmax(orig_valid_labels[2*i])] > parent_folder_idx):
              appendto_labels.append(orig_valid_labels[2*i])
              appendto_uncalib_preds.append(
                      preds_lookup[parent_folder][2*i])
              appendto_uncalib_ttdsamples.append(
                  ttdsamples_lookup[parent_folder][:,(2*i)])                  
          #Right eye:
          if ((in_test==False) or (labelshift==False) or
              imbalance_upsampling[
                  np.argmax(orig_valid_labels[(2*i) + 1])] > parent_folder_idx): 
              appendto_labels.append(orig_valid_labels[(2*i)+1])
              appendto_uncalib_preds.append(
                  preds_lookup[parent_folder][(2*i)+1])
              appendto_uncalib_ttdsamples.append(
                  ttdsamples_lookup[parent_folder][:,(2*i)+1])

      #cast things to np arrays, infer softmax logits where needed
      # (logits are used during calibration)
      pseudovalid_uncalib_preds = np.array(pseudovalid_uncalib_preds)
      pseudotest_uncalib_preds = np.array(pseudotest_uncalib_preds)
      pseudovalid_uncalib_pred_logits = inverse_softmax(pseudovalid_uncalib_preds)
      pseudotest_uncalib_pred_logits = inverse_softmax(pseudotest_uncalib_preds)
      pseudovalid_uncalib_ttdsamples = np.array(
          pseudovalid_uncalib_ttdsamples).transpose((1,0,2))
      pseudotest_uncalib_ttdsamples = np.array(
          pseudotest_uncalib_ttdsamples).transpose((1,0,2))
      pseudovalid_uncalib_ttdsamples_logits = np.array([
              inverse_softmax(x) for x in pseudovalid_uncalib_ttdsamples])        
      pseudotest_uncalib_ttdsamples_logits = np.array([
              inverse_softmax(x) for x in pseudotest_uncalib_ttdsamples])
      pseudovalid_labels = np.array(pseudovalid_labels) 
      pseudotest_labels = np.array(pseudotest_labels)

      #Apply calibration
      the_calibrator = calibfactory(
                          valid_preacts=pseudovalid_uncalib_pred_logits,
                          valid_labels=pseudovalid_labels)
      pseudovalid_calib_preds = the_calibrator(pseudovalid_uncalib_pred_logits)
      pseudotest_calib_preds = the_calibrator(pseudotest_uncalib_pred_logits)   
      pseudotest_calib_ttdsamples = np.array(
          [the_calibrator(x) for x in pseudotest_uncalib_ttdsamples_logits])

      #Apply labelshift adaptation to the calibrated preds, if applicable
      if (labelshift==True):
        imbalance_adaptation_func = imbalanceadapterfactory(
            valid_labels=None,
            tofit_initial_posterior_probs=pseudotest_calib_preds,
            valid_posterior_probs=pseudovalid_calib_preds)
        pseudotest_adapted_preds = imbalance_adaptation_func(
            pseudotest_calib_preds)
        pseudotest_adapted_ttdsamples = np.array([
                      imbalance_adaptation_func(x) for
                      x in pseudotest_calib_ttdsamples])

      if (labelshift==False):
        methodprefix_and_valstouse = [
            ('yescalib', pseudotest_calib_preds, pseudotest_calib_ttdsamples),
            ('nocalib', pseudotest_uncalib_preds, pseudotest_uncalib_ttdsamples),
        ]
      else:
        methodprefix_and_valstouse = [
            ('yesadapt', pseudotest_adapted_preds, pseudotest_adapted_ttdsamples),
            ('noadapt', pseudotest_calib_preds, pseudotest_calib_ttdsamples),
        ]

      for (methodprefix, pseudotest_preds_to_use,
           pseudotest_ttdsamples_to_use) in methodprefix_and_valstouse:

        #compute the pre-abstention perf
        base_weighted_kappa = weighted_kappa_metric(
                predprobs=pseudotest_preds_to_use,
                true_labels=pseudotest_labels,
                weights=quadratic_weights,
                mode='argmax')
        methodprefix_to_baseperfs[methodprefix].append(base_weighted_kappa)
        
        print("\n"+methodprefix,"base perf", base_weighted_kappa)

        #take the variance in the most confident class, as described in
        # https://arxiv.org/pdf/1705.08500.pdf
        pseudotest_ttdsamples_var =\
          np.var(pseudotest_ttdsamples_to_use, axis=0)[
           list(range(pseudotest_ttdsamples_to_use.shape[1])),
           np.argmax(np.mean(pseudotest_ttdsamples_to_use, axis=0),axis=-1)]
        #Iterate over abstention methods
        for abstname in abstname_to_factory:
          abstfunc = abstname_to_factory[abstname](
                      valid_labels=None, valid_posterior=None)
          abstpriorities = abstfunc(posterior_probs=pseudotest_preds_to_use,
                                    uncertainties=pseudotest_ttdsamples_var)
          sortedabstindices = np.array([
              x[0] for x in sorted(enumerate(abstpriorities),
                                   key=lambda x: x[1])])
          #Iterate over abstention levels
          for abstfrac in abstfracs:
            indices_to_retain =\
              sortedabstindices[:int(len(sortedabstindices )*(1-abstfrac))]
            retained_pseudotest_preds = np.array([
                pseudotest_preds_to_use[i] for i in indices_to_retain])
            retained_pseudotest_labels = np.array([
                pseudotest_labels[i] for i in indices_to_retain])
            #Compute performance
            postabst_weighted_kappa = weighted_kappa_metric(
                predprobs=retained_pseudotest_preds,
                true_labels=retained_pseudotest_labels,
                weights=quadratic_weights,
                mode='argmax')
            abstfrac_to_methodname_to_perfs[abstfrac][
                methodprefix+":"+abstname].append(postabst_weighted_kappa)
            print("abstfrac",abstfrac,
                  abstname,"perf",postabst_weighted_kappa)
    
    file_out = (
        "diabeticretinopathy_useweightrescalpreds-"+str(useweightrescalepreds)
        +"_labelshift-"+str(labelshift)+"_abstention_results.json")
    dict_to_write = {
        "methodprefix_to_baseperfs": methodprefix_to_baseperfs,
        "abstfrac_to_methodname_to_perfs": abstfrac_to_methodname_to_perfs}
    open(file_out, 'w').write(
        json.dumps(dict_to_write,
                   sort_keys=True, indent=4, separators=(',', ': ')))
    os.system("gzip -f "+file_out)

    
    



useweightrescalepreds True
labelshift True

on fold 0

yesadapt base perf 0.7754512877213324
abstfrac 0.15 estmarginalweightedkappa perf 0.8324768888107463
abstfrac 0.3 estmarginalweightedkappa perf 0.8537599471919868
abstfrac 0.15 jsdiv perf 0.8347235392239909
abstfrac 0.3 jsdiv perf 0.8403852847447665
abstfrac 0.15 distmaxclassprobfromone perf 0.8052640243357561
abstfrac 0.3 distmaxclassprobfromone perf 0.8347340755396746
abstfrac 0.15 entropy perf 0.8158252986087732
abstfrac 0.3 entropy perf 0.8506095772251889
abstfrac 0.15 variance perf 0.7722051019463386
abstfrac 0.3 variance perf 0.7975799297885889

noadapt base perf 0.7170236582198334
abstfrac 0.15 estmarginalweightedkappa perf 0.7906293630157774
abstfrac 0.3 estmarginalweightedkappa perf 0.8221966093192521
abstfrac 0.15 jsdiv perf 0.8159231695944269
abstfrac 0.3 jsdiv perf 0.8420341533130891
abstfrac 0.15 distmaxclassprobfromone perf 0.7475502648034935
abstfrac 0.3 distmaxclassprobfromone perf 0.7710964288192184
abstfrac 0.15 