<a href="https://colab.research.google.com/github/kundajelab/labelshiftexperiments/blob/master/notebooks/demo/blog_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## This notebook demonstrates how to perform label shift domain adaptation using Maximum Likelihood + Bias-Corrected Temperature Scaling

Install the necessary package

In [35]:
!pip install abstention



Download the datasets

In [36]:
#TODO

In [37]:
#Import relevant modules and define functions for reading in the data
import gzip
import numpy as np
from collections import defaultdict
from scipy.special import softmax
from abstention.calibration import TempScaling
from abstention.label_shift import EMImbalanceAdapter

def read_labels(fh):
    to_return = []
    for line in fh:
        the_class=int(line.rstrip())
        to_add = np.zeros(10)
        to_add[the_class] = 1
        to_return.append(to_add)
    return np.array(to_return)

def read_preds(fh):
    return np.array([[float(x) for x in y.decode("utf-8").rstrip().split("\t")]
                     for y in fh])

In [38]:
#Read in the validation set predictions and labels, as well as the predictions
# on the (label shifted) test set
valid_labels = read_labels(gzip.open("demo_valid_labels.txt.gz", "rb"))
valid_preds = read_preds(gzip.open("demo_valid_preds.txt.gz", "rb"))
shifted_test_preds = read_preds(gzip.open("demo_shifted_test_preds.txt.gz", "rb"))

In [47]:
#Specify BCTS as the calibrator to use
bcts_calibrator_factory = TempScaling(verbose=False, bias_positions='all')
#Specify that we would like to use Maximum Likelihood (EM) for the label shift adaptation
imbalance_adapter = EMImbalanceAdapter(calibrator_factory=
                                       bcts_calibrator_factory)
#Get the function that will do the label shift adaptation; "valid" is for
# the validation set
imbalance_adapter_func = imbalance_adapter(valid_labels=valid_labels,
                          tofit_initial_posterior_probs=shifted_test_preds,
                          valid_posterior_probs=valid_preds)
#Get the adapted test-set predictions
adapted_shifted_test_preds = imbalance_adapter_func(shifted_test_preds)

In [49]:
#Let's evaluate the improvement in performance due to domain adaptation
shifted_test_labels = read_labels(gzip.open("demo_shifted_test_labels.txt.gz", "rb"))
test_accuracy = np.mean(np.argmax(shifted_test_labels,axis=-1)==np.argmax(shifted_test_preds,axis=-1))
adapted_test_accuracy = np.mean(np.argmax(shifted_test_labels,axis=-1)==np.argmax(adapted_shifted_test_preds,axis=-1))
print(test_accuracy, adapted_test_accuracy)

0.707 0.986


For reference, this is the code that was used to generate the `demo_*` files

```
import gzip
import glob
import numpy as np
from collections import defaultdict
from scipy.special import softmax


def sample_from_probs_arr(arr_with_probs):
    rand_num = np.random.random()
    cdf_so_far = 0
    for (idx, prob) in enumerate(arr_with_probs):
        cdf_so_far += prob
        if (cdf_so_far >= rand_num
            or idx == (len(arr_with_probs) - 1)):  # need the
            # letterIdx==(len(row)-1) clause because of potential floating point errors
            # that mean arrWithProbs doesn't sum to 1
            return idx


def draw_test_indices(total_to_return, label_proportions):
    indices_to_use = []
    for class_index, class_proportion in enumerate(label_proportions):
        indices_to_use.extend(np.random.choice(
                TEST_CLASS_TO_INDICES[class_index],
                int(total_to_return*class_proportion),
                replace=True))
    for i in range(total_to_return-len(indices_to_use)):
        class_index = sample_from_probs_arr(label_proportions)
        indices_to_use.append(
            np.random.choice(TEST_CLASS_TO_INDICES[class_index]))
    return indices_to_use


def write_preds(preds, filename):
  f = open(filename,'w')
  for pred in preds:
    f.write("\t".join([str(x) for x in pred])+"\n") 
  f.close()


def write_labels(labels, filename):
  f = open(filename,'w')
  f.write("\n".join([str(np.argmax(x, axis=-1)) for x in labels]))
  f.close()


def read_labels(fh):
    to_return = []
    for line in fh:
        the_class=int(line.rstrip())
        to_add = np.zeros(10)
        to_add[the_class] = 1
        to_return.append(to_add)
    return np.array(to_return)


def read_preds(fh):
    return np.array([[float(x) for x in y.decode("utf-8").rstrip().split("\t")]
                     for y in fh])


!wget https://zenodo.org/record/3406662/files/test_labels.txt.gz?download?=1 -O test_labels.txt.gz
!wget https://zenodo.org/record/3406662/files/testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz?download=1 -O testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz
!wget https://zenodo.org/record/3406662/files/validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz?download=1 -O validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz
!wget https://zenodo.org/record/3406662/files/valid_labels.txt.gz?download?=1 -O demo_valid_labels.txt.gz


test_labels = read_labels(gzip.open("test_labels.txt.gz"))
test_preds = softmax(read_preds(gzip.open(
  "testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz")),
                       axis=1)
valid_preds = softmax(read_preds(gzip.open(
    "validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz")),
                      axis=1)


dirichlet_alpha = 0.1
samplesize = 1000
dirichlet_dist = np.random.RandomState(123).dirichlet(
                  [dirichlet_alpha for x in range(10)])

TEST_CLASS_TO_INDICES = defaultdict(list)
for index,row in enumerate(test_labels):
    row_label = np.argmax(row)
    TEST_CLASS_TO_INDICES[row_label].append(index)

test_indices = draw_test_indices(total_to_return=samplesize,
                                 label_proportions=dirichlet_dist)
shifted_test_labels = test_labels[test_indices]
shifted_test_preds = test_preds[test_indices]

write_preds(preds=valid_preds, filename="demo_valid_preds.txt")
write_preds(preds=shifted_test_preds, filename="demo_shifted_test_preds.txt")
write_labels(labels=shifted_test_labels, filename="demo_shifted_test_labels.txt")
!gzip -f *.txt
```

