In [1]:
import os, pickle
import numpy as np

import warnings
warnings.filterwarnings("ignore")

%matplotlib inline
%config InlineBackend.figure_format='retina'

import matplotlib.pyplot as plt
import evals.teaching_evals as evals
import pickle
import utils
import figs.plot_data as plot
import algorithms.teaching_algs as algs
y_train = np.array([0]*80+[1]*80)
y_valid = np.array([0]*20+[1]*20)
import seaborn as sns

## load data

In [14]:
dwac_train = pickle.load(open("embeds/bm/dwac_train_emb10.merged10.pkl","rb"))
dwac_valid = pickle.load(open("embeds/bm/dwac_valid_emb10.merged10.pkl","rb"))
resn_train = pickle.load(open("embeds/bm/resn_train_emb10.pkl","rb"))
resn_valid = pickle.load(open("embeds/bm/resn_valid_emb10.pkl","rb"))
TN_train = pickle.load(open("embeds/bm/human/TN_train_emb10.pkl","rb"))
TN_valid = pickle.load(open("embeds/bm/human/TN_valid_emb10.pkl","rb"))
MTL_train = pickle.load(open("embeds/bm/human/MTL.BCETN_train_emb10.pkl","rb"))
MTL_valid = pickle.load(open("embeds/bm/human/MTL.BCETN_valid_emb10.pkl","rb"))

In [None]:
embeds = {"dwac": (dwac_train,dwac_valid), 
"resn": (resn_train,resn_valid), 
"TN_human":(TN_train,TN_valid),
"MTL_human": (MTL_train,MTL_valid)}
legend = ['full', 'random', 'random_ci'] + list(embeds.keys())

## knn using model embeds

In [None]:
def get_model_knn_scores(data, m_range, selection_alg, k=1, args=None):
    if selection_alg == "protogreedy":
        selection_alg = algs.protogreedy
    elif selection_alg == "prototriplet":
        selection_alg = algs.prototriplet

    x_train, y_train, x_valid, y_valid = data
    prototype_knn_scores = []
    for m in m_range:
        prototype_idx = selection_alg(x_train, m, args)
        knn_score = evals.get_knn_score(x_train[prototype_idx], y_train[prototype_idx], x_valid, y_valid, k=k)
        prototype_knn_scores.append(knn_score)

    print(prototype_knn_scores)
    return prototype_knn_scores

In [2]:
m_range = np.arange(4,21)

In [None]:
all_scores = {}
for model, (x_train,x_valid) in embeds.items():
    scores = {}
    data = x_train, y_train, x_valid, y_valid
    scores["full"],  scores["random_scores"] = evals.get_full_random(data, m_range)
    scores["protogreedy"] = get_model_knn_scores(data, m_range, "protogreedy")
    scores["prototriplet_topk=10"] = get_model_knn_scores(data, m_range, "prototriplet", args={"topk":10})
    scores["prototriplet_topk=100"] = get_model_knn_scores(data, m_range, "prototriplet", args={"topk":100})
    all_scores[model] = scores

In [None]:

def vis_knn_scores_multiplot(m_range, allall_scores, subtitles=None, title=None, save=False, save_dir=None):
    n = len(allall_scores)
    fig, ax = plt.subplots(2,2, figsize=(8*2, 6*2), sharey=True)
    for j, (model, all_score) in enumerate (allall_scores.items()):    
        ax[j//2][j%2].axhline(all_score["full"] , c='black', linewidth=2, linestyle="solid", label="full score")  
        random_knn_scores, random_knn_ci = all_score["random_scores"]
        ax[j//2][j%2].plot(m_range, random_knn_scores, linewidth=2, linestyle="dashed", label="random score")
        ax[j//2][j%2].fill_between(m_range, random_knn_scores + random_knn_ci / 2, random_knn_scores - random_knn_ci / 2, alpha=0.5)
        for key, score in all_score.items():
            if key == "full" or key == "random_scores": continue
            ax[j//2][j%2].plot(m_range, score, linewidth=4, label=key)

        ax[j//2][j%2].set_title(model, fontsize=25)

    fig.supxlabel("number of examples", y=0.05,fontsize=25)
    fig.supylabel("acc",x=0.09, fontsize=25)
    plt.legend(loc='upper right', bbox_to_anchor=(1, -0.2),fancybox=True, shadow=True, ncol=7, fontsize=15)

    if title: fig.suptitle(title, fontsize=30)
    if save:
        if not save_dir: save_dir = f"figs/{title}.pdf"
        plt.savefig(save_dir, format="pdf", bbox_inches="tight")

vis_knn_scores_multiplot(m_range,all_scores, save=True, save_dir="figs/model_knn.pdf")

## knn using human triplets

In [None]:
train_embs = np.array(pickle.load(open("embeds/bm/human/TN_train_emb10.pkl","rb")))
valid_embs = np.array(pickle.load(open("embeds/bm/human/TN_valid_emb10.pkl","rb")))
def get_htriplet_knn_scores(X, m_range, selection_alg, k=1, args=None):
    if selection_alg == "protogreedy":
        selection_alg = algs.protogreedy
    elif selection_alg == "tripet_greedy":
        selection_alg = algs.tripet_greedy

    prototype_knn_scores = []
    align_scores = []
    for m in m_range:
        prototype_idx = selection_alg(X, m, args)
        knn_score = evals.get_knn_score(train_embs[prototype_idx], y_train[prototype_idx], valid_embs, y_valid, k=k)
        prototype_knn_scores.append(knn_score)
        align_scores.append(evals.human_1NN_align(X,prototype_idx))

    return prototype_knn_scores, align_scores

In [20]:
m_range=np.arange(3,21)
def get_nn_greedy(k=1):
    train = MTL_train
    valid = MTL_valid
    prototype_knn_scores = []
    for m in m_range:
        prototype_idx = algs.nngreedy((train,valid), m, (y_train,y_valid))
        knn_score = evals.get_knn_score(TN_train[prototype_idx], y_train[prototype_idx], TN_valid, y_valid, k=k)

        prototype_knn_scores.append(knn_score)

    return prototype_knn_scores

In [21]:
get_nn_greedy()

[0.8,
 0.8,
 0.8,
 0.8,
 0.8,
 0.8,
 0.8,
 0.8,
 0.8,
 0.6,
 0.6,
 0.6,
 0.6,
 0.6,
 0.6,
 0.6,
 0.6,
 0.6]

In [None]:
m_range = np.arange(4,21)
data = (train_embs, y_train, valid_embs, y_valid)
full_score,  random_scores = evals.get_full_random(data, m_range)

In [None]:
protog_scores = {"full_score":full_score, "random_scores":random_scores}
protog_align = {}
selection_alg = "protogreedy"
for model, embed in embeds.items():
    protog_scores[model], protog_align[model] = get_htriplet_knn_scores(embed[0], m_range, selection_alg=selection_alg)

In [None]:
protot10_scores = {"full_score":full_score, "random_scores":random_scores}
protot10_align = {}
selection_alg = "prototriplet"
args = {"topk":10}
for model, embed in embeds.items():
    protot10_scores[model], protot10_align[model] = get_htriplet_knn_scores(embed[0], m_range, selection_alg, args=args)

In [None]:
protot100_scores = {"full_score":full_score, "random_scores":random_scores}
protot100_align = {}
selection_alg = "prototriplet"
args = {"topk":100}
for model, embed in embeds.items():
    protot100_scores[model], protot100_align[model] = get_htriplet_knn_scores(embed[0], m_range, selection_alg, args=args)

In [None]:
plot.vis_knn_scores_multiplot(m_range, [protog_scores, protot10_scores, protot100_scores], 
subtitles=["protogreedy","tripletgreedy_topk=10","tripletgreedy_topk=100"],
save=True, save_dir="figs/KNN_TN.pdf")

In [None]:

def vis_align_multiplot(m_range, allall_scores, subtitles=None, title=None, save=False, save_dir=None):
    n = len(allall_scores)
    fig, ax = plt.subplots(1, n, figsize=(8*n, 6), sharey=True)
    for j, all_scores in enumerate(allall_scores):    
        for model, score in all_scores.items():
            ax[j].plot(m_range, score, linewidth=4, label=model)

        if subtitles: ax[j].set_title(subtitles[j], fontsize=25)

    fig.supxlabel("number of examples", fontsize=25)
    fig.supylabel("align score",x=0.09, fontsize=25)
    plt.legend(loc='upper right', bbox_to_anchor=(0.15, -0.12),fancybox=True, shadow=True, ncol=7, fontsize=20)

    if title: fig.suptitle(title, fontsize=30)
    if save:
        if not save_dir: save_dir = f"figs/{title}.pdf"
        plt.savefig(save_dir, format="pdf", bbox_inches="tight")


In [None]:
vis_align_multiplot(m_range,[protog_align, protot10_align, protot100_align],
subtitles=["protogreedy","tripletgreedy_topk=10","tripletgreedy_topk=100"],
save=True, save_dir="figs/TN_align.pdf")

In [None]:
x = plot.tsne2(TN_train)
plot.vis_data_multiplot([(x,y_train),(x,y_train),(x,y_train)], ["butterfly","moth"], subtitles=["protogreedy","prototriplet_topk=10","prototriplet_topk=100"], 
prototype_idx=[algs.prototriplet(TN_train, 10, {"topk":10}),
algs.prototriplet(TN_train, 10, {"topk":100}),
algs.protogreedy(TN_train, 10)], 
save=True, save_dir="figs/bm_TN_m=10.pdf")

In [None]:
plot.vis_data_multiplot([(x,y_train),(x,y_train),(x,y_train)], ["butterfly","moth"], subtitles=["protogreedy","prototriplet_topk=10","prototriplet_topk=100"], 
prototype_idx=[algs.prototriplet(MTL_train, 10, {"topk":10}),
algs.prototriplet(MTL_train, 10, {"topk":100}),
algs.protogreedy(MTL_train, 10)], 
save=True, save_dir="figs/bm_MTL_m=10.pdf")