# Imports 

In [None]:
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from collections import Counter
import string
import re
import argparse
import json
from sklearn import neighbors, datasets
import os
import sys
import random
import shutil
import time
import torch.nn.parallel
from tqdm import tqdm
import torch.optim
import torch.utils.data
import torchvision.datasets as datasets
from torch.autograd import Variable
from PIL import Image
import torch.utils.data as data
from torchvision.datasets.utils import download_url, check_integrity
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.gaussian_process.kernels import RBF
from scipy.stats import multivariate_normal
import  scipy.stats as st
from matplotlib import cm
from __future__ import print_function
from spacy.lang.en import English
from sklearn.cluster import KMeans
import pickle
import matplotlib
from sklearn.metrics.pairwise import rbf_kernel
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as tfunc
from torch.utils.data import Dataset
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
import sklearn.metrics as metrics
import scipy.stats as st
#import lime
#import lime.lime_tabular


In [None]:
class TeacherExplainer():
    """ Returns top examples that best teach a learner when to defer to a classifier.
    Given a tabular dataset with classifier predictions, human predictions and a similarity metric,
    the method returns the top k images that best describe when to defer to the AI.    
     """

def __init__(self,
             data_x,
             data_y,
             hum_preds,
             ai_preds,
             prior_rejector_preds,
             sim_kernel,
             metric_y,
             teaching_points = 10):
        """Init function.
        Args:
            data_x: 2d numpy array of the features
            data_y: 1d numpy array of labels
            hum_preds:  1d array of the human predictions 
            ai_preds:  1d array of the AI predictions 
            prior_rejector_preds: 1d binary array of the prior rejector preds 
            sim_kernel: function that takes as input two inputs and returns a positive number 
        """
        self.data_x = data_x
        self.data_y = data_y
        self.hum_preds = hum_preds
        self.data_y = data_y
        self.ai_preds = ai_preds
        self.sim_kernel = sim_kernel
        self.prior_rejector_preds = self.prior_rejector_preds
        self.teaching_points = teaching_points

def get_teaching_examples(teaching_points):
    """ obtains teaching points.
    Args:
        teaching_points: number of teaching points
    Return:
        teaching_x: 2d numpy array of teaching points features
        teaching_indices: indices of the teaching points in self.data_x
        teaching_gammas: 1d numpy of gamma values used
        teaching_labels: 1d array of deferral labels where 1 signifies defer to AI and 0 signifies don't defer to AI
    
    """
    self.teaching_points = teaching_points
    # run algorithm to get examples
    teaching_x = [ ]
    teaching_indices = []
    teaching_gammas = []
    teaching_labels = []
    # ALGORITHM here
    # define human learner 
    # run algorithm to get teaching points
    return teaching_x, teaching_gammas, teaching_labels, teaching_indices

In [None]:
class HumanLearner:
    """ Model of Human Learner.
    Learner has a list of training points each with a radius and label.
    Learner follows the radius nearest neighbor assumption.
    """
    def __init__(self, kernel):
        '''
        Args:
            kernel: function that takes two inputs and returns a similarity
        '''
        self.teaching_set = []
        self.kernel = kernel
        self.rejector_tresh = 0.8

    def predict(self, xs, prior_rejector_preds, to_print = False):
        '''
        Args:
            xs: teaching points 
            prior_rejector_preds: predictions of prior rejector
        Return:
            preds: posterior human learner rejector predictions
        '''
        preds = []
        idx = 0
        used_posterior = 0 
        for x in xs:
            ball_at_x = []
            similarities = kernel(x.reshape(1,-1), np.asarray([self.teaching_set[kk][0] for kk in range(len(self.teaching_set))]))[0]
            for i in range(len(self.teaching_set)):
                similarity = similarities[i]
                if similarity >=  self.teaching_set[i][2]:
                    ball_at_x.append(self.teaching_set[i])
            if len(ball_at_x) == 0: 
                preds.append(prior_rejector_preds[idx])
            else:
                used_posterior += 1
                ball_similarities = kernel(x.reshape(1,-1), np.asarray([ball_at_x[kk][0] for kk in range(len(ball_at_x))]))[0]
                normalization = np.sum([ball_similarities[i] for i in range(len(ball_at_x))])
                score_one = np.sum([ball_similarities[i]*ball_at_x[i][1] for i in range(len(ball_at_x))])
                pred = score_one / normalization
                if pred >= 0.5:
                    preds.append(1)
                else:
                    preds.append(0)
            idx += 1
        
        return preds

    def add_to_teaching(self, teaching_example):
        '''
        adds teaching_example to training set
        args:
            teaching_example: (x, label, gamma)
        '''
        self.teaching_set.append(teaching_example)

    def remove_last_teaching_item(self):
        """ removes last placed teaching example from training set"""
        self.teaching_set = self.teaching_set[:-1]


In [None]:
def compute_predictions_humanai(hum_preds, hum_rejector, ai_preds, data_x):
    '''
    hum_preds: array of human predictions
    ai_preds: array of AI predictions
    hum_rejector: HumanLearner
    data_x: array of inputs

    Returns array of final predictions and deferalls
    '''
    predictions = []
    with torch.no_grad():
        reject_decisions = hum_rejector(data_x)
        for i in range(len(data_x)):
            if reject_decisions[i] == 1:
                # defer
                predictions.append(ai_preds[i])
            else:
                predictions.append(hum_preds[i])
    return predictions, reject_decisions

def get_metrics(preds, truths):
    # to be implemented for each method, higher better
    '''
    preds: array of predictions
    truths:  target array
    '''
    acc = metrics.accuracy_score(truths, preds)
    metrics_computed = { "score": acc}
    return metrics_computed

def compute_metrics(human_preds, ai_preds, reject_decisions, truths, to_print = False):
    coverage = 1 - np.sum(reject_decisions)/len(reject_decisions)
    humanai_preds = []
    human_preds_sys = []
    truths_human = []
    ai_preds_sys = []
    truths_ai = []
    for i in range(len(reject_decisions)):
        if reject_decisions[i] == 1:
            humanai_preds.append(ai_preds[i])
            ai_preds_sys.append(ai_preds[i])
            truths_ai.append(truths[i])
        else:
            humanai_preds.append(human_preds[i])
            human_preds_sys.append(human_preds[i])
            truths_human.append(truths[i])
    humanai_metrics = get_metrics(humanai_preds, truths)

    human_metrics = get_metrics(human_preds_sys, truths_human)

    ai_metrics = get_metrics(ai_preds_sys, truths_ai)

    if to_print:
        print(f'Coverage is {coverage*100:.2f}')
        print(f' metrics of system are: {humanai_metrics}')
        print(f' metrics of human are: {human_metrics}')
        print(f' metrics of AI are: {ai_metrics}')
    return coverage, humanai_metrics, human_metrics, ai_metrics

In [None]:
# get optimal gammas, ONLY FOR TEACHING
def get_optimal_consistent_gammas(teaching_embeddings, opt_defer_teaching ):
    '''
    Args:
        teaching_embeddings: teaching points 
        opt_defer_teaching: binary deferral label 
    Return:
        preds: posterior human learner rejector predictions
    '''
    optimal_gammas = []
    with tqdm(total=len(teaching_embeddings)) as pbar:
        similarities_embeds_all = rbf_kernel( np.asarray(teaching_embeddings), np.asarray(teaching_embeddings))
        for i in range(len(teaching_embeddings)):
            # get all similarities
            similarities_embeds = similarities_embeds_all[i]
            opt_defer_ex = opt_defer_teaching[i]
            opt_gamma = 1
            sorted_sim = sorted([(similarities_embeds[k], opt_defer_teaching[k]) for k in range(len(teaching_embeddings))], key=lambda tup: tup[0])
            indicess = list(range(1, len(opt_defer_teaching)))
            indicess.reverse()
            for k in indicess:
                if sorted_sim[k][1] == opt_defer_ex and sorted_sim[k- 1][1] != opt_defer_ex:
                    opt_gamma = sorted_sim[k][0]
                    break
            optimal_gammas.append(opt_gamma)
            pbar.update(1)
    return optimal_gammas

In [None]:
import multiprocessing
from multiprocessing.dummy import Pool as ThreadPool

indicess = list(range(1, len(teaching_embeddings) -1 ))
indicess.reverse()
def get_improvement_defer_greedy(current_defer_preds, opt_defer_preds, xs, seen_indices):

    error_improvements = []
    error_at_i = 0
    found_gammas = []
    for i in range(len(opt_defer_preds)):
        coin = random.random() # random number between [0,1]

        similarities_embeds = similarities_embeds_all[i]
        sorted_sim = sorted_sims[i] #sorted([(similarities_embeds[k], k) for k in range(len(teaching_embeddings))], key=lambda tup: tup[0])

        max_improve = -1000
        gamma_value = optimal_gammas[i]
        current_improve = 0
        so_far = 0
        for j in indicess:
            if i in seen_indices:
                continue

            so_far += 1
            idx = int(sorted_sim[j][1])
            f1_hum = hum_teaching_preds_b[idx]
            f1_ai = ai_teaching_preds_b[idx]
            if opt_defer_preds[i] == 1:
                if current_defer_preds[idx] == 0:
                    current_improve += f1_ai - f1_hum
            else:
                if current_defer_preds[idx] == 1:
                    current_improve += f1_hum - f1_ai

            if current_improve >= max_improve:
                max_improve = current_improve 
                gamma_value = min(optimal_gammas[i], sorted_sim[j][0] )
            
        error_improvements.append(max_improve)
        found_gammas.append(gamma_value)
    return error_improvements, found_gammas


In [None]:
def teach_ours_doublegreedy():
    human_learner = HumanLearner(None)

    errors = []
    data_sizes  = []
    indices_used = []
    points_chosen = []
    for itt in range(MAX_SIZE):
        print(f'New size {itt}')
        best_index = -1
        # predict with current human learner
        if itt == 0:
            preds_teach = priorhum_teaching_preds
        else:
            preds_teach = human_learner.predict(teaching_embeddings, priorhum_teaching_preds)
        error_improvements, best_gammas = get_improvement_defer_greedy(preds_teach, opt_defer_teaching,  teaching_embeddings, indices_used)
        print(f'got improvements with max {max(error_improvements)}')
        best_index = np.argmax(error_improvements)
        indices_used.append(best_index) # add found element to set used
        ex_embed = teaching_embeddings[best_index]
        ex_label = opt_defer_teaching[best_index]
        gamma = best_gammas[best_index] # + (np.random.rand(1)[0])*(1-optimal_gammas[best_index])-(1-optimal_gammas[best_index])/2 # random choice
        human_learner.add_to_teaching([ex_embed, ex_label, gamma])

        if False and itt % PLOT_INTERVAL == 0:
            print("####### train eval " +str(itt)+ " ###########")
            preds_teach = human_learner.predict(teaching_embeddings, priorhum_teaching_preds)
            _, metricsc, __, ___ = compute_metrics(hum_teaching_preds, ai_teaching_preds, preds_teach, teaching_target, True)
            #errors.append(metricsc)   
            print("##############################")

        if   itt % PLOT_INTERVAL == 0:
            print("####### val eval " +str(itt)+ " ###########")
            preds_teach = human_learner.predict(testing_embeddings, priorhum_testing_preds)
            _, metricsc, __, ___ = compute_metrics(hum_testing_preds, ai_testing_preds, preds_teach, testing_target, True)
            errors.append(metricsc['accuracy'])   
            print("##############################")
    return errors, indices_used
#errors_doublegreedy, indices_used_doublegreedy = teach_ours_doublegreedy()

In [None]:
def get_improvement_defer(current_defer_preds, opt_defer_preds, gammas, xs, coin_prob = 0.1):
    error_improvements = []
    #similarities_embeds_all = rbf_kernel(np.asarray(xs), np.asarray(xs))
    error_at_i = 0
    for i in range(len(gammas)):
        coin = random.random() # random number between [0,1]
        error_at_i = 0
        similarities_embeds = similarities_embeds_all[i]
        for j in range(len(similarities_embeds)):
            if similarities_embeds[j] >= gammas[i]:
                f1_hum = hum_teaching_preds_b[j]
                f1_ai = ai_teaching_preds_b[j]
                if opt_defer_preds[i] == 1:
                    if current_defer_preds[j] == 0:
                        error_at_i += f1_ai - f1_hum
                else:
                    if current_defer_preds[j] == 1:
                        error_at_i += f1_hum - f1_ai
        error_improvements.append(error_at_i)

        # get the ball for x
        # in this ball how many does the current defer not match the optimal
    return error_improvements



def teach_ours(greedy_gamma = False):
    human_learner = HumanLearner(None)
    errors = []
    data_sizes  = []
    indices_used = []
    points_chosen = []
    for itt in range(MAX_SIZE):
        print(f'New size {itt}')
        best_index = -1
        # predict with current human learner
        if itt == 0:
            preds_teach = priorhum_teaching_preds
        else:
            preds_teach = human_learner.predict(teaching_embeddings, priorhum_teaching_preds)
        error_improvements = get_improvement_defer(preds_teach, opt_defer_teaching, optimal_gammas, teaching_embeddings)
        best_index = np.argmax(error_improvements)
        indices_used.append(best_index) # add found element to set used
        ex_embed = teaching_embeddings[best_index]
        ex_label = opt_defer_teaching[best_index]

        if greedy_gamma:
            _, greedy_gamma = get_greedy_gamma(best_index, preds_teach, opt_defer_teaching, optimal_gammas, teaching_embeddings)
            gamma = greedy_gamma
            print(f'got improvements with max {_}')
        else:
            gamma = optimal_gammas[best_index]
            print(f'got improvements with max {max(error_improvements)}')

        #gamma = optimal_gammas[best_index] # + (np.random.rand(1)[0])*(1-optimal_gammas[best_index])-(1-optimal_gammas[best_index])/2 # random choice
        human_learner.add_to_teaching([ex_embed, ex_label, gamma])

        if False and itt % 3 == 0:
            print("####### train eval " +str(itt)+ " ###########")
            preds_teach = human_learner.predict(teaching_embeddings, priorhum_teaching_preds)
            _, metrics, __, ___ = compute_metrics(hum_teaching_preds, ai_teaching_preds, preds_teach, teaching_target, True)
            #errors.append(metrics)   
            print("##############################")

        if   itt % PLOT_INTERVAL == 0:

            plt.imshow(  train_dataset[best_index][0].permute(1, 2, 0)  )
            plt.show()
            print("####### val eval " +str(itt)+ " ###########")
            preds_teach = human_learner.predict(testing_embeddings, priorhum_testing_preds)
            _, metrics, __, ___ = compute_metrics(hum_testing_preds, ai_testing_preds, preds_teach, testing_target, True)
            errors.append(metrics['accuracy'])   
            print("##############################")
    return errors, indices_used
#errors, indices_used = teach_ours(True)

GOAL

general:

given dataloader and costs, retreive set of points and their indices and gammas, and distance metric

input:
- 
- 
- 

test case
- images with cifar 

- adult dataset fake expert

- 

