In [1]:
from collections import deque
import numpy as np
import matplotlib.pyplot as plt
import random
import pandas as pd
import matplotlib
matplotlib.rcParams['font.size'] = 10
matplotlib.rcParams['font.family'] = 'serif'
from bbutils import BetaBernoulli

In [2]:
predictions = np.genfromtxt("data/svhn/svhn_predictions.txt", delimiter=' ', dtype = float)
df = pd.DataFrame(predictions[:,0].astype(int), columns=['Correct'])
df['Predicted'] = np.argmax(predictions[:,1:11], axis=1)
df['Confidence'] = np.max(predictions[:,1:11], axis=1)
df = df.reset_index()
num_classes = 10

In [3]:
df.shape

(26032, 4)

In [None]:
# get value for the varaibles robby defined
category2idx = {i:i for i in range(num_classes)}
idx2category = {i:i for i in range(num_classes)}
categories = df['Predicted'].tolist()
observations = [df['Predicted'][idx] == df['Correct'][idx] for idx in range(df.shape[0])] 

In [None]:
# Try to identify worst class
n = df.shape[0]
k = num_classes
runs = 10000
# pseudo_count = 2
mode = 'min'

active_choices = np.zeros((runs, k, n))
active_thetas = np.zeros((runs, k, n))
random_thetas = np.zeros((runs, k, n))

for r in range(runs):

    # Queue choices for each category
    deques = [deque() for _ in range(k)]
    for category, observation in zip(categories, observations):
        deques[category].append(observation)
    ### THIS IS IMPORTANT ###
    for _deque in deques:
        random.shuffle(_deque)
   ####

    # Random model
    model = BetaBernoulli(k)

#     # Prior
#     # model._params = np.vstack((alpha, beta)).T

    n_success = 0
    total = 0
    for i in range(n):
        while True:
            category = random.randrange(k)
            if len(deques[category]) != 0:
                break
        observation = deques[category].pop()
        model.update(category, observation)
        random_thetas[r, :, i] = model._params[:,0] / (model._params[:,0] + model._params[:,1])
    
    # Queue choices for each category
    deques = [deque() for _ in range(k)]
    for category, observation in zip(categories, observations):
        deques[category].append(observation)
    ### THIS IS IMPORTANT ###
    for _deque in deques:
        random.shuffle(_deque)
   ####

    # Beta Bernoulli model
    model = BetaBernoulli(k)

    # Prior
    # model._params = np.vstack((alpha, beta)).T

    n_success = 0
    total = 0

    for i in range(n):
        theta_hat = model.sample()
        
        if mode == 'max':
            choices = np.argsort(theta_hat)[::-1]
        elif mode == 'min':
            choices = np.argsort(theta_hat)
            
        for j in range(k):
            category = choices[j]
            if len(deques[category]) != 0:
                break
        observation = deques[category].pop()
        model.update(category, observation)
        #beta_bernoulli_outcome[i] = n_success / (total + 1e-13)
        active_thetas[r, :, i] = model._params[:,0] / (model._params[:,0] + model._params[:,1])
        if i > 0:
            active_choices[r, :, i] = active_choices[r, :, i - 1]
        active_choices[r, category, i] += 1

In [None]:
correct = np.zeros(k)
total = np.zeros(k)

for category, observation in zip(categories, observations):
    if observation:
        correct[category] += 1
    total[category] += 1
    
empirical_acc = correct / total
ranked = np.argsort(empirical_acc)[::-1]
print(empirical_acc)

In [None]:
avg_active_choices = np.mean(active_choices, axis=0)

for i in range(k):
    plt.plot(avg_active_choices[i,:])
    plt.xlabel('Time')
    plt.ylabel('Number of times chosen')

top = np.argsort(avg_active_choices[:,-1])[::-1]

for i in top[:10]:
    print(idx2category[i])

In [None]:
most_chosen = np.argmax(active_choices[:, :, :], axis=1)

cutoff = 9

if mode == 'max':
    selection = ranked[:cutoff]
elif mode == 'min':
    selection = ranked[-cutoff:]
    selection = selection[::-1]
    
timestamps = [100, 200, 300, 400, 500]

table = np.zeros((cutoff+1, len(timestamps)))
    
for i, category in enumerate(selection):
    for j, timestamp in enumerate(timestamps):
        table[i,j] = np.mean(most_chosen[:,timestamp] == category)

table[cutoff,:] = 1 - np.sum(table, axis=0)

In [None]:
header = ' ' * 10 + ' '.join('%6i' % x for x in timestamps)
print(header)

for i, row in enumerate(table):
    if i < cutoff:
        category = idx2category[selection[i]]
    else:
        category = 'OTHER'
    category_string = '{}'.format(category) + ' ' * 10
    num_string = ' '.join('{:.04f}'.format(x) for x in row.tolist())
    print(category_string + num_string)

In [None]:
random_success = np.mean(np.argmin(random_thetas, axis=1) == selection[0], axis=0)
active_success = np.mean(np.argmin(active_thetas, axis=1) == selection[0], axis=0)

In [None]:
# Single column
plt.figure(figsize=(3.03, 3.03 / 1.618), dpi=300)
# # Double column
# plt.figure(figsize=(6.30, 6.30 / 1.618), dpi=300)
plt.plot(active_success, label='active')
plt.plot(random_success, label='random')
plt.xlabel('Time')
plt.ylabel('Success Rate')
plt.legend()
plt.savefig("active_svhn_mode_%s_runs_%d.pdf" % (mode, runs), format='pdf')