In [1]:
from __future__ import division, print_function

In [9]:
import abstention
reload(abstention)
reload(abstention.calibration)
reload(abstention.label_shift)
from abstention.calibration import TempScaling, ConfusionMatrix, softmax
from abstention.label_shift import EMImbalanceAdapter, BBSEImbalanceAdapter, ShiftWeightFromImbalanceAdapter
import glob
import gzip
import numpy as np
from collections import defaultdict

def read_labels(fh):
    to_return = []
    for line in fh:
        the_class=int(line.rstrip())
        to_add = np.zeros(10)
        to_add[the_class] = 1
        to_return.append(to_add)
    return np.array(to_return)

def read_preds(fh):
    return np.array([[float(x) for x in y.rstrip().split("\t")]
                     for y in fh])

def sample_from_probs_arr(arr_with_probs):
    rand_num = np.random.random()
    cdf_so_far = 0
    for (idx, prob) in enumerate(arr_with_probs):
        cdf_so_far += prob
        if (cdf_so_far >= rand_num
            or idx == (len(arr_with_probs) - 1)):  # need the
            # letterIdx==(len(row)-1) clause because of potential floating point errors
            # that mean arrWithProbs doesn't sum to 1
            return idx
        
test_labels = read_labels(gzip.open(glob.glob("test_labels.txt.gz")[0]))
test_class_to_indices = defaultdict(list)
for index,row in enumerate(test_labels):
    row_label = np.argmax(row)
    test_class_to_indices[row_label].append(index)
def draw_test_indices(total_to_return, label_proportions):
    indices_to_use = []
    for class_index, class_proportion in enumerate(label_proportions):
        indices_to_use.extend(np.random.choice(
                test_class_to_indices[class_index],
                int(total_to_return*class_proportion),
                replace=True))
    for i in range(total_to_return-len(indices_to_use)):
        class_index = sample_from_probs_arr(label_proportions)
        indices_to_use.append(
            np.random.choice(test_class_to_indices[class_index]))
    return indices_to_use

valid_labels = read_labels(gzip.open(glob.glob("valid_labels.txt.gz")[0]))

imbalance_adapters = [
    ('em_calib-None_init-default', EMImbalanceAdapter(calibrator_factory=None)),
    ('em_calib-tsnobiascorr_init-default', EMImbalanceAdapter(calibrator_factory=TempScaling(verbose=False))),
    ('em_calib-confusionmat_init-default', EMImbalanceAdapter(calibrator_factory=ConfusionMatrix(), verbose=False)),
    ('em_calib-confusionmat_init-BBSE-hard', EMImbalanceAdapter(calibrator_factory=ConfusionMatrix(), verbose=False,
                                                                    initialization_weight_ratio=
                                                                     ShiftWeightFromImbalanceAdapter(BBSEImbalanceAdapter(soft=False)))),
    ('em_calib-tswithbiascorr_init-default', EMImbalanceAdapter(calibrator_factory=
                                                       TempScaling(verbose=False,bias_positions=[0,1,2,3,4,5,6,7,8,9]))),
    ('bbse-hard_calib-None', BBSEImbalanceAdapter(soft=False, calibrator_factory=None)),
    ('bbse-soft_calib-None', BBSEImbalanceAdapter(soft=True, calibrator_factory=None)),
    #('bbse-soft_calib-tsnobiascorr', BBSE(soft=True, calibrator_factory=TempScaling(verbose=False))),
    #('bbse-hard_calib-tsnobiascorr', BBSE(soft=False, calibrator_factory=TempScaling(verbose=False))),
    #('bbse-soft_calib-tswithbiascorr', BBSE(soft=True, calibrator_factory=TempScaling(verbose=False,
    #                                                                                  bias_positions=[0,1,2,3,4,5,6,7,8,9]))), 
    #('bbse-hard_calib-tswithbiascorr', BBSE(soft=False, calibrator_factory=TempScaling(verbose=False,
    #                                                                                  bias_positions=[0,1,2,3,4,5,6,7,8,9]))),
]

dirichlet_alphas = [0.1, 1.0, 1.0]

In [None]:
import numpy as np
import random
import sys

dirichletalpha_to_adaptername_to_weightdiffnorm = defaultdict(lambda: defaultdict(list))
dirichletalpha_to_baselineacc = defaultdict(list)
dirichletalpha_to_adaptername_to_deltaacc = defaultdict(lambda: defaultdict(list))
num_trials = 10
for dirichlet_alpha in dirichlet_alphas:
    print("On alpha",dirichlet_alpha)
    for seed in [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
        print("Seed",seed)
        test_preds = softmax(preact=read_preds(gzip.open(glob.glob("cifar10_balanced_seed-"+str(seed)+"_*testpreds.txt.gz")[0])),
                             temp=1, biases=None)
        valid_preds = softmax(preact=read_preds(gzip.open(glob.glob("cifar10_balanced_seed-"+str(seed)+"_*validpreds.txt.gz")[0])),
                              temp=1, biases=None)
        for trial_num in range(num_trials):
            print("On trial num",trial_num)
            sys.stdout.flush()
            np.random.seed(trial_num*100)
            random.seed(trial_num*100)
            dirichlet_dist = np.random.dirichlet([dirichlet_alpha for x in range(10)])
            test_indices = draw_test_indices(total_to_return=10000,
                                             label_proportions=dirichlet_dist)
            shifted_test_labels = test_labels[test_indices]
            shifted_test_preds = test_preds[test_indices]
            
            shifted_test_baseline_accuracy = np.mean(np.argmax(shifted_test_labels,axis=-1)==
                                                     np.argmax(shifted_test_preds,axis=-1))
            dirichletalpha_to_baselineacc[dirichlet_alpha].append(shifted_test_baseline_accuracy)
            
            ideal_shift_weights = np.mean(shifted_test_labels,axis=0)/np.mean(valid_labels,axis=0)
            for adapter_name,imbalance_adapter in imbalance_adapters:
                #print(adapter_name)
                imbalance_adapter_func = imbalance_adapter(valid_labels=valid_labels,
                                                           tofit_initial_posterior_probs=shifted_test_preds,
                                                           valid_posterior_probs=valid_preds)  
                shift_weights = imbalance_adapter_func.multipliers
                adapted_shifted_test_preds = imbalance_adapter_func(shifted_test_preds)
                adapted_shifted_test_accuracy = np.mean(np.argmax(shifted_test_labels,axis=-1)==
                                                        np.argmax(adapted_shifted_test_preds,axis=-1))
                delta_from_baseline = adapted_shifted_test_accuracy-shifted_test_baseline_accuracy
                
                dirichletalpha_to_adaptername_to_weightdiffnorm[dirichlet_alpha][adapter_name].append(
                    np.linalg.norm(shift_weights-ideal_shift_weights))
                dirichletalpha_to_adaptername_to_deltaacc[dirichlet_alpha][adapter_name].append(
                    delta_from_baseline)
                
            for adapter_name in [x[0] for x in imbalance_adapters]:
                n = len(dirichletalpha_to_adaptername_to_deltaacc[dirichlet_alpha][adapter_name])
                if ((n%num_trials == 0 and n > 0) or n==1):
                    print(adapter_name),
                    print("delt acc",
                          np.mean(dirichletalpha_to_adaptername_to_deltaacc[dirichlet_alpha][adapter_name]),
                          "+/-",
                          (1.0/n)*np.std(dirichletalpha_to_adaptername_to_deltaacc[dirichlet_alpha][adapter_name],
                                         ddof=1))
                    print("weight diff norm",
                          np.mean(dirichletalpha_to_adaptername_to_weightdiffnorm[dirichlet_alpha][adapter_name]),
                          "+/-",
                          (1.0/n)*np.std(dirichletalpha_to_adaptername_to_weightdiffnorm[dirichlet_alpha][adapter_name],
                                         ddof=1))
                    sys.stdout.flush()

On alpha 0.1
Seed 0
On trial num 0
em_calib-None_init-default
delt acc 0.07799999999999996 +/- nan
weight diff norm 0.20063927995368255 +/- nan
em_calib-tsnobiascorr_init-default
delt acc 0.08619999999999994 +/- nan
weight diff norm 0.07704549049340365 +/- nan
em_calib-confusionmat_init-default
delt acc 0.0706 +/- nan
weight diff norm 0.27401555026902186 +/- nan
em_calib-confusionmat_init-BBSE-hard
delt acc 0.0706 +/- nan
weight diff norm 0.2738624097121636 +/- nan
em_calib-tswithbiascorr_init-default
delt acc 0.08650000000000002 +/- nan
weight diff norm 0.0326774041884536 +/- nan
bbse-hard_calib-None
delt acc 0.07450000000000001 +/- nan
weight diff norm 0.26571104401334134 +/- nan
bbse-soft_calib-None
delt acc 0.07450000000000001 +/- nan
weight diff norm 0.21937459988020055 +/- nan
On trial num 1
On trial num 2
On trial num 3
On trial num 4
On trial num 5
On trial num 6
On trial num 7
On trial num 8
On trial num 9
em_calib-None_init-default
delt acc 0.05738 +/- 0.0018145326916010345
w

In [14]:
dirichletalpha_to_adaptername_to_deltaacc[dirichlet_alpha][adapter_name]

[0.07799999999999996]

In [None]:
import json
import os
file_out = "label_shift_adaptation_results.json"
open(file_out, 'w').write(
    json.dumps(dict([(dirichlet_alpha,
                      dict([(estimator_name, [x[2] for x in results])
                             for (estimator_name,results) in
                             estimatorname_to_results.items()]))
                     for (dirichlet_alpha,estimatorname_to_results)
                     in dirichletalpha_to_estimatorname_to_results.items()]),
               sort_keys=True, indent=4, separators=(',', ': ')))
os.system("gzip -f "+file_out)