# Imports

In [None]:
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import random
import shutil
import time
from sklearn.metrics.pairwise import rbf_kernel
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
from PIL import Image
import torch.utils.data as data
from torchvision.datasets.utils import download_url, check_integrity
import sys
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler
import pickle
from sklearn.gaussian_process.kernels import RBF
import torch
from scipy.stats import multivariate_normal
import  scipy.stats as st
from matplotlib import cm
import torch.optim as optim
from __future__ import print_function
from tqdm import tqdm
from sklearn.metrics.pairwise import rbf_kernel
from scipy.special import expit

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


# Training and Algorithm

Data generating code

In [None]:

def sample(mu, var, nb_samples=500):
    """
    :param mu: torch.Tensor (features)
    :param var: torch.Tensor (features) (note: zero covariance)
    :return: torch.Tensor (nb_samples, features)
    """
    out = []
    for i in range(nb_samples):
        out += [
            torch.normal(mu, var.sqrt())
        ]
    return torch.stack(out, dim=0)

    

Linear classifier definition and training

In [None]:

class Linear_net_sig(nn.Module):
    def __init__(self, input_dim):
        super(Linear_net_sig, self).__init__()
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(input_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.sigmoid(x)
        return x

def run_classifier_sig(net, data_x, data_y, n_epochs = 10000):
    '''
    training code using GD
    '''
    BCE = torch.nn.BCELoss(size_average=True)
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0)
    #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000*10000)
    for epoch in range(n_epochs):  # loop over the dataset multiple times

        running_loss = 0.0
        # get the inputs; data is a list of [inputs, labels]
        inputs = data_x
        labels = data_y
        order = np.array(range(len(data_x)))
        np.random.shuffle(order)
        # in-place changing of values
        inputs[np.array(range(len(data_x)))] = inputs[order]
        labels[np.array(range(len(data_x)))] = labels[order]
        # zero the parameter gradients

        # forward + backward + optimize
        outputs = net(inputs)[:,0]

        loss = BCE(outputs, labels*1.0) 

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #scheduler.step()
        running_loss += loss.item()
        if epoch % 10000 == 0:
            print("loss " + str(loss.item()))

    #print('Finished Training')


def test_classifier_sig(net, data_x, data_y):
    correct = 0
    total = 0
    with torch.no_grad():
        inputs =  data_x
        labels = data_y
        outputs = net(inputs)
        predicted = torch.round(outputs.data)
        total = labels.size(0)
        for i in range(total):
            correct += predicted[i].item() == labels[i].item()
        #correct = (predicted == labels).sum()
    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

Simplified human learner model

In [None]:
class HumanLearner:
    def __init__(self):
        self.teaching_set = []
        self.kernel_raw = RBF()

    def kernel(self, x,y):
        kernel_computation = self.kernel_raw(x.reshape(1,-1),y.reshape(1,-1))[0][0]
        return kernel_computation

    def predict(self, xs):
        '''
        x: expected array of inputs
        '''
        preds = []
        for x in xs:
            ball_at_x = []
            if len(self.teaching_set) == 0:
                preds.append(-1)
                continue
            similarities = rbf_kernel(x.reshape(1,-1), np.asarray([self.teaching_set[kk][0] for kk in range(len(self.teaching_set))]))[0]
            for i in range(len(self.teaching_set)):

                similarity = similarities[i]
                if similarity >= self.teaching_set[i][2]:
                    ball_at_x.append(self.teaching_set[i])
            if len(ball_at_x) == 0:
                preds.append(-1)
                continue
            ball_similarities = rbf_kernel(x.reshape(1,-1), np.asarray([ball_at_x[kk][0] for kk in range(len(ball_at_x))]))[0]
            normalization = np.sum([ball_similarities[i] for i in range(len(ball_at_x))])
            score_one = np.sum([ball_similarities[i]*ball_at_x[i][1] for i in range(len(ball_at_x))])
            pred = score_one / normalization
            if pred >= 0.5:
                preds.append(1)
            else:
                preds.append(0)
        return preds
    def predict_prior(self, xs, prior_rejector_preds, to_print = False):
        preds = []
        j = 0
        for x in xs:
            ball_at_x = []
            if len(self.teaching_set) == 0:
                preds.append(prior_rejector_preds[j])
                j += 1
                continue
            similarities = rbf_kernel(x.reshape(1,-1), np.asarray([self.teaching_set[kk][0] for kk in range(len(self.teaching_set))]))[0]
            for i in range(len(self.teaching_set)):

                similarity = similarities[i]
                if similarity >= self.teaching_set[i][2]:
                    ball_at_x.append(self.teaching_set[i])
            if len(ball_at_x) == 0:
                preds.append(prior_rejector_preds[j])
                j += 1
                continue
            j += 1
            ball_similarities = rbf_kernel(x.reshape(1,-1), np.asarray([ball_at_x[kk][0] for kk in range(len(ball_at_x))]))[0]
            normalization = np.sum([ball_similarities[i] for i in range(len(ball_at_x))])
            score_one = np.sum([ball_similarities[i]*ball_at_x[i][1] for i in range(len(ball_at_x))])
            pred = score_one / normalization
            if pred >= 0.5:
                preds.append(1)
            else:
                preds.append(0)
        return preds
    def add_to_teaching(self, exam):
        self.teaching_set.append(exam)

    def remove_teaching(self):
        self.teaching_set = self.teaching_set[:-1]

Evaluation code of prior and posterior rejector of human learner

In [None]:

def test_prior(net_hum, net_mach, epsilon, data_x, data_y):
    correct = 0
    correct_sys = 0
    exp = 0
    exp_total = 0
    total = 0
    real_total = 0
    alone_correct = 0
    prior_rejectors = []
    with torch.no_grad():
        inputs =  data_x
        labels = data_y
        m = net_mach(inputs)
        predicted_exp = torch.round(m.data)
        outputs = net_hum(inputs)
        predicted = torch.round(outputs.data)
        for i in range(len(inputs)):
            r_score = max(1 - outputs.data[i].item(), outputs.data[i].item())
            r = 0
            if r_score <  epsilon:
                r = 1
            else:
                r =  0
            prior_rejectors.append(r)
            if r == 1:
                exp += (predicted_exp[i] == labels[i]).item()
                correct_sys += (predicted_exp[i] == labels[i]).item()
                exp_total += 1
            elif r == 0:
                correct += (predicted[i] == labels[i]).item() 
                correct_sys += (predicted[i] == labels[i]).item()
                total += 1
        real_total += labels.size(0)
    cov = str(total) + str(" out of") + str(real_total)
    to_print={"coverage":cov, "system accuracy": 100*correct_sys/real_total, "expert accuracy":100* exp/(exp_total+0.0002),"classifier accuracy":100*correct/(total+0.0001), "alone classifier": 100*alone_correct/real_total }
    #print(to_print)
    return to_print, prior_rejectors


def test_posterior(net_hum, net_mach, epsilon, knn_learner, data_x, data_y):
    correct = 0
    correct_sys = 0
    exp = 0
    exp_total = 0
    total = 0
    real_total = 0
    alone_correct = 0
    mistakes = []
    rejector_preds = []
    with torch.no_grad():
        inputs =  data_x
        labels = data_y
        m = net_mach(inputs)
        predicted_exp = torch.round(m.data)
        outputs = net_hum(inputs)
        predicted = torch.round(outputs.data)
        post_preds = knn_learner.predict(inputs.numpy())
        for i in range(len(inputs)):
            r = post_preds[i]
            if r == -1: # if no point in  ball
                r_score = max(1 - outputs.data[i].item(), outputs.data[i].item())
                if r_score <  epsilon:
                    r = 1
                else:
                    r =  0
            rejector_preds.append(r)
            if r == 1:
                exp += (predicted_exp[i] == labels[i]).item()
                correct_sys += (predicted_exp[i] == labels[i]).item()
                mistakes.append((predicted_exp[i] == labels[i]).item()*1.0)
                exp_total += 1
            elif r == 0:
                correct += (predicted[i] == labels[i]).item() 
                correct_sys += (predicted[i] == labels[i]).item()
                mistakes.append((predicted[i] == labels[i]).item()*1.0)
                total += 1
        real_total += labels.size(0)
    cov = str(total) + str(" out of") + str(real_total)
    to_print={"coverage":cov, "system accuracy": 100*correct_sys/real_total, "expert accuracy":100* exp/(exp_total+0.0002),"classifier accuracy":100*correct/(total+0.0001), "alone classifier": 100*alone_correct/real_total }
    #print(to_print)
    return mistakes, to_print, rejector_preds

Algorithm helpers

In [None]:
def get_improvement_defer(current_defer_preds, opt_defer_preds, gammas, xs, ai_preds, hum_preds, truth):
    '''
    evaluates improvement for each data point if it were to be added
    '''
    error_improvements = []
    similarities_embeds_all = rbf_kernel(np.asarray(xs), np.asarray(xs))
    error_at_i = 0
    for i in range(len(gammas)):
        coin = random.random() # random number between [0,1]
        error_at_i = 0

        similarities_embeds = similarities_embeds_all[i]
        for j in range(len(similarities_embeds)):
            if similarities_embeds[j] >= gammas[i]:
                error_hum = (hum_preds[j] == truth[j]) * 1.0 
                error_ai = (ai_preds[j] == truth[j]) * 1.0
                
                if opt_defer_preds[i] == 1:
                    if current_defer_preds[j] == 0:
                        error_at_i += error_ai - error_hum  
                else:
                    if current_defer_preds[j] == 1:
                        error_at_i +=   error_hum - error_ai
        error_improvements.append(error_at_i)

        # get the ball for x
        # in this ball how many does the current defer not match the optimal
    return error_improvements

Plotting code

In [None]:
def mscatter(x,y,ax=None, m=None, **kw):
    import matplotlib.markers as mmarkers
    if not ax: ax=plt.gca()
    sc = ax.scatter(x,y,**kw)
    if (m is not None) and (len(m)==len(x)):
        paths = []
        for marker in m:
            if isinstance(marker, mmarkers.MarkerStyle):
                marker_obj = marker
            else:
                marker_obj = mmarkers.MarkerStyle(marker)
            path = marker_obj.get_path().transformed(
                        marker_obj.get_transform())
            paths.append(path)
        sc.set_paths(paths)
    return sc
def conv_to_color(arr):
    cols = []
    for a in arr:
        if a == 1:
            cols.append("blue")
        else:
            cols.append("red")
    return cols

# Generate Data

In [None]:

d = 2
total_samples = 500
mean_scale = 12
variance = 1.0
group_proportion = 0.5#np.random.uniform()
'''
if group_proportion <= 0.02:
    group_proportion = 0.02
if group_proportion >= 0.98:
    group_proportion = 0.98
#group_proportion = 0.4
'''
cluster1_mean = torch.rand(d)*mean_scale
cluster1_var = torch.tensor(variance)#torch.rand(d)*d
cluster1 = sample(
    cluster1_mean,
    cluster1_var,
    nb_samples= math.floor(total_samples * group_proportion * 0.5 )
)
cluster1_labels = torch.ones([math.floor(total_samples * group_proportion * 0.5 )], dtype=torch.long)
cluster2_mean = torch.rand(d)*mean_scale
cluster2_var = torch.tensor(variance)#torch.rand(d)*d
cluster2 = sample(
    cluster2_mean,
    cluster2_var,
    nb_samples= math.floor(total_samples * group_proportion * 0.5 )
)
cluster2_labels = torch.zeros([math.floor(total_samples * group_proportion * 0.5 )], dtype=torch.long)
cluster3_mean = torch.rand(d)*mean_scale
cluster3_var = torch.tensor(variance)#torch.rand(d)*d
cluster3 = sample(
    cluster3_mean,
    cluster3_var,
    nb_samples= math.floor(total_samples * (1-group_proportion) * 0.5 )
)
cluster3_labels = torch.ones([math.floor(total_samples * (1-group_proportion) * 0.5 )], dtype=torch.long)

cluster4_mean = torch.rand(d)*mean_scale
cluster4_var = torch.tensor(variance)#torch.rand(d)*d
cluster4 = sample(
    cluster4_mean,
    cluster4_var,
    nb_samples= math.floor(total_samples * (1-group_proportion) * 0.5 )
)
cluster4_labels = torch.zeros([math.floor(total_samples * (1-group_proportion) * 0.5 )], dtype=torch.long)

# test data
cluster1_test = sample(
    cluster1_mean,
    cluster1_var,
    nb_samples= math.floor(total_samples * group_proportion * 0.5 )
)
cluster1_labels_test = torch.ones([math.floor(total_samples * group_proportion * 0.5 )], dtype=torch.long)

cluster2_test = sample(
    cluster2_mean,
    cluster2_var,
    nb_samples= math.floor(total_samples * group_proportion * 0.5 )
)
cluster2_labels_test = torch.zeros([math.floor(total_samples * group_proportion * 0.5 )], dtype=torch.long)

cluster3_test = sample(
    cluster3_mean,
    cluster3_var,
    nb_samples= math.floor(total_samples * (1-group_proportion) * 0.5 )
)
cluster3_labels_test = torch.ones([math.floor(total_samples * (1-group_proportion) * 0.5 )], dtype=torch.long)

cluster4_test = sample(
    cluster4_mean,
    cluster4_var,
    nb_samples= math.floor(total_samples * (1-group_proportion) * 0.5 )
)
cluster4_labels_test = torch.zeros([math.floor(total_samples * (1-group_proportion) * 0.5 )], dtype=torch.long)

fig, ax = plt.subplots(1)
x1 = cluster1.numpy()
x2 = cluster2.numpy()
x3 = cluster3.numpy()
x4 = cluster4.numpy()
epsilon = 0.8
ax.set_facecolor('white')
#ax.set(xlim=(-4, 10), ylim=(-4, 10))
#ax.vlines([-12,-6,0,6,12.3],-12,12.3)
#ax.hlines([-12,-6,0,6,12.3],-12,12.3)
#ax.plot([x1h, x2h], [y1h, y2h], color='red', marker='x',label = "human")
#ax.plot([x1m, x2m], [y1m, y2m], color='blue', marker='x',label = "machine")

scatter = mscatter(x1[:, 0], x1[:, 1],  cmap='RdBu',  ax=ax,s=100, label="human 0")
scatter = mscatter(x2[:, 0], x2[:, 1], cmap='RdBu',  ax=ax,s=100, label="human 1")
scatter = mscatter(x3[:, 0], x3[:, 1],  cmap='RdBu',  ax=ax,s=100, label="machine 0")
scatter = mscatter(x4[:, 0], x4[:, 1],  cmap='RdBu',  ax=ax,s=100, label="machine 1")
plt.legend()
plt.show()

# Obtain Human and AI predictors

get AI to be linear model on 2 clusters and human on the other 2 clusters

In [None]:
print("Obtaining AI")
net_machine = Linear_net_sig(d)
data_x = torch.cat([cluster3, cluster4])
data_y = torch.cat([cluster3_labels, cluster4_labels])
run_classifier_sig(net_machine, data_x, data_y, 50000)

print("Obtaining human")
net_human = Linear_net_sig(d)
data_x = torch.cat([cluster1, cluster2])
data_y = torch.cat([cluster1_labels, cluster2_labels])
run_classifier_sig(net_human, data_x, data_y, 50000)
knn_learner = HumanLearner()

In [None]:
# get line human
weights = net_human.fc1.weight.detach().numpy()[0]
bias = net_human.fc1.bias.detach().numpy()[0]
x1h =  -100
x2h = 100
y1h = -(1/weights[1])*(weights[0]*x1h + bias)
y2h = -(1/weights[1])*(weights[0]*x2h + bias)

# get line machine
weights = net_machine.fc1.weight.detach().numpy()[0]
bias = net_machine.fc1.bias.detach().numpy()[0]
x1m =  -100
x2m = 100
y1m = -(1/weights[1])*(weights[0]*x1m + bias)
y2m = -(1/weights[1])*(weights[0]*x2m + bias)

data_x = torch.cat([cluster1])
data_y = torch.cat([cluster1_labels])
mists1, _, _  = test_posterior(net_human, net_machine, epsilon, knn_learner, data_x, data_y)

data_x = torch.cat([cluster2])
data_y = torch.cat([cluster2_labels])
mists2, _, _  = test_posterior(net_human, net_machine, epsilon, knn_learner, data_x, data_y)

data_x = torch.cat([cluster3])
data_y = torch.cat([cluster3_labels])
mists3, _, _  = test_posterior(net_human, net_machine, epsilon, knn_learner, data_x, data_y)

data_x = torch.cat([cluster4])
data_y = torch.cat([cluster4_labels])
mists4, _, _ = test_posterior(net_human, net_machine, epsilon, knn_learner, data_x, data_y)


fig, ax = plt.subplots(1)

ax.set_facecolor('white')
ax.set(xlim=(-2, 13), ylim=(-2, 13))
#ax.vlines([-12,-6,0,6,12.3],-12,12.3)
#ax.hlines([-12,-6,0,6,12.3],-12,12.3)
ax.plot([x1h, x2h], [y1h, y2h], color='black', marker='x',label = "human")
ax.plot([x1m, x2m], [y1m, y2m], color='green', marker='x',label = "machine")


ax.scatter(x1[:, 0], x1[:, 1], c = conv_to_color(mists1), cmap='RdBu',marker ='o' )
ax.scatter(x2[:, 0], x2[:, 1], c = conv_to_color(mists2), cmap='RdBu',marker ='x' )
ax.scatter(x3[:, 0], x3[:, 1], c = conv_to_color(mists3), cmap='RdBu', marker ='o' )
ax.scatter(x4[:, 0], x4[:, 1], c = conv_to_color(mists4), cmap='RdBu', marker ='x' )

#scatter = mscatter(x1[:, 0], x1[:, 1], c = mists1 , cmap='RdBu',  ax=ax,s=100, label="human 0")
#scatter = mscatter(x2[:, 0], x2[:, 1], c = mists2 , cmap='RdBu',  ax=ax,s=100, label="human 1")
#scatter = mscatter(x3[:, 0], x3[:, 1], c = mists3 , cmap='RdBu',  ax=ax,s=100, label="machine 0")
#scatter = mscatter(x4[:, 0], x4[:, 1], c = mists4 , cmap='RdBu',  ax=ax,s=100, label="machine 1")


plt.legend()
plt.show()

test prior rejector with a given epsilon_prior rejector

In [None]:
knn_learner = HumanLearner()
epsilon_prior = 0.9
data_x = torch.cat([cluster1, cluster2, cluster3, cluster4])
data_y = torch.cat([cluster1_labels, cluster2_labels, cluster3_labels, cluster4_labels])
to_print, _ = test_prior(net_human, net_machine, epsilon_prior, data_x, data_y)
print(to_print)

# Run Teaching

In [None]:
data_x = torch.cat([cluster1, cluster2, cluster3, cluster4])
data_y = torch.cat([cluster1_labels, cluster2_labels, cluster3_labels, cluster4_labels])
data_x_np = data_x.numpy()
data_y_np = data_y.numpy()
outputs = net_human(data_x)
predicted_hum = torch.round(outputs.data).numpy()[:,0]
outputs = net_machine(data_x)
predicted_mach = torch.round(outputs.data).numpy()[:,0]
points_chosen = []

get optimal deferall decision

In [None]:
# get optimal deferall decision
opt_defer_teaching = []
emperical_deferall = True # compute optimal deferall based on distribution or emperical error
for ex in range(len(predicted_hum)):
    if not emperical_deferall:
        if ex < len(cluster1) + len(cluster2):
            opt_defer_teaching.append(0)
        else:
            opt_defer_teaching.append(1)
    else:
        error_hum = (predicted_hum[ex] == data_y[ex])
        error_ai = (predicted_mach[ex] == data_y[ex])
        if error_hum > error_ai:
            opt_defer_teaching.append(0)
        else:
            opt_defer_teaching.append(1)
# get optimal gammas
from tqdm import tqdm
optimal_gammas = []
with tqdm(total=len(opt_defer_teaching)) as pbar:
    for i in range(len(opt_defer_teaching)):
        # get all similarities
        opt_defer_ex = opt_defer_teaching[i]
        similarities_embeds = rbf_kernel(data_x_np[i].reshape(1,-1), data_x_np)[0]
        sorted_sim = sorted([(similarities_embeds[k], opt_defer_teaching[k]) for k in range(len(opt_defer_teaching))], key=lambda tup: tup[0])
        indicess = list(range(1, len(opt_defer_teaching)))
        indicess.reverse()

        for k in indicess:
            if sorted_sim[k][1] == opt_defer_ex and sorted_sim[k- 1][1] != opt_defer_ex:
                optimal_gammas.append(sorted_sim[k][0])
                break
        pbar.update(1)

In [None]:
def plot(indexx, human_learner, points_chosen):
    x1 = cluster1.numpy()
    x2 = cluster2.numpy()
    x3 = cluster3.numpy()
    x4 = cluster4.numpy()
    epsilon = 0.8
    def Extract(lst, indx): 
        return [item[indx] for item in lst] 

    # get line human
    weights = net_human.fc1.weight.detach().numpy()[0]
    bias = net_human.fc1.bias.detach().numpy()[0]
    x1h =  -100
    x2h = 100
    y1h = -(1/weights[1])*(weights[0]*x1h + bias)
    y2h = -(1/weights[1])*(weights[0]*x2h + bias)

    # get line machine
    weights = net_machine.fc1.weight.detach().numpy()[0]
    bias = net_machine.fc1.bias.detach().numpy()[0]
    x1m =  -100
    x2m = 100
    y1m = -(1/weights[1])*(weights[0]*x1m + bias)
    y2m = -(1/weights[1])*(weights[0]*x2m + bias)

    data_x = torch.cat([cluster1])
    data_y = torch.cat([cluster1_labels])
    mists1, _, _ = test_posterior(net_human, net_machine, epsilon, human_learner, data_x, data_y)

    data_x = torch.cat([cluster2])
    data_y = torch.cat([cluster2_labels])
    mists2, _, _ = test_posterior(net_human, net_machine, epsilon, human_learner, data_x, data_y)

    data_x = torch.cat([cluster3])
    data_y = torch.cat([cluster3_labels])
    mists3, _, _ = test_posterior(net_human, net_machine, epsilon, human_learner, data_x, data_y)

    data_x = torch.cat([cluster4])
    data_y = torch.cat([cluster4_labels])
    mists4, _, _ = test_posterior(net_human, net_machine, epsilon, human_learner, data_x, data_y)


    fig, ax = plt.subplots(1)

    ax.set_facecolor('white')
    ax.set(xlim=(0, 15), ylim=(-2, 15))
    #ax.vlines([-12,-6,0,6,12.3],-12,12.3)
    #ax.hlines([-12,-6,0,6,12.3],-12,12.3)
    ax.plot([x1h, x2h], [y1h, y2h], color='black', marker='x',label = "human")
    ax.plot([x1m, x2m], [y1m, y2m], color='green', marker='x',label = "machine")
    ax.scatter(x1[:, 0], x1[:, 1], c = conv_to_color(mists1), cmap='RdBu',marker ='o' )
    ax.scatter(x2[:, 0], x2[:, 1], c = conv_to_color(mists2), cmap='RdBu',marker ='x' )
    ax.scatter(x3[:, 0], x3[:, 1], c = conv_to_color(mists3), cmap='RdBu', marker ='o' )
    ax.scatter(x4[:, 0], x4[:, 1], c = conv_to_color(mists4), cmap='RdBu', marker ='x' )

    ax.scatter(Extract(points_chosen,0), Extract(points_chosen,1), label = "points chosen", marker = "X", color="green",s=70 )
    gs = np.array(gammas)
    #gs = -np.log(expit(gs))
    gs = np.sqrt(-2*np.log(gs))
    for i in range(len(gs)):
        if i == 0:
            circle = plt.Circle((points_chosen[i][0], points_chosen[i][1]), gs[i], color='b', fill=False, label ="radius")
        else:
            circle = plt.Circle((points_chosen[i][0], points_chosen[i][1]), gs[i], color='b', fill=False)

        ax.add_patch(circle)

    plt.legend()
    #plt.savefig("figure"+str(indexx)+".pdf")
    plt.show()

In [None]:
knn_learner = HumanLearner()
points_chosen = []
epsilon = 0.9
RESOLUTION = 1
MAX_SIZE = 5 # size of teaching set
MAX_TRIALS = 1
_, prior_preds = test_prior(net_human, net_machine, epsilon, data_x, data_y)
data_sizes  = []
indices_used = []
test_errors = [[] for _ in range(MAX_TRIALS)]
gammas = []
plot(1, knn_learner, points_chosen)



In [None]:

for itt in range(MAX_SIZE):
    best_index = -1
    best_value = 0
    if itt % RESOLUTION == 0:
        print("###########################")
        print("evaluating at size " + str(itt) )
        _, metrics, _b = test_posterior(net_human, net_machine, epsilon, knn_learner, data_x, data_y)
        print(metrics)
        print("###########################")
        test_errors[0].append(metrics["system accuracy"])
        data_sizes.append(itt)

    if itt == 0:
        preds_teach = prior_preds
    else:
        _a, _b, preds_teach = test_posterior(net_human, net_machine, epsilon, knn_learner, data_x, data_y)

    error_improvements = get_improvement_defer(preds_teach, opt_defer_teaching, optimal_gammas, data_x_np, predicted_mach, predicted_hum, data_y_np )
    print(f'index {np.argmax(error_improvements)} val {max(error_improvements)}')
    best_index = np.argmax(error_improvements)
    indices_used.append(best_index) # add found element to set used
    repr_x = data_x_np[best_index]
    target = data_y_np[best_index] 
    points_chosen.append(repr_x)

    knn_learner.add_to_teaching((repr_x, opt_defer_teaching[best_index], optimal_gammas[best_index]))
    gammas.append(optimal_gammas[best_index])
    points_chosen = [list(pp) for pp in points_chosen]
    plot(1, knn_learner, points_chosen)
_, metrics, _b = test_posterior(net_human, net_machine, epsilon, knn_learner, data_x, data_y)
print(metrics)
