<a href="https://colab.research.google.com/github/kundajelab/labelshiftexperiments/blob/master/notebooks/demo/blog_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Maximum Likelihood + Bias-Corrected Temperature Scaling

This notebook demonstrates how to perform label shift domain adaptation using 

### Setup

Download the datasets

In [1]:
!wget https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_labels.txt.gz -O demo_valid_labels.txt.gz
!wget https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_preds.txt.gz -O demo_shifted_test_preds.txt.gz
!wget https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_preds.txt.gz -O demo_valid_preds.txt.gz

--2020-11-22 02:56:31--  https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_labels.txt.gz
Resolving raw.github.com (raw.github.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.github.com (raw.github.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_labels.txt.gz [following]
--2020-11-22 02:56:31--  https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_valid_labels.txt.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5969 (5.8K) [application/octet-stream]
Saving to: ‘demo_valid_labels.txt.gz’


2020-

Install the necessary package

In [2]:
!pip install abstention

Collecting abstention
  Downloading https://files.pythonhosted.org/packages/c2/cb/b9a4ef4a0efecf1ac74fc12a459f05d17dc76ebba9c9ee1c62b9d651bb18/abstention-0.1.3.1.tar.gz
Building wheels for collected packages: abstention
  Building wheel for abstention (setup.py) ... [?25l[?25hdone
  Created wheel for abstention: filename=abstention-0.1.3.1-cp36-none-any.whl size=25470 sha256=f16debecbfdee13d197c22ed52c9a349d2bb8822780917b58c1f7741e5f9de71
  Stored in directory: /root/.cache/pip/wheels/7c/a8/fc/5ddf92c0e5934d70543ea30142078287d911f01e75cffb808c
Successfully built abstention
Installing collected packages: abstention
Successfully installed abstention-0.1.3.1


Import relevant modules and define functions for reading in the data

In [3]:
import gzip
import numpy as np
from collections import defaultdict
from scipy.special import softmax

def read_labels(fh):
    to_return = []
    for line in fh:
        the_class=int(line.rstrip())
        to_add = np.zeros(10)
        to_add[the_class] = 1
        to_return.append(to_add)
    return np.array(to_return)

def read_preds(fh):
    return np.array([[float(x) for x in y.decode("utf-8").rstrip().split("\t")]
                     for y in fh])

Read in the validation set predictions and labels, as well as the predictions on the (label shifted) test set

In [4]:
valid_labels = read_labels(gzip.open("demo_valid_labels.txt.gz", "rb"))
valid_preds = read_preds(gzip.open("demo_valid_preds.txt.gz", "rb"))
shifted_test_preds = read_preds(gzip.open("demo_shifted_test_preds.txt.gz", "rb"))

### Perform label shift adaptation

Apply Maximum Likelihood + BCTS

In [5]:
from abstention.calibration import TempScaling
from abstention.label_shift import EMImbalanceAdapter

#Instantiate the BCTS calibrator factory
bcts_calibrator_factory = TempScaling(verbose=False, bias_positions='all')
#Specify that we would like to use Maximum Likelihood (EM) for the
# label shift adaptation, with BCTS for calibration
imbalance_adapter = EMImbalanceAdapter(calibrator_factory=
                                       bcts_calibrator_factory)
#Get the function that will do the label shift adaptation (creating this
# function requires supplying the validation set labels/predictions as well as
# the test-set predictions)
imbalance_adapter_func = imbalance_adapter(valid_labels=valid_labels,
                          tofit_initial_posterior_probs=shifted_test_preds,
                          valid_posterior_probs=valid_preds)
#Get the adapted test-set predictions
adapted_shifted_test_preds = imbalance_adapter_func(shifted_test_preds)

### Evaluation

Download and read in the labels for the test set


In [6]:
!wget https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_labels.txt.gz -O demo_shifted_test_labels.txt.gz

shifted_test_labels = read_labels(gzip.open("demo_shifted_test_labels.txt.gz", "rb"))

--2020-11-22 02:56:39--  https://raw.github.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_labels.txt.gz
Resolving raw.github.com (raw.github.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.github.com (raw.github.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_labels.txt.gz [following]
--2020-11-22 02:56:40--  https://raw.githubusercontent.com/kundajelab/labelshiftexperiments/master/notebooks/demo/demo_shifted_test_labels.txt.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 71 [application/octet-stream]
Saving to: ‘demo_shifted_test_lab

Evaluate the improvement in performance due to domain adaptation

In [7]:
#Get the test set accuracy WITHOUT label shift adaptation
unadapted_test_accuracy = np.mean(np.argmax(shifted_test_labels,axis=-1)==np.argmax(shifted_test_preds,axis=-1))
#Get the test-set accuracy WITH label shift adaptation
adapted_test_accuracy = np.mean(np.argmax(shifted_test_labels,axis=-1)==np.argmax(adapted_shifted_test_preds,axis=-1))

print("Accuracy without label shift adaptation:", unadapted_test_accuracy)
print("Accuracy with label shift adaptation:", adapted_test_accuracy)

Accuracy without label shift adaptation: 0.707
Accuracy with label shift adaptation: 0.986


## Misc

This is the code that was used to generate the `demo_*` files

```
import gzip
import glob
import numpy as np
from collections import defaultdict
from scipy.special import softmax


def sample_from_probs_arr(arr_with_probs):
    rand_num = np.random.random()
    cdf_so_far = 0
    for (idx, prob) in enumerate(arr_with_probs):
        cdf_so_far += prob
        if (cdf_so_far >= rand_num
            or idx == (len(arr_with_probs) - 1)):  # need the
            # letterIdx==(len(row)-1) clause because of potential floating point errors
            # that mean arrWithProbs doesn't sum to 1
            return idx


def draw_test_indices(total_to_return, label_proportions):
    indices_to_use = []
    for class_index, class_proportion in enumerate(label_proportions):
        indices_to_use.extend(np.random.choice(
                TEST_CLASS_TO_INDICES[class_index],
                int(total_to_return*class_proportion),
                replace=True))
    for i in range(total_to_return-len(indices_to_use)):
        class_index = sample_from_probs_arr(label_proportions)
        indices_to_use.append(
            np.random.choice(TEST_CLASS_TO_INDICES[class_index]))
    return indices_to_use


def write_preds(preds, filename):
  f = open(filename,'w')
  for pred in preds:
    f.write("\t".join([str(x) for x in pred])+"\n") 
  f.close()


def write_labels(labels, filename):
  f = open(filename,'w')
  f.write("\n".join([str(np.argmax(x, axis=-1)) for x in labels]))
  f.close()


def read_labels(fh):
    to_return = []
    for line in fh:
        the_class=int(line.rstrip())
        to_add = np.zeros(10)
        to_add[the_class] = 1
        to_return.append(to_add)
    return np.array(to_return)


def read_preds(fh):
    return np.array([[float(x) for x in y.decode("utf-8").rstrip().split("\t")]
                     for y in fh])


!wget https://zenodo.org/record/3406662/files/test_labels.txt.gz?download?=1 -O test_labels.txt.gz
!wget https://zenodo.org/record/3406662/files/testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz?download=1 -O testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz
!wget https://zenodo.org/record/3406662/files/validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz?download=1 -O validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz
!wget https://zenodo.org/record/3406662/files/valid_labels.txt.gz?download?=1 -O demo_valid_labels.txt.gz


test_labels = read_labels(gzip.open("test_labels.txt.gz"))
test_preds = softmax(read_preds(gzip.open(
  "testpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz")),
                       axis=1)
valid_preds = softmax(read_preds(gzip.open(
    "validpreacts_model_cifar10_balanced_seed-0_bestbefore-100_currentepoch-100_valacc-91_vgg.txt.gz")),
                      axis=1)


dirichlet_alpha = 0.1
samplesize = 1000
dirichlet_dist = np.random.RandomState(123).dirichlet(
                  [dirichlet_alpha for x in range(10)])

TEST_CLASS_TO_INDICES = defaultdict(list)
for index,row in enumerate(test_labels):
    row_label = np.argmax(row)
    TEST_CLASS_TO_INDICES[row_label].append(index)

test_indices = draw_test_indices(total_to_return=samplesize,
                                 label_proportions=dirichlet_dist)
shifted_test_labels = test_labels[test_indices]
shifted_test_preds = test_preds[test_indices]

write_preds(preds=valid_preds, filename="demo_valid_preds.txt")
write_preds(preds=shifted_test_preds, filename="demo_shifted_test_preds.txt")
write_labels(labels=shifted_test_labels, filename="demo_shifted_test_labels.txt")
!gzip -f *.txt
```

