This notebook can be run immediately on google colab to replicate individually each subfigure in figure 3 in the paper. We recommend running for only 1 trial, as running for 10 trials can take significantly longer with the code version below not optimized for parallel computations. 

# Imports

this is install for colab

In [None]:
!pip install transformers
!pip install sentence_transformers
!pip install datasets
!pip install sentencepiece
!pip install scikit-learn-extra
!pip install lime

In [None]:
from __future__ import print_function
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.backends.cudnn as cudnn
import torch.nn.parallel
from torchvision.datasets.utils import download_url, check_integrity
from torch.utils.data import DataLoader
import argparse
from collections import Counter
import string
import re
import argparse
import json
import pickle
import os
import random
import shutil
from matplotlib import pyplot as plt
import lime
import lime.lime_tabular
from tqdm import tqdm
import sys
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.gaussian_process.kernels import RBF
from sklearn.metrics.pairwise import rbf_kernel
from scipy.stats import multivariate_normal
import  scipy.stats as st
from matplotlib import cm
import torch.optim as optim
from __future__ import print_function
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AdamW, AutoModel
import pickle
import spacy
from gensim import corpora, models, similarities
import gensim
from spacy.lang.en import English
from nltk.corpus import wordnet as wn
from sklearn.metrics.pairwise import rbf_kernel
from nltk.stem.wordnet import WordNetLemmatizer
from sklearn.cluster import KMeans
import pickle
import matplotlib
from sklearn.metrics.pairwise import rbf_kernel
import copy
import nltk
from nltk.corpus import stopwords


In [None]:
nltk.download('stopwords')
stopwords = stopwords.words('english')
nltk.download('punkt')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rc('text', usetex=False)
plt.rc('font', family='serif')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# HotpotQA
Load the hotpotQA dataset with our modifications.


In [None]:
from datasets import load_dataset
DISTRACTORS_PARS_LEN = 0
dataset = load_dataset("hotpot_qa", 'distractor')
train_dataset = dataset['train']
validation_dataset = dataset['validation']

In [None]:
def get_token_lenght(context, question):
    inputs = tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt")
    return len(inputs['input_ids'][0])

def remove_distractors(example, n_distractors = 0):
    all_pars = example['context']['title']
    gold_pars = list(set(example['supporting_facts']['title']))
    distractor_pars = list(set(all_pars) - set(gold_pars))
    # get indices to keep from disractors
    if len(distractor_pars) == 0 or n_distractors == 0:
        distract_indices = []
    else:
        distract_indices = random.sample(range(len(distractor_pars)), n_distractors)
    distractor_pars = [distractor_pars[idx] for idx in distract_indices]
    keep_pars = gold_pars + distractor_pars
    keep_pars_indices = [all_pars.index(keep_par) for keep_par in keep_pars]
    example['context']['title'] = [example['context']['title'][idx] for idx in keep_pars_indices]
    example['context']['sentences'] = [example['context']['sentences'][idx] for idx in keep_pars_indices]
    
    sentences_par = ["".join(example['context']['sentences'][idx]) for idx in range(len(example['context']['sentences']))]
    example['intros'] = " ".join([example['context']['sentences'][idx][0] for idx in range(len(example['context']['sentences']))])
    example['passage'] = " #$ ".join(sentences_par)
    return example



In [None]:
train_dataset = train_dataset.map(remove_distractors)
validation_dataset = validation_dataset.map(remove_distractors)

In [None]:
all_indices = set(list(range(len(validation_dataset))))
yes_no_indices = set([i  for i in range(len(validation_dataset)) if validation_dataset[i]['answer'] in ["yes","no"] ])
free_indices_val = all_indices - yes_no_indices
all_indices = set(list(range(len(train_dataset))))
yes_no_indices_train = set([i  for i in range(len(train_dataset)) if (train_dataset[i]['answer'] in ["yes","no"]) or (train_dataset[i]['level'] in ['easy','medium']) ])
free_indices_train = all_indices - yes_no_indices_train
print(len(free_indices_val))
print(len(free_indices_train))


In [None]:
hard_dataset = [train_dataset[i] for i in free_indices_train ]
hard_dataset += [validation_dataset[i] for i in free_indices_val]
dataset_indices =  set(list(range(len(hard_dataset))))
VAL_START = len([train_dataset[i] for i in free_indices_train ])
len(hard_dataset)

## Utils
Utilities for evaluating QA models

In [None]:

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)



def evaluate_truth_pred(truths, preds):
    '''
    truths, preds: matched arrays of ground truth answers and predictions
    '''
    f1 = exact_match = total = 0
    for i in range(len(truths)):
        total += 1
        if truths[i] in ['yes', "no"]:
            continue
        ground_truths = [truths[i]]
        prediction = preds[i]
        exact_match += metric_max_over_ground_truths(
            exact_match_score, prediction, ground_truths)
        f1 += metric_max_over_ground_truths(
            f1_score, prediction, ground_truths)

    exact_match = 100.0 * exact_match / (total+ 0.00000000001)
    f1 = 100.0 * f1 / (total+ 0.00000000001)

    return {'exact_match': exact_match, 'f1': f1}


def get_answer( model, tokenizer, context, question):
    # 1. TOKENIZE THE INPUT
    # note: if you don't include return_tensors='pt' you'll get a list of lists which is easier for
    # exploration but you cannot feed that into a model.
    inputs = tokenizer.encode_plus(question, context, return_tensors="pt")
    inputs = inputs.to(device)
    # 2. OBTAIN MODEL SCORES
    # the AutoModelForQuestionAnswering class includes a span predictor on top of the model.
    # the model returns answer start and end scores for each word in the text
    answer_start_scores, answer_end_scores = model(**inputs, return_dict=False)
    answer_start = torch.argmax(answer_start_scores)  # get the most likely beginning of answer with the argmax of the score
    answer_end = torch.argmax(answer_end_scores) + 1  # get the most likely end of answer with the argmax of the score
    # 3. GET THE ANSWER SPAN
    # once we have the most likely start and end tokens, we grab all the tokens between them
    # and convert tokens back to words!
    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
    return answer
# https://huggingface.co/transformers/migration.html
def model_evaluation(model, tokenizer, questions, contexts, answers, to_print = False):
    preds = []
    my_list = list(range(len(questions)))
    with tqdm(total=len(my_list)) as pbar:
        for ex in range(len(questions)):
            answer = get_answer(model, tokenizer, contexts[ex],questions[ex])
            preds.append(answer)
            if ex % 100 == 0 and to_print:
                print("context " +contexts[ex] )
                print("quest " +questions[ex] )
                print("truth " +answers[ex]['text'] )
                print("pred " + answer)
            pbar.update(1)
    truths = [answers[i]['text'] for i in range(len(answers))]
    scores = evaluate_truth_pred(truths, preds)
    print(scores)
    return scores


def kernel_similarity(x, y):
        kernel_dist = rbf_kernel(x.reshape(1, -1),y.reshape(1, -1))
        return kernel_dist[0][0]


# FakeAI model
The simulatedAI model

In [None]:
class FakeAI:
    r"""
    FakeAI that is incorrect in random clusters in the input space
    
    Args:
        train_passages: list of embedded passages
        train_questions: list of embedded questions
        n_clus_p: number of centroids in kmeans for passages
        n_clus_q: number of centroids in kmeans for questions
        clust_err_p: clusters where AI makes error on passages
        clust_err_q: clusters where AI makes error on questions
    """
    def __init__(self, train_passages, train_questions, n_clus_p, n_clus_q, beta_params = (1,1), kmeans = None, err_p = None):
 
        self.train_passages = train_passages
        self.train_questions = train_questions
        self.n_clus_p = n_clus_p
        self.n_clus_q = n_clus_q
        if err_p == None:
            self.clust_err_p = np.random.beta(beta_params[0],beta_params[1],self.n_clus_p)
        else:
            self.clust_err_p = err_p
        self.clust_err_q = np.random.beta(0.2,0.8,self.n_clus_p)#np.random.rand(self.n_clus_q)#/2 + 0.5
        if kmeans == None:
            self.build_kmeans()
        else:
            self.kmeans_pass = kmeans
            self.kmeans_quest = kmeans

        
    def build_kmeans(self):
        # Builds kmeans for passages and questions
        self.kmeans_quest = KMeans(n_clusters=self.n_clus_q, max_iter = 100000).fit(self.train_questions)
        self.kmeans_pass = KMeans(n_clusters=self.n_clus_p, max_iter = 100000).fit(self.train_passages)

    def predict_right_wrong(self, passages, questions):
        '''
        Args:
            passages: list of embedded passages
            questions: list of embedded questions (same size as passages)
        Returns:
            preds: binary array indicating if AI is right (1) or wrong (0)
        '''
        clusts_p = self.kmeans_pass.predict(passages)
        clusts_q = self.kmeans_quest.predict(questions)
        preds = []
        for i in range(len(clusts_p)):
            coin = random.random() # random number between [0,1]
            if coin >= self.clust_err_p[clusts_p[i]]:
                preds.append(0)
            else:
                preds.append(1)
        return preds

    
    def predict_raw_probas(self, passages, questions):
        '''
        Args:
            passages: list of embedded passages
            questions: list of embedded questions (same size as passages)
        Returns:
            preds: binary array indicating if AI is right (1) or wrong (0)
        '''
        clusts_p = self.kmeans_pass.predict(passages)
        clusts_q = self.kmeans_quest.predict(questions)
        preds = []
        for i in range(len(clusts_p)):
            preds.append(self.clust_err_p[clusts_p[i]])
        return preds
    
    def prior_rejector(self, passages, questions, epsilon_reject):
        '''
        rejector defined as : 1{prob_of_correct <= epsilon_reject}
        '''
        clusts_p = self.kmeans_pass.predict(passages)
        clusts_q = self.kmeans_quest.predict(questions)
        preds = []
        for i in range(len(clusts_p)):
            coin = random.random() # random number between [0,1]
            if  self.clust_err_p[clusts_p[i]] <= epsilon_reject:
                preds.append(1)
            else:
                preds.append(0)
        return preds

    def predict_proba(self, raw_passages):
        embed_ps = model.encode(raw_passages)
        embed_ps = [embed_ps[i] for i in range(len(embed_ps))]
        clusts_p = self.kmeans_pass.predict(embed_ps)
        preds = []
        for i in range(len(clusts_p)):
            coin = random.random() # random number between [0,1]
            if coin >= self.clust_err_p[clusts_p[i]]:
                preds.append([0.0,1.0])
            else:
                preds.append([1.0,0.0])
        return np.asarray(preds)

    def predict_proba_encoded(self, passages):
        clusts_p = self.kmeans_pass.predict(np.array(passages, dtype=np.float32))
        preds = []
        for i in range(len(clusts_p)):
            coin = random.random() # random number between [0,1]
            if coin >= self.clust_err_p[clusts_p[i]]:
                preds.append([0.0,1.0])
            else:
                preds.append([1.0,0.0])
        return np.asarray(preds)


def is_a_stopword(feature, weight):
    split_words = feature.split(' ')
    for word in split_words:
        if word in stopwords:
            return False
    return True


In [None]:
from sentence_transformers import SentenceTransformer
# sentence embedding model
model = SentenceTransformer('paraphrase-distilroberta-base-v1')
# https://www.sbert.net/docs/pretrained_models.html

# Teaching Prelims



## Human classes and utilities

In [None]:
class HumanLearner:
    def __init__(self, kernel):
        '''
        kernel: function that takes two inputs and returns a similarity
        '''
        self.teaching_set = []
        self.kernel = kernel
    def predict(self, xs, prior_rejector_preds, to_print = False):
        '''
        xs: expected array of inputs
        '''
        preds = []
        idx = 0
        used_posterior = 0 
        if to_print:
            print("-- Human making reject predictions --")
            with tqdm(total=len(xs)) as pbar:
                for x in xs:
                    ball_at_x = []
                    similarities = rbf_kernel(x.reshape(1,-1), np.asarray([self.teaching_set[kk][0] for kk in range(len(self.teaching_set))]))[0]
                    for i in range(len(self.teaching_set)):
                        similarity = similarities[i]
                        if similarity >=  self.teaching_set[i][2]:
                            ball_at_x.append(self.teaching_set[i])
                    if len(ball_at_x) == 0: 
                        # use prior rejector
                        preds.append(prior_rejector_preds[idx])
                    else:
                        used_posterior += 1
                        ball_similarities = rbf_kernel(x.reshape(1,-1), np.asarray([ball_at_x[kk][0] for kk in range(len(ball_at_x))]))[0]
                        normalization = np.sum([ball_similarities[i] for i in range(len(ball_at_x))])
                        score_one = np.sum([ball_similarities[i]*ball_at_x[i][1] for i in range(len(ball_at_x))])
                        pred = score_one / normalization
                        if pred >= 0.5:
                            preds.append(1)
                        else:
                            preds.append(0)
                    idx += 1
                    pbar.update(1)
        else:
            for x in xs:
                ball_at_x = []
                similarities = rbf_kernel(x.reshape(1,-1), np.asarray([self.teaching_set[kk][0] for kk in range(len(self.teaching_set))]))[0]
                for i in range(len(self.teaching_set)):
                    similarity = similarities[i]
                    if similarity >=  self.teaching_set[i][2]:
                        ball_at_x.append(self.teaching_set[i])
                if len(ball_at_x) == 0: 
                    # use prior rejector
                    preds.append(prior_rejector_preds[idx])
                else:
                    used_posterior += 1
                    ball_similarities = rbf_kernel(x.reshape(1,-1), np.asarray([ball_at_x[kk][0] for kk in range(len(ball_at_x))]))[0]
                    normalization = np.sum([ball_similarities[i] for i in range(len(ball_at_x))])
                    score_one = np.sum([ball_similarities[i]*ball_at_x[i][1] for i in range(len(ball_at_x))])
                    pred = score_one / normalization
                    if pred >= 0.5:
                        preds.append(1)
                    else:
                        preds.append(0)
                idx += 1
        if to_print:
            print(f'Used posterior {used_posterior/len(xs)*100:.2f}')
        return preds

    def add_to_teaching(self, teaching_example):
        '''
        teaching_example: (x, label, gamma)
        '''
        self.teaching_set.append(teaching_example)

    def remove_last_teaching_item(self):
        self.teaching_set = self.teaching_set[:-1]

In [None]:
def compute_predictions_humanai(hum_preds, hum_rejector, ai_preds, data_x):
    '''
    hum_preds: array of human predictions
    ai_preds: array of AI predictions
    hum_rejector: HumanLearner
    data_x: array of inputs

    Returns array of final predictions and deferalls
    '''
    predictions = []
    with torch.no_grad():
        reject_decisions = hum_rejector(data_x)
        for i in range(len(data_x)):
            if reject_decisions[i] == 1:
                # defer
                predictions.append(ai_preds[i])
            else:
                predictions.append(hum_preds[i])
    return predictions, reject_decisions

def get_metrics(preds, truths):
    # custom for each use case
    return evaluate_truth_pred(truths, preds)

def compute_metrics(human_preds, ai_preds, reject_decisions, truths, to_print = False):
    coverage = 1 - np.sum(reject_decisions)/len(reject_decisions)
    humanai_preds = []
    human_preds_sys = []
    truths_human = []
    ai_preds_sys = []
    truths_ai = []
    for i in range(len(reject_decisions)):
        if reject_decisions[i] == 1:
            humanai_preds.append(ai_preds[i])
            ai_preds_sys.append(ai_preds[i])
            truths_ai.append(truths[i])
        else:
            humanai_preds.append(human_preds[i])
            human_preds_sys.append(human_preds[i])
            truths_human.append(truths[i])
    humanai_metrics = get_metrics(humanai_preds, truths)
    human_metrics = get_metrics(human_preds_sys, truths_human)
    ai_metrics = get_metrics(ai_preds_sys, truths_ai)
    if to_print:
        print(f'Coverage is {coverage*100:.2f}')
        print(f' metrics of system are: {humanai_metrics}')
        print(f' metrics of human are: {human_metrics}')
        print(f' metrics of AI are: {ai_metrics}')
    return coverage, humanai_metrics, human_metrics, ai_metrics

## Get predictions 
Obtain embeddings of the passages and other data used for teaching

In [None]:
# Define teaching and validation
random.seed(66)
all_passages = [hard_dataset[i]['passage'] for i in range(len(hard_dataset))]
embeddings_passage = model.encode(all_passages)
embeddings_passage = [embeddings_passage[i] for i in range(len(embeddings_passage))]

all_questions = [hard_dataset[i]['question'] for i in range(len(hard_dataset))]
all_answers = [hard_dataset[i]['answer'] for i in range(len(hard_dataset))]
all_sentences =  [hard_dataset[i]['context']['sentences'] for i in range(len(hard_dataset))]

train_indices = list(range(VAL_START))
test_indices = list(range(VAL_START,len(hard_dataset)))

train_passages = [embeddings_passage[i] for i in train_indices]
#train_questions = [embeddings_question[i] for i in train_indices]
test_passages = [embeddings_passage[i] for i in test_indices]
#test_questions = [embeddings_question[i] for i in test_indices]
train_sentences = [all_sentences[i] for i in train_indices]
test_sentences = [all_sentences[i] for i in test_indices]

train_answers = [all_answers[i] for i in train_indices]
test_answers = [all_answers[i] for i in test_indices]
train_raw_passages = [all_passages[i] for i in train_indices]
test_raw_passages = [all_passages[i] for i in test_indices]
train_raw_questions= [all_questions[i] for i in train_indices]
test_raw_questions = [all_questions[i] for i in test_indices]


In [None]:

teaching_embeddings = np.asarray([np.concatenate([train_passages[i]]) for i in range(len(train_passages))])
validation_embeddings = np.asarray([np.concatenate([test_passages[i]]) for i in range(len(test_passages))])
del embeddings_passage
del model
del train_passages
del test_passages


In [None]:
similarities_embeds_all = rbf_kernel(np.asarray(teaching_embeddings), np.asarray(teaching_embeddings))
sorted_sims = []
for i in range(len(similarities_embeds_all)):
    sorted_sim = sorted([(similarities_embeds_all[i][k], k) for k in range(len(teaching_embeddings))], key=lambda tup: tup[0])
    sorted_sims.append(np.asarray(sorted_sim))

In [None]:
# get optimal gammas
def get_optimal_gammas():
    optimal_gammas = []
    with tqdm(total=len(teaching_embeddings)) as pbar:
        similarities_embeds_all = rbf_kernel( np.asarray(teaching_embeddings), np.asarray(teaching_embeddings))
        for i in range(len(teaching_embeddings)):
            # get all similarities
            similarities_embeds = similarities_embeds_all[i]
            opt_defer_ex = opt_defer_teaching[i]
            opt_gamma = 1
            sorted_sim = sorted([(similarities_embeds[k], opt_defer_teaching[k]) for k in range(len(teaching_embeddings))], key=lambda tup: tup[0])
            indicess = list(range(1, len(opt_defer_teaching)))
            indicess.reverse()
            for k in indicess:
                if sorted_sim[k][1] == opt_defer_ex and sorted_sim[k- 1][1] != opt_defer_ex:
                    opt_gamma = sorted_sim[k][0]
                    break
            optimal_gammas.append(opt_gamma)
            pbar.update(1)
    return optimal_gammas

# Methods

## Selection Algorithm- consistent (ours)

In [None]:
def get_improvement_defer(current_defer_preds, opt_defer_preds, gammas, xs, coin_prob = 0.15):
    error_improvements = []
    error_at_i = 0
    for i in range(len(gammas)):
        coin = random.random() # random number between [0,1]
        error_at_i = 0
        similarities_embeds = similarities_embeds_all[i]
        for j in range(len(similarities_embeds)):
            if similarities_embeds[j] >= gammas[i]:
                f1_hum = hum_teaching_preds_b[j]
                f1_ai = ai_teaching_preds_b[j]
                if opt_defer_preds[i] == 1:
                    if current_defer_preds[j] == 0:
                        error_at_i += f1_ai - f1_hum
                else:
                    if current_defer_preds[j] == 1:
                        error_at_i += f1_hum - f1_ai
        error_improvements.append(error_at_i)

        # get the ball for x
        # in this ball how many does the current defer not match the optimal
    return error_improvements


def get_greedy_gamma(i, current_defer_preds, opt_defer_preds, gammas, xs):
    similarities_embeds = similarities_embeds_all[i]
    sorted_sim = sorted_sims[i]#sorted([(similarities_embeds[k], k) for k in range(len(teaching_embeddings))], key=lambda tup: tup[0])
    indicess = list(range(1, len(opt_defer_teaching)))
    indicess.reverse()
    max_improve = -1000
    gamma_value = 1
    current_improve = 0
    so_far = 0
    for j in indicess:
        so_far += 1
        idx = int(sorted_sim[j][1])
        #f1_hum = metric_max_over_ground_truths(f1_score, train_answers[idx], [hum_teaching_preds[idx]]) # pass as param plz
        #f1_ai = metric_max_over_ground_truths(f1_score, train_answers[idx], [ai_teaching_preds[idx]])# pass as param plz
        f1_hum = hum_teaching_preds_b[idx]
        f1_ai = ai_teaching_preds_b[idx]
        if opt_defer_preds[i] == 1:
            if current_defer_preds[idx] == 0:
                current_improve += f1_ai - f1_hum
        else:
            if current_defer_preds[idx] == 1:
                current_improve += f1_hum - f1_ai

        if current_improve >= max_improve:
            max_improve = current_improve 
            gamma_value = sorted_sim[j][0]

    return max_improve, gamma_value


In [None]:
def teach_ours(greedy_gamma = False):
    human_learner = HumanLearner(None)
    errors = []
    data_sizes  = []
    indices_used = []
    points_chosen = []
    for itt in range(MAX_SIZE):
        print(f'New size {itt}')
        best_index = -1
        # predict with current human learner
        if itt == 0:
            preds_teach = priorhum_teaching_preds
        else:
            preds_teach = human_learner.predict(teaching_embeddings, priorhum_teaching_preds)
        error_improvements = get_improvement_defer(preds_teach, opt_defer_teaching, optimal_gammas, teaching_embeddings)
        best_index = np.argmax(error_improvements)
        indices_used.append(best_index) # add found element to set used
        ex_embed = teaching_embeddings[best_index]
        ex_label = opt_defer_teaching[best_index]

        if greedy_gamma:
            _, greedy_gamma = get_greedy_gamma(best_index, preds_teach, opt_defer_teaching, optimal_gammas, teaching_embeddings)
            gamma = greedy_gamma
            print(f'got improvements with max {_}')

        else:
            gamma = optimal_gammas[best_index]
            print(f'got improvements with max {max(error_improvements)}')


        human_learner.add_to_teaching([ex_embed, ex_label, gamma])

        if False and itt % 3 == 0:
            print("####### train eval " +str(itt)+ " ###########")
            preds_teach = human_learner.predict(teaching_embeddings, priorhum_teaching_preds)
            _, metrics, __, ___ = compute_metrics(hum_teaching_preds, ai_teaching_preds, preds_teach, train_answers, True)
            #errors.append(metrics)   
            print("##############################")

        if itt % PLOT_INTERVAL == 0:
            print("####### val eval " +str(itt)+ " ###########")
            preds_teach = human_learner.predict(validation_embeddings, priorhum_validation_preds)
            _, metrics, __, ___ = compute_metrics(hum_validation_preds, ai_validation_preds, preds_teach, test_answers, True)
            errors.append(metrics['exact_match'])   
            print("##############################")
    return errors, indices_used
#errors, indices_used = teach_ours(True)

## Selection double greedy

In [None]:

indicess = list(range(1, len(teaching_embeddings)))
indicess.reverse()
def get_improvement_defer_greedy(current_defer_preds, opt_defer_preds, gammas, xs):

    error_improvements = []
    #similarities_embeds_all = rbf_kernel(np.asarray(xs), np.asarray(xs))
    error_at_i = 0
    found_gammas = []
    for i in range(len(opt_defer_preds)):
        coin = random.random() # random number between [0,1]
        similarities_embeds = similarities_embeds_all[i]
        sorted_sim = sorted_sims[i]#sorted([(similarities_embeds[k], k) for k in range(len(teaching_embeddings))], key=lambda tup: tup[0])

        max_improve = -1000
        gamma_value = 1
        current_improve = 0
        so_far = 0
        for j in indicess:
            so_far += 1
            idx = int(sorted_sim[j][1])
            f1_hum = hum_teaching_preds_b[idx]
            f1_ai = ai_teaching_preds_b[idx]
            if opt_defer_preds[i] == 1:
                if current_defer_preds[idx] == 0:
                    current_improve += f1_ai - f1_hum
            else:
                if current_defer_preds[idx] == 1:
                    current_improve += f1_hum - f1_ai

            if current_improve >= max_improve:
                max_improve = current_improve 
                gamma_value = sorted_sim[j][0]
            
        error_improvements.append(max_improve)
        found_gammas.append(gamma_value)
    return error_improvements, found_gammas


In [None]:
def teach_ours_doublegreedy(noisy_radius = False):
    human_learner = HumanLearner(None)
    errors = []
    data_sizes  = []
    indices_used = []
    points_chosen = []
    for itt in range(MAX_SIZE):
        print(f'New size {itt}')
        best_index = -1
        # predict with current human learner
        if itt == 0:
            preds_teach = priorhum_teaching_preds
        else:
            preds_teach = human_learner.predict(teaching_embeddings, priorhum_teaching_preds)
        error_improvements, best_gammas = get_improvement_defer_greedy(preds_teach, opt_defer_teaching, optimal_gammas, teaching_embeddings)
        print(f'got improvements with max {max(error_improvements)}')
        best_index = np.argmax(error_improvements)
        indices_used.append(best_index) # add found element to set used
        ex_embed = teaching_embeddings[best_index]
        ex_label = opt_defer_teaching[best_index]
        gamma = best_gammas[best_index] # + (np.random.rand(1)[0])*(1-optimal_gammas[best_index])-(1-optimal_gammas[best_index])/2 # random choice
        if noisy_radius:
            gamma = best_gammas[best_index]  + (np.random.rand(1)[0])*(1-optimal_gammas[best_index])-(1-optimal_gammas[best_index])/2 # random choice
        human_learner.add_to_teaching([ex_embed, ex_label, gamma])

        if False and itt % 3 == 0:
            print("####### train eval " +str(itt)+ " ###########")
            preds_teach = human_learner.predict(teaching_embeddings, priorhum_teaching_preds)
            _, metrics, __, ___ = compute_metrics(hum_teaching_preds, ai_teaching_preds, preds_teach, train_answers, True)
            #errors.append(metrics)   
            print("##############################")

        if itt % PLOT_INTERVAL == 0:
            print("####### val eval " +str(itt)+ " ###########")
            preds_teach = human_learner.predict(validation_embeddings, priorhum_validation_preds)
            _, metrics, __, ___ = compute_metrics(hum_validation_preds, ai_validation_preds, preds_teach, test_answers, True)
            errors.append(metrics['exact_match'])   
            print("##############################")
    return errors, indices_used
#errors_doublegreedy, indices_used_doublegreedy = teach_ours_doublegreedy()

## Random baseline

In [None]:
def teach_random(greedy_gamma = False):
    human_learner_random = HumanLearner(None)
    errors_random = []
    data_sizes  = []
    indices_used_random = random.sample(list(range(len(teaching_embeddings))), MAX_SIZE) # used to take gradient steps
    points_chosen = []
    for itt in range(MAX_SIZE):
        print(f'New size {itt}')
        best_index = indices_used_random[itt]
        ex_embed = teaching_embeddings[best_index]
        ex_label = opt_defer_teaching[best_index]
        if greedy_gamma:
            if itt == 0 or first_time:
                preds_teach = priorhum_teaching_preds
                first_time = False
            else:
                preds_teach = human_learner_random.predict(teaching_embeddings, priorhum_teaching_preds)
            _, greedy_gamma = get_greedy_gamma(best_index, preds_teach, opt_defer_teaching, optimal_gammas, teaching_embeddings)
            gamma = greedy_gamma
        else:
            gamma = optimal_gammas[best_index]
        human_learner_random.add_to_teaching([ex_embed, ex_label, gamma])

        if  False and itt % 100 == 0:
            print("####### train eval " +str(itt)+ " ###########")
            preds_teach = human_learner_random.predict(teaching_embeddings, priorhum_teaching_preds)
            _, metrics, __, ___ = compute_metrics(hum_teaching_preds, ai_teaching_preds, preds_teach, train_answers, True)
            #errors_random.append(metrics)   
            print("##############################")

        if  (itt) % PLOT_INTERVAL == 0:
            print("####### val eval " +str(itt)+ " ###########")
            preds_teach = human_learner_random.predict(validation_embeddings, priorhum_validation_preds)
            _, metrics, __, ___ = compute_metrics(hum_validation_preds, ai_validation_preds, preds_teach, test_answers, True)
            errors_random.append(metrics['exact_match'])   
            print("##############################")
    return errors_random, indices_used_random
#errors_random, indices_used_random = teach_random()

## Medoids Baseline

In [None]:
from sklearn_extra.cluster import KMedoids
def teach_medoids(greedy_gamma = False):
    
    human_learner_medoid = HumanLearner(None)
    errors_medoid = []
    data_sizes  = []
    indices_used_medoid = []
    points_chosen = []
    
    for itt in range(MAX_SIZE):
        if False and itt % 3 == 0:
            print("####### train eval " +str(itt)+ " ###########")
            preds_teach = human_learner_medoid.predict(teaching_embeddings, priorhum_teaching_preds)
            _, metrics, __, ___ = compute_metrics(hum_teaching_preds, ai_teaching_preds, preds_teach, train_answers, True)
            #errors_medoid.append(metrics)   
            print("##############################")

        if itt % PLOT_INTERVAL == 0:
            human_learner_medoid = HumanLearner(None)
            kmedoids = KMedoids(n_clusters=itt+1, init='k-medoids++').fit(teaching_embeddings)
            teaching_indices = kmedoids.medoid_indices_

            print(f'New size {itt}')
            first_time = True
            for teach_ex_idx in teaching_indices:
                best_index = teach_ex_idx
                if greedy_gamma:
                    if itt == 0 or first_time:
                        preds_teach = priorhum_teaching_preds
                        first_time = False
                    else:
                        preds_teach = human_learner_medoid.predict(teaching_embeddings, priorhum_teaching_preds)
                    _, greedy_gamma = get_greedy_gamma(best_index, preds_teach, opt_defer_teaching, optimal_gammas, teaching_embeddings)
                    gamma = greedy_gamma
                else:
                    gamma = optimal_gammas[best_index]
                ex_embed = teaching_embeddings[best_index]
                ex_label = opt_defer_teaching[best_index]
                 #+ (np.random.rand(1)[0])*(1-optimal_gammas[best_index])-(1-optimal_gammas[best_index])/2 # random choice
                human_learner_medoid.add_to_teaching([ex_embed, ex_label, gamma])

            print("####### val eval " +str(itt)+ " ###########")
            preds_teach = human_learner_medoid.predict(validation_embeddings, priorhum_validation_preds)
            _, metrics, __, ___ = compute_metrics(hum_validation_preds, ai_validation_preds, preds_teach, test_answers, True)
            errors_medoid.append(metrics['exact_match'])   
            print("##############################")
    return errors_medoid, indices_used_medoid
#errors_medoid, indices_used_medoid = teach_medoids()

## LIME baseline

In [None]:
def get_improvement_lime(features_covered, xs, indices_used_lime):
    error_improvements = []
    for i in range(len(xs)):
        error_at_i = 0
        if i in indices_used_lime:
            error_at_i = -10000
        for feat_id, feat_val in individual_feature_importance[i].items():
            if features_covered[feat_id] == 0:
                error_at_i +=  global_feature_importance[feat_id]
                    
        error_improvements.append(error_at_i)
        # get the ball for x
        # in this ball how many does the current defer not match the optimal
    
    best_index = np.argmax(error_improvements)
    for feat_id, feat_val in individual_feature_importance[best_index].items():
        if features_covered[feat_id] == 0:
            features_covered[feat_id] = 1
    return error_improvements, features_covered

In [None]:
def teach_lime(greedy_gamma = False):
    human_learner = HumanLearner(None)
    errors_lime = []
    data_sizes  = []
    indices_used_lime = {}
    points_chosen = []
    features_covered = {}
    for i in range(DATA_DIM):
        features_covered[i] = 0

    for itt in range(MAX_SIZE):
        print(f'New size {itt}')
        best_index = -1
        # predict with current human learner

        error_improvements, features_covered = get_improvement_lime(features_covered, teaching_embeddings[:cutoff_size], indices_used_lime)
        print(f'got improvements with max {max(error_improvements)}')
        best_index = np.argmax(error_improvements)
        indices_used_lime[best_index] =1 # add found element to set used
        ex_embed = teaching_embeddings[best_index]
        ex_label = opt_defer_teaching[best_index]

        if greedy_gamma:
            if itt == 0 :
                preds_teach = priorhum_teaching_preds
            else:
                preds_teach = human_learner.predict(teaching_embeddings, priorhum_teaching_preds)
                
            _, greedy_gamma = get_greedy_gamma(best_index, preds_teach, opt_defer_teaching, optimal_gammas, teaching_embeddings)
            gamma = greedy_gamma
        else:
            gamma = optimal_gammas[best_index]

        #gamma = optimal_gammas[best_index]  #+ (np.random.rand(1)[0])*2*(1-optimal_gammas[best_index])-(1-optimal_gammas[best_index]) # random choice
        human_learner.add_to_teaching([ex_embed, ex_label, gamma])

        if False and itt % 3 == 0:
            print("####### train eval " +str(itt)+ " ###########")
            preds_teach = human_learner.predict(teaching_embeddings, priorhum_teaching_preds)
            _, metrics, __, ___ = compute_metrics(hum_teaching_preds, ai_teaching_preds, preds_teach, train_answers, True)
            #errors_lime.append(metrics)   
            print("##############################")

        if itt % PLOT_INTERVAL == 0:
            print("####### val eval " +str(itt)+ " ###########")
            preds_teach = human_learner.predict(validation_embeddings, priorhum_validation_preds)
            _, metrics, __, ___ = compute_metrics(hum_validation_preds, ai_validation_preds, preds_teach, test_answers, True)
            errors_lime.append(metrics['exact_match'])   
            print("##############################")
    return errors_lime, indices_used_lime
#errors_lime, indices_used_lime = teach_lime()

## Learn AI behavior baseline

In [None]:
from sklearn.neighbors import RadiusNeighborsClassifier, KNeighborsClassifier
def teach_learnai(greedy_gamma = False):
    human_learner_learnai = HumanLearner(None)
    errors_learnai = []
    data_sizes  = []
    indices_used_learnai = []
    points_chosen = []
    set_xs = [teaching_embeddings[0]]
    set_ys = [ai_teaching_preds_b[0]]
    for itt in range(MAX_SIZE):
        print(f'New size {itt}')
        best_index = 0
        best_value = 0
        neigh = KNeighborsClassifier(n_neighbors = 1, weights='distance')
        random_teach_subset = random.sample(list(range(len(teaching_embeddings))), 2000) 
        random_test_subset = random.sample(list(range(len(teaching_embeddings))), 1000) 
        for j in random_teach_subset:
            x_try = teaching_embeddings[j]
            y_try = ai_teaching_preds_b[j]
            set_xs.append(x_try)
            set_ys.append(y_try)
            np_set_xs = np.asarray(set_xs)
            np_set_ys = np.asarray(set_ys)
            if len(np_set_xs) > 5:
                neigh = KNeighborsClassifier(n_neighbors = 5, weights='distance')
            else:
                neigh = KNeighborsClassifier(n_neighbors = 1, weights='distance')
            neigh.fit(np_set_xs, np_set_ys)
            acc = neigh.score(np.asarray([teaching_embeddings[kk] for kk in random_test_subset]), np.asarray([ai_teaching_preds_b[kk] for kk in random_test_subset]))
            if acc >= best_value:
                best_value = acc
                best_index = j
            set_xs = set_xs[:-1]
            set_ys = set_ys[:-1]
        print(best_value)
        indices_used_learnai.append(best_index)
        ex_embed = teaching_embeddings[best_index]
        ex_label = opt_defer_teaching[best_index]
        
        if greedy_gamma:
            if itt == 0 :
                preds_teach = priorhum_teaching_preds
                first_time = False
            else:
                preds_teach = human_learner_learnai.predict(teaching_embeddings, priorhum_teaching_preds)
                
            _, greedy_gamma = get_greedy_gamma(best_index, preds_teach, opt_defer_teaching, optimal_gammas, teaching_embeddings)
            gamma = greedy_gamma
        else:
            gamma = optimal_gammas[best_index]
        human_learner_learnai.add_to_teaching([ex_embed, ex_label, gamma])


        if False and itt % 3 == 0:
            print("####### train eval " +str(itt)+ " ###########")
            preds_teach = human_learner_learnai.predict(teaching_embeddings, priorhum_teaching_preds)
            _, metrics, __, ___ = compute_metrics(hum_teaching_preds, ai_teaching_preds, preds_teach, train_answers, True)
            #errors_learnai.append(metrics)   
            print("##############################")

        if itt % PLOT_INTERVAL == 0:
            print("####### val eval " +str(itt)+ " ###########")
            preds_teach = human_learner_learnai.predict(validation_embeddings, priorhum_validation_preds)
            _, metrics, __, ___ = compute_metrics(hum_validation_preds, ai_validation_preds, preds_teach, test_answers, True)
            errors_learnai.append(metrics['exact_match'])   
            print("##############################")
    return errors_learnai, indices_used_learnai
#errors_learnai, indices_used_learnai = teach_learnai()

# Replicate Figure 3

In [None]:
human_predictor = FakeAI(teaching_embeddings, teaching_embeddings, 15, 15,  (2,1), kmeans = None)
kmeans_pass = human_predictor.kmeans_pass

In [None]:
setting = "A"
greedy_gamma = False # if False: \alpha=0, if True: \alpha = 1

In [None]:
if setting == "A":
    REJECT_EPSILON = 0.1
else:
    REJECT_EPSILON = 0.9
MAX_TRIALS = 1
MAX_SIZE = 30 + 1
PLOT_INTERVAL = 2

In [None]:
'''
At each trial, get new human and AI models, get optimal gammas and evaluate methods
'''
scores_ours = []
scores_medoid = []
scores_random = []
scores_aibaseline = []
scores_oracle = []
scores_lime = []
scores_human = []
for trial in range(MAX_TRIALS):
    # Get Human Predictions
    print(f' \n \n trial {trial}  \n \n')
    hum_teaching_preds = []
    hum_validation_preds = []
    priorhum_teaching_preds = []
    priorhum_validation_preds = []
    if setting == "A":
        human_predictor = FakeAI(teaching_embeddings, teaching_embeddings, 15, 15,  (1,1), kmeans = kmeans_pass)
    else:
        human_predictor = FakeAI(teaching_embeddings, teaching_embeddings, 15, 15,  (2,1), kmeans = kmeans_pass)
    hum_teaching_preds = []
    hum_validation_preds = []
    priorhum_teaching_preds = []
    priorhum_validation_preds = []
    hum_teaching_preds_b = human_predictor.predict_right_wrong(teaching_embeddings, teaching_embeddings)
    hum_validation_preds_b = human_predictor.predict_right_wrong(validation_embeddings, validation_embeddings)
    for i in range(len(hum_teaching_preds_b)):
        if hum_teaching_preds_b[i] == 1:
            hum_teaching_preds.append(train_answers[i])
        else:
            sents = nltk.sent_tokenize(train_raw_passages[i].replace("#$","</br> </br>"))
            rand_sent = random.randint(0,len(sents)-1)
            hum_teaching_preds.append(sents[rand_sent])
    for i in range(len(hum_validation_preds_b)):
        if hum_validation_preds_b[i] == 1:
            hum_validation_preds.append(test_answers[i])
        else:
            sents = nltk.sent_tokenize(test_raw_passages[i].replace("#$","</br> </br>"))
            rand_sent = random.randint(0,len(sents)-1)
            hum_validation_preds.append(sents[rand_sent])
    priorhum_teaching_preds = human_predictor.prior_rejector(teaching_embeddings, teaching_embeddings, REJECT_EPSILON)
    priorhum_validation_preds = human_predictor.prior_rejector(validation_embeddings, validation_embeddings, REJECT_EPSILON)

    # Get AI predictions
    ai_teaching_preds = []
    ai_validation_preds = []
    if setting == "A":
        fakeai = FakeAI(teaching_embeddings, teaching_embeddings, 15, 15, (2,1), kmeans = human_predictor.kmeans_pass )
    else:
        fakeai = FakeAI(teaching_embeddings, teaching_embeddings, 15, 15, (1,1), kmeans = human_predictor.kmeans_pass )
    ai_teaching_preds_b = fakeai.predict_right_wrong(teaching_embeddings, teaching_embeddings)
    ai_validation_preds_b = fakeai.predict_right_wrong(validation_embeddings, validation_embeddings)
    ai_teaching_preds = []
    for i in range(len(ai_teaching_preds_b)):
        if ai_teaching_preds_b[i] == 1:
            ai_teaching_preds.append(train_answers[i])
        else:
            sents = nltk.sent_tokenize(train_raw_passages[i].replace("#$","</br> </br>"))
            rand_sent = random.randint(0,len(sents)-1)
            ai_teaching_preds.append(sents[rand_sent])
    ai_validation_preds = []
    for i in range(len(ai_validation_preds_b)):
        if ai_validation_preds_b[i] == 1:
            ai_validation_preds.append(test_answers[i])
        else:
            sents = nltk.sent_tokenize(test_raw_passages[i].replace("#$","</br> </br>"))
            rand_sent = random.randint(0,len(sents)-1)
            ai_validation_preds.append(sents[rand_sent])

    # Optimal deferall decisions
    opt_defer_teaching = []
    opt_defer_validation = []
    clusts_p_ai = fakeai.kmeans_pass.predict(teaching_embeddings)
    clusts_p_hum = human_predictor.kmeans_pass.predict(teaching_embeddings)
    for ex in range(len(hum_teaching_preds)):
        corr_ai = fakeai.clust_err_p[clusts_p_ai[ex]]
        corr_hum = human_predictor.clust_err_p[clusts_p_hum[ex]]
        if corr_ai >= corr_hum:
            opt_defer_teaching.append(1)
        else:
            opt_defer_teaching.append(0)
    clusts_p_ai = fakeai.kmeans_pass.predict(validation_embeddings)
    clusts_p_hum = human_predictor.kmeans_pass.predict(validation_embeddings)
    for ex in range(len(hum_validation_preds)):
        corr_ai = fakeai.clust_err_p[clusts_p_ai[ex]]
        corr_hum = human_predictor.clust_err_p[clusts_p_hum[ex]]
        if corr_ai >= corr_hum:
            opt_defer_validation.append(1)
        else:
            opt_defer_validation.append(0)

    optimal_score = compute_metrics(hum_validation_preds, ai_validation_preds, opt_defer_validation, test_answers)
    prior_score = compute_metrics(hum_validation_preds, ai_validation_preds, priorhum_validation_preds, test_answers)
    prior_score = prior_score[1]['exact_match']

    human_score = compute_metrics(hum_validation_preds, ai_validation_preds, [0]*len(ai_validation_preds), test_answers)
    ai_score = compute_metrics(hum_validation_preds, ai_validation_preds, [1]*len(ai_validation_preds), test_answers)

    hum = human_score[1]['exact_match']
    scores_human.append(human_score[1]['exact_match'])
    scores_oracle.append(optimal_score[1]['exact_match'])
    if not greedy_gamma:
        optimal_gammas = get_optimal_gammas()
    else:
        optimal_gammas = [1]*len(teaching_embeddings)
    print("running our method")
    if greedy_gamma:
        errors, indices_used = teach_ours_doublegreedy()

    else:
        errors, indices_used = teach_ours(False)

         
    print("running lime")
    DATA_DIM = teaching_embeddings.shape[1]
    explainer = lime.lime_tabular.LimeTabularExplainer(teaching_embeddings, discretize_continuous= False)
    global_feature_importance = {}
    cutoff_size = len(teaching_embeddings)
    for i in range(DATA_DIM):
        global_feature_importance[i] = 0
    individual_feature_importance = []
    with tqdm(total=len(teaching_embeddings)) as pbar:
        for ex in range(cutoff_size):
            exp = explainer.explain_instance(teaching_embeddings[ex], fakeai.predict_proba_encoded, num_samples = 500)
            feat_weights = exp.local_exp[1]
            dic_feat_weights = {}
            for j in range(len(feat_weights)):
                dic_feat_weights[feat_weights[j][0]] = abs(feat_weights[j][1])
                global_feature_importance[feat_weights[j][0]] += abs(feat_weights[j][1])
            individual_feature_importance.append(dic_feat_weights)
            pbar.update(1)

    for i in range(DATA_DIM):
        global_feature_importance[i] = math.sqrt(global_feature_importance[i])
    errors_lime, indices_used_lime= teach_lime(greedy_gamma)

    print("running learnai")
    errors_learnai, indices_used_learnai = teach_learnai(greedy_gamma)
    
    print("running medoid")
    errors_medoid, indices_used_medoid = teach_medoids(greedy_gamma)
    
    print("running random")
    errors_random, indices_used_random = teach_random(greedy_gamma)
        


    errors.insert(0, prior_score)
    errors_learnai.insert(0, prior_score)
    errors_medoid.insert(0, prior_score)
    errors_random.insert(0, prior_score)
    errors_lime.insert(0, prior_score)

    scores_ours.append(errors)
    scores_aibaseline.append(errors_learnai)
    scores_medoid.append(errors_medoid)
    scores_random.append(errors_random)
    scores_lime.append(errors_lime)

In [None]:
import matplotlib

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rc('text', usetex=False)
plt.rc('font', family='serif')


In [None]:

teaching_sizes = [i*PLOT_INTERVAL for i in range(6)]
actual_max_trials = len(scores_ours) 

avgs_rand = [np.average([ scores_oracle[triall] - scores_ours[triall][i] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
stds_rand = [np.std([scores_oracle[triall] - scores_ours[triall][i] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
plt.errorbar(teaching_sizes,  avgs_rand, yerr=stds_rand, marker = "o",  label=f'GREEDY-SELECT (Ours)')


avgs_rand = [np.average([scores_oracle[triall] - scores_medoid[triall][i] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
stds_rand = [np.std([ scores_oracle[triall] -scores_medoid[triall][i] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
plt.errorbar(teaching_sizes,  avgs_rand, yerr=stds_rand, marker = "s",   label=f'K-Medoids')


avgs_rand = [np.average([scores_oracle[triall]- scores_aibaseline[triall][i] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
stds_rand = [np.std([scores_oracle[triall] -scores_aibaseline[triall][i] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
plt.errorbar(teaching_sizes,  avgs_rand,yerr=stds_rand, marker = "*",   label=f'AI-Behavior')

avgs_rand = [np.average([scores_oracle[triall] - scores_lime[triall][i] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
stds_rand = [np.std([scores_oracle[triall] -scores_lime[triall][i] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
plt.errorbar(teaching_sizes,  avgs_rand,yerr=stds_rand, marker = "+",  label=f'LIME')


avgs_rand = [np.average([scores_oracle[triall] - scores_random[triall][i] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
stds_rand = [np.std([scores_oracle[triall] -scores_random[triall][i] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
plt.errorbar(teaching_sizes,  avgs_rand,yerr=stds_rand, marker = "v",  label=f'Random')


avgs_rand = [np.average([scores_oracle[triall] - scores_human[triall] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
stds_rand = [np.std([scores_oracle[triall] -scores_human[triall] for triall in range(actual_max_trials)]) for i in range(len(teaching_sizes))]
plt.errorbar(teaching_sizes[0],  avgs_rand[0],yerr=stds_rand[0], marker = "x",  label=f'Human alone {avgs_rand[0]:.2f} $\pm$ {stds_rand[0]:.2f}')



ax = plt.gca()
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.get_xaxis().tick_bottom()    
ax.get_yaxis().tick_left()   
plt.grid()
plt.legend(fontsize='large')
plt.legend()
plt.ylabel('Difference to Oracle Accuracy',  fontsize='x-large')
plt.xlabel('Teaching set size', fontsize='x-large')
fig_size = plt.rcParams["figure.figsize"]
fig_size[0] = 6
fig_size[1] = 4
#plt.savefig("teaching_complexity_A_nongreedy", dpi = 1000)
plt.show()