In [None]:
from wandb.apis.public import Api

WANDB_KEY = ""  # YOUR WANDB KEY
PROJECT = ""  # YOUR PROJECT

wandb_api = Api(api_key=WANDB_KEY)
project = wandb_api.runs(PROJECT)

In [None]:
import os
from tqdm import tqdm
import numpy as np
import pandas as pd

all_runs = []
all_runs_paths = set()
for run in tqdm(project):
    if run.state == 'finished':
        x_path = os.path.join(*run.path)
        if x_path in all_runs_paths:
            continue
        else:
            config = run.config
            history = run.history()
            if len(history) > 0:
                scores = history.scores.item()
                dataset = config['dataset']
                model = config['model']
                prediction_method = config['prediction_method']

                all_runs.append({"dataset": dataset, "model": model, 'seed': config['seed'],
                                 "selection_method": config['example_selection_method'],
                                 "prediction_method": '_'.join(prediction_method.split("_")[-2:]),
                                 "n_shots": config['n_shots'], 'scores': scores, 'worst_case': min(scores),
                                'mean': np.mean(scores), 'std': np.std(scores)})
                all_runs_paths.add(x_path)
all_runs_df = pd.DataFrame(all_runs)
all_runs_df.head()

In [None]:
import torch
torch.save(all_runs_paths, 'all_runs_paths.pt')
all_runs_df.to_csv('all_runs_df.csv', index=False)

In [None]:
from collections import defaultdict
dataset_templates = defaultdict(dict)
for dataset in ['sst2', 'dbpedia', 'agnews', 'trec']:
    for seed in [59, 13, 21]:
        for run in project:
            if run.state == 'finished' and run.config.get('dataset', 's') == dataset and run.config.get('seed', 0) == seed:
                templates = run.history().templates.item()
                dataset_templates[dataset][seed] = templates
                print(templates)
                break

In [None]:
dataset_template_tables[dataset].sort_values(list(table.keys()))

In [None]:
dataset_template_tables = {}

all_templates = []
for dataset in datasets:
    table = {'inp_verbalizer': [], 'out_verbalizer': [], 'sep': [], 'big_sep': [], 'seed': [], 'n': []}
    for seed in dataset_templates[dataset]:
        for i, template in enumerate(dataset_templates[dataset][seed]):
            table['inp_verbalizer'].append(template[0])
            table['out_verbalizer'].append(template[1])
            table['sep'].append(f"{repr(template[2])}")
            table['big_sep'].append(f"{repr(template[3])}")
            table['seed'].append(seed)
            table['n'].append(i)
    dataset_template_tables[dataset] = pd.DataFrame(table)
    print(dataset)
    display(dataset_template_tables[dataset].sort_values(list(table.keys())[:-2]))

In [None]:
table = dataset_template_tables['dbpedia'].sort_values(list(table.keys())[:-2])
row = table.iloc[0]
next_row = table.iloc[1]

In [None]:
for k in ['inp_verbalizer', 'out_verbalizer', 'sep']:
    print(row[k] == next_row[k], k)

In [None]:
dataset_good_pairs = {}
for dataset in datasets:
    good_pairs = []
    table = dataset_template_tables[dataset].sort_values(list(table.keys())[:-2])
    for i in range(len(table) - 1):
        row = table.iloc[i]
        next_row = table.iloc[i + 1]
        if all([row[k] == next_row[k] for k in ['inp_verbalizer', 'out_verbalizer', 'sep']]) and row['big_sep'] != next_row['big_sep']:
            good_pairs.append(((row.seed, row.n), (next_row.seed, next_row.n)))
            
    dataset_good_pairs[dataset] = good_pairs

In [None]:
from scipy.stats.stats import pearsonr
a = [template_a_scores[k] for k in template_a_scores]
b = [template_b_scores[k] for k in template_a_scores]
pearsonr(a, b)

In [None]:
for dataset in datasets:
    selected_runs = all_runs_df.loc[all_runs_df['dataset'] == dataset]
    for pair in dataset_good_pairs[dataset]:
        template_a, template_b = pair
        template_a_scores, template_b_scores = [], []
        for model in names_to_checkpoints:
            for prediction_method in ['direct_False', 'channel_True', 'calibrate_True']:
                for selection_method in ['random', 'implicitly_topic_models', 'z-ICL']:
                    for n_shots in [2, 4, 8]:
                        pair_runs = selected_runs.loc[selected_runs['n_shots'] == n_shots]
                        pair_runs = pair_runs.loc[selected_runs['selection_method'] == selection_method]
                        pair_runs = pair_runs.loc[selected_runs['prediction_method'] == prediction_method]
                        pair_runs = pair_runs.loc[selected_runs['model'] == model]
                        a_run = pair_runs.loc[pair_runs['seed'] == template_a[0]]
                        b_run = pair_runs.loc[pair_runs['seed'] == template_b[0]]
                        if len(a_run) > 0 and len(b_run) > 0:
                            template_a_scores.append(a_run['scores'].values[0][template_a[1]])
                            template_b_scores.append(b_run['scores'].values[0][template_b[1]])
                            
        corr, p_value = pearsonr(template_a_scores, template_b_scores)
        print(f"{template_a} correlation with {template_b}: {corr:.3f} with p_value: {p_value}")

# 3.1 Baseline results

In [None]:
from collections import defaultdict

n_shots_tables = {}
selected_runs = all_runs_df.loc[all_runs_df['selection_method'] == 'random']
selected_runs = selected_runs.loc[selected_runs['prediction_method'] == 'direct_False']
for n_shots in [2, 4, 8]:
    n_shots_runs = selected_runs.loc[selected_runs['n_shots'] == n_shots]
    table = {dataset: defaultdict(str) for dataset in datasets}
    for model in names_to_checkpoints:
        model_runs = n_shots_runs.loc[n_shots_runs['model'] == model]
        for dataset in datasets:
            dataset_runs = model_runs.loc[model_runs['dataset'] == dataset]
            scores = []
            for seed in [59, 13, 21]:
                seed_scores = dataset_runs.loc[dataset_runs['seed'] == seed]['scores']
                if len(seed_scores) > 0:
                    seed_scores = seed_scores.values[0]
                scores.extend(seed_scores)
            table[dataset][model] = f"{np.mean(scores):.3f} ± {np.std(scores):.3f}"
    table = pd.DataFrame(table, index=names_to_checkpoints.keys(), columns=datasets)
    n_shots_tables[n_shots] = table
    print(f"{n_shots}-shot")
    display(table)

In [None]:
for model in ['gpt2-large', 'gpt2-xl',
              'llama-7b', 'llama-13b', 'llama-30b', 'llama-65b',
              ]:
    print(model, end=' & ')
    for dataset in datasets:
        for n_shots in [2, 4, 8]:
            res = n_shots_tables[n_shots].loc[model, dataset]
            mean, std = map(float, res.split(" ± "))
            end = '\\\\\n' if (n_shots == 8) and (dataset == 'trec') else ' & '
            print(f"{mean:.2f}\\textsubscript{{{std:.2f}}}", end=end)
    if model == 'gpt2-xl':
        print("\\midrule")

In [None]:
plt.figure(figsize=[16, 12])
for i, dataset in enumerate(datasets, 1):
    plt.subplot(2, 2, i)
    for n_shots in [2, 4, 8]:
        table = n_shots_tables[n_shots]
        x, y, y_plus, y_minus = [], [], [], []
        for model in ['llama-7b', 'llama-13b', 'llama-30b', 'llama-65b',
                      'falcon-1b', 'falcon-7b', 'falcon-40b']:
            x.append(model)
            res = table.loc[model, dataset]
            mean, std = [float(x) for x in res.split(' ± ')]
            y.append(mean)
            y_plus.append(mean + std)
            y_minus.append(mean - std)
        plt.plot(x, y, label=f"{n_shots}-shot")
        plt.fill_between(x, y_minus, y_plus, alpha=0.1)
    plt.legend()
    plt.xticks(rotation=45)
    plt.title(dataset)

# 3.2. Prediction methods

In [None]:
selected_runs = all_runs_df.loc[all_runs_df['n_shots'] == int(n_shots.split('-')[0])]
selected_runs = selected_runs.loc[selected_runs['dataset'] == dataset]
selection_method = 'random' if n_shots == '2-shot' else '0-shot'
selected_runs = selected_runs.loc[selected_runs['selection_method'] == selection_method]
selected_runs.loc[selected_runs['prediction_method'] == 'direct_False']

In [None]:
n_shots = '0-shot'
dataset = 'sst2'
selected_runs = all_runs_df.loc[all_runs_df['n_shots'] == int(n_shots.split('-')[0])]
selected_runs = selected_runs.loc[selected_runs['dataset'] == dataset]
selection_method = 'random' if n_shots == '2-shot' else '0-shot'
selected_runs = selected_runs.loc[selected_runs['selection_method'] == selection_method]
method = 'direct_False'
method_runs = selected_runs.loc[selected_runs['prediction_method'] == method]
method_runs

In [None]:
methods = ['direct_False', 'channel_True',  'calibrate_True',]

def combine_scores(runs, aggregation_method='all', prediction_method='direct_False'):
    all_scores = []
    for _, run in runs.iterrows():
        if aggregation_method == 'direct_best':
            if method in ['calibrate_True', 'channel_True']:
                all_scores.extend(run['scores'])
            else:
                all_scores.append(max(run['scores']))
        elif aggregation_method == 'worst':
            all_scores.append(min(run['scores']))
        elif aggregation_method == 'all':
            all_scores.extend(run['scores'])
            
    return all_scores

method_name = {'direct_False': 'Direct', 'channel_True': 'Channel', 'calibrate_True': 'Calibrate'}

def plot(n_shots_range, aggregation_method='all'):
    res = []
    for n_shots in n_shots_range:
        fig = plt.figure(figsize=[24, 6])
        for i, dataset in enumerate(datasets, 1):
            ax = fig.add_subplot(1, 4, i)
            
            selected_runs = all_runs_df.loc[all_runs_df['n_shots'] == int(n_shots.split('-')[0])]
            selected_runs = selected_runs.loc[selected_runs['dataset'] == dataset]
            selection_method = 'random' if n_shots == '2-shot' else '0-shot'
            selected_runs = selected_runs.loc[selected_runs['selection_method'] == selection_method]
            
            for method in methods:
                method_runs = selected_runs.loc[selected_runs['prediction_method'] == method]
                x, mean, std, = [], [], []
                for model in names_to_checkpoints:
                    x.append(model)
                    model_runs = method_runs.loc[method_runs['model'] == model]
                    model_scores = combine_scores(model_runs, prediction_method=method, 
                                                  aggregation_method=aggregation_method)
                    mean.append(np.mean(model_scores))
                    std.append(np.std(model_scores))
                ax.plot(x, mean, label=method_name[method])
                mean = np.array(mean)
                std = np.array(std)
                ax.fill_between(x, mean - std, mean + std, alpha=0.1)
            if i == 1:
                ax.legend(loc='lower right')
            ax.set_xticks(ticks=list(range(len(names_to_checkpoints))),
                       labels=names_to_checkpoints.keys(),
                       rotation=45, ha="right")
            ax.set_title(f"{n_shots} accuracy on {dataset_name[dataset]}")
            ax.plot()
        res.append(fig)
    return res

In [None]:
%config InlineBackend.figure_format = 'retina'

matplotlib.rcParams.update({
        "font.family": "Times New Roman",
        "axes.labelsize": 18,
        "font.size": 20,
        "legend.fontsize": 16,
        "xtick.labelsize": 13,
        "ytick.labelsize": 14,
        "text.usetex": True,
})

In [None]:
fig = plot(['2-shot'])[0]
fig.savefig('figs/prediction_methods_main.pdf', format='pdf', bbox_inches='tight', pad_inches=0)

In [None]:
figs = plot(['0-shot', '2-shot'], aggregation_method='direct_best')

In [None]:
figs = plot(['0-shot', '2-shot'], aggregation_method='worst')

## Tables

## Ensembles

In [None]:
import pandas as pd

all_runs_df = pd.read_csv("all_runs_df.csv")
all_runs_df.head()

In [None]:
all_runs_df["scores"] = all_runs_df["scores"].apply(lambda x: eval(x))

In [None]:
names_to_checkpoints = {'gpt2-large': 'gpt2-large',
                        'gpt2-xl': 'gpt2-xl',
                        'gptj': 'EleutherAI/gpt-j-6B',
                        'gpt-neox': 'EleutherAI/gpt-neox-20b',
                        'opt-1.3b': 'facebook/opt-1.3b',
                        'opt-6.7b': "facebook/opt-6.7b",
                        'opt-30b': "facebook/opt-30b",
                        'opt-66b': "facebook/opt-66b",
                        'bloom-1.7b': 'bigscience/bloom-1b7',
                        'bloom-3b': 'bigscience/bloom-3b',
                        'bloom-7.1b': 'bigscience/bloom-7b1',
                        'pythia-6.9b': 'EleutherAI/pythia-6.9b',
                        'pythia-12b': 'EleutherAI/pythia-12b',
                        'cerebras-6.7b': 'cerebras/Cerebras-GPT-6.7B',
                        'cerebras-13b': 'cerebras/Cerebras-GPT-13B',
                        'llama-7b': 'Neko-Institute-of-Science/LLaMA-7B-HF',
                        'llama-13b': 'Neko-Institute-of-Science/LLaMA-13B-HF',
                        'llama-30b': 'Neko-Institute-of-Science/LLaMA-30B-hf',
                        'llama-65b': 'Neko-Institute-of-Science/LLaMA-65B-hf',
                        'falcon-1b': 'tiiuae/falcon-rw-1b',
                        'falcon-7b': 'tiiuae/falcon-7b',
                        'falcon-40b': 'tiiuae/falcon-40b',
}

prediction_methods = ['direct_False', 
                      'channel_True', 
                      'calibrate_True',]

def aggregate_scores(df, seeds, conditions, method='str'):
    selected_runs = df
    for condition in conditions:
        k, v = list(condition.items())[0]
        selected_runs = selected_runs[selected_runs[k] == v]
    scores = []
    for seed in seeds:
        run = selected_runs[selected_runs['seed'] == seed]
        if len(run) < 1:
            print(conditions, seed)
        if len(run) != 0:
            scores.extend(run['scores'].values[0])
    if len(scores) == 0:
        out = 'NaN'
    else:
        if method == 'str':
            out = f"{np.mean(scores):.3f} ± {np.std(scores):.2f}"
            if len(scores) != len(seeds) * 10:
                out += f" ({len(scores)})"
        elif method == 'mean':
            out = round(np.mean(scores), 3)
        elif method == 'worst':
            out = min(scores)
    return out

In [None]:
from IPython.display import display

datasets = ['sst2', 'dbpedia', 'agnews', 'trec']
tables1 = {'0-shot': {}, '2-shot': {}}
for n_shots in [0, 2]:
    conditions = []
    
    seeds = [59] if n_shots == 0 else [59, 13, 21]
    selection_method = '0-shot' if n_shots == 0 else 'random'
    
    conditions.append({'selection_method': selection_method})
    conditions.append({'n_shots': n_shots})
    for dataset in datasets:
        conditions.append({'dataset': dataset})
        table = []
        for model in names_to_checkpoints:
            if model == 'opt-30b':
                seeds = [13]
            else:
                seeds = [59] if n_shots == 0 else [59, 13, 21]
            conditions.append({'model': model})
            entry = {'model': model}
            for method in prediction_methods:
                conditions.append({'prediction_method': method})
                res = aggregate_scores(all_runs_df, seeds, conditions)
                # res = f"mean ± std"
                entry.update({method: res})
                conditions.remove({'prediction_method': method})
            conditions.remove({'model': model})
            table.append(entry)
        table = pd.DataFrame(table)
        table.set_index('model', inplace=True)
        tables1[f"{n_shots}-shot"][dataset] = table
        conditions.remove({'dataset': dataset})
        

for n_shots in [0, 2]:
    print(f'{n_shots}-shot')
    for d in datasets:
        print(d)
        display(tables1[f'{n_shots}-shot'][d])
    print("-------")

In [None]:
import torch
import numpy as np
import pandas as pd
from scipy import stats
from collections import defaultdict

from utils import load_split_dataset, names_to_checkpoints

import warnings
warnings.filterwarnings("ignore")


mega_mega_res_df = {}
for dataset in ["sst2", "agnews", "dbpedia", "trec"]:
    base_path = f"templates_ensemble_2-shot_{dataset}/"
    train, val, labels_mp = load_split_dataset(dataset, seed=59)

    mega_res_df = {}

    for name in names_to_checkpoints:
        if not os.path.exists(base_path + f"{name}_direct_13"):
            continue
        mega_res_df[name] = {}
        mode_means = defaultdict(list)
        mean_means = defaultdict(list)
        seeds = [13, 21, 59]
        for seed in seeds:
            res = torch.load(base_path + f"{name}_direct_{seed}")
            for size in range(1, 11):
                mode = stats.mode(np.array(res['results'][:size])).mode[0]
                mode_mean = (mode == val['target']).mean()
                probs = torch.stack([
                    res['probs'][i] for i in range(len(res['probs'][:size]))
                ])
                answers = [labels_mp[x.item()] for x in probs.mean(dim=0).argmax(dim=1)]
                mean_mean = (answers == val['target']).mean()

                mode_means[size].append(mode_mean)
                mean_means[size].append(mean_mean)

        stds = []
        for size in range(1, 11):
            mega_res_df[name][size] = f"{np.mean(mean_means[size]):.3f} ± {np.std(mean_means[size]):.2f}"
            stds.append(np.std(mean_means[size]))

    mega_mega_res_df[dataset] = mega_res_df

In [None]:
for dataset in mega_mega_res_df:
    ensemble_10, ensemble_5, ensemble_3 = [], [], []
    for model in list(tables1["2-shot"][dataset].index):
        if model in mega_mega_res_df[dataset]:
            ensemble_10.append(mega_mega_res_df[dataset][model][10])
            ensemble_3.append(mega_mega_res_df[dataset][model][3])
            ensemble_5.append(mega_mega_res_df[dataset][model][5])
        else:
            ensemble_10.append("???")
            ensemble_3.append("???")
            ensemble_5.append("???")
    tables1["2-shot"][dataset]["3_direct"] = ensemble_3
    tables1["2-shot"][dataset]["5_direct"] = ensemble_5
    tables1["2-shot"][dataset]["10_direct"] = ensemble_10

In [None]:
for d in tables1['2-shot']:
    print(d)
    display(tables1['2-shot'][d])

In [None]:
mega_mega_res_df = {}
for dataset in ["sst2", "agnews", "dbpedia", "trec"]:
    base_path = f"ensembles_channel_true/{dataset}/2_shot/"
    train, val, labels_mp = load_split_dataset(dataset, seed=59)

    mega_res_df = {}

    for name in names_to_checkpoints:
        if not os.path.exists(base_path + f"{name}_channel_13"):
            continue
        mega_res_df[name] = {}
        mode_means = defaultdict(list)
        mean_means = defaultdict(list)
        seeds = [13, 21, 59]
        for seed in seeds:
            res = torch.load(base_path + f"{name}_channel_{seed}")
            for size in range(1, 11):
                mode = stats.mode(np.array(res['results'][:size])).mode[0]
                mode_mean = (mode == val['target']).mean()
                probs = torch.stack([
                    res['probs'][i] for i in range(len(res['probs'][:size]))
                ])
                answers = [labels_mp[x.item()] for x in probs.mean(dim=0).argmax(dim=1)]
                mean_mean = (answers == val['target']).mean()

                mode_means[size].append(mode_mean)
                mean_means[size].append(mean_mean)

        stds = []
        for size in range(1, 11):
            mega_res_df[name][size] = f"{np.mean(mean_means[size]):.3f} ± {np.std(mean_means[size]):.2f}"
            stds.append(np.std(mean_means[size]))

    mega_mega_res_df[dataset] = mega_res_df

In [None]:
for dataset in mega_mega_res_df:
    ensemble_10, ensemble_5, ensemble_3 = [], [], []
    for model in list(tables1["2-shot"][dataset].index):
        if model in mega_mega_res_df[dataset]:
            ensemble_10.append(mega_mega_res_df[dataset][model][10])
            ensemble_3.append(mega_mega_res_df[dataset][model][3])
            ensemble_5.append(mega_mega_res_df[dataset][model][5])
        else:
            ensemble_10.append("???")
            ensemble_3.append("???")
            ensemble_5.append("???")
    tables1["2-shot"][dataset]["3_channel"] = ensemble_3
    tables1["2-shot"][dataset]["5_channel"] = ensemble_5
    tables1["2-shot"][dataset]["10_channel"] = ensemble_10

In [None]:
for d in tables1['2-shot']:
    print(d)
    display(tables1['2-shot'][d][tables1['2-shot'][d]['5_channel'] != "???"])

In [None]:
from scipy import stats

mega_mega_res_df = {}
for dataset in ["sst2", "agnews", "dbpedia", "trec"]:
    base_path = f"ensembles_calibrate_true/{dataset}/2_shot/"
    train, val, labels_mp = load_split_dataset(dataset, seed=59)

    mega_res_df = {}

    for name in names_to_checkpoints:
        if not os.path.exists(base_path + f"{name}_calibrate_13"):
            continue
        mega_res_df[name] = {}
        mode_means = defaultdict(list)
        mean_means = defaultdict(list)
        seeds = [13, 21, 59] #+ list(range(0, 7))
        for seed in seeds:
            res = torch.load(base_path + f"{name}_calibrate_{seed}")
            for size in range(1, 11):
                mode = stats.mode(np.array(res['results'][:size])).mode[0]
                mode_mean = (mode == val['target']).mean()
                probs = torch.stack([
                    res['probs'][i] for i in range(len(res['probs'][:size]))
                ])
                answers = [labels_mp[x.item()] for x in probs.mean(dim=0).argmax(dim=1)]
                mean_mean = (answers == val['target']).mean()

                mode_means[size].append(mode_mean)
                mean_means[size].append(mean_mean)

        stds = []
        for size in range(1, 11):
            mega_res_df[name][size] = f"{np.mean(mean_means[size]):.3f} ± {np.std(mean_means[size]):.2f}"
            stds.append(np.std(mean_means[size]))

    mega_mega_res_df[dataset] = mega_res_df

In [None]:
for dataset in mega_mega_res_df:
    ensemble_10, ensemble_5, ensemble_3 = [], [], []
    for model in list(tables1["2-shot"][dataset].index):
        if model in mega_mega_res_df[dataset]:
            ensemble_10.append(mega_mega_res_df[dataset][model][10])
            ensemble_3.append(mega_mega_res_df[dataset][model][3])
            ensemble_5.append(mega_mega_res_df[dataset][model][5])
        else:
            ensemble_10.append("???")
            ensemble_3.append("???")
            ensemble_5.append("???")
    tables1["2-shot"][dataset]["3_calibrate"] = ensemble_3
    tables1["2-shot"][dataset]["5_calibrate"] = ensemble_5
    tables1["2-shot"][dataset]["10_calibrate"] = ensemble_10

In [None]:
for d in tables1['2-shot']:
    print(d)
    display(
        tables1['2-shot'][d][tables1['2-shot'][d]['5_calibrate'] != "???"][["direct_False", "5_direct", "channel_True", "5_channel", "calibrate_True", "5_calibrate"]]
    )

In [None]:
col_order = ["calibrate_True"] +\
            [f"{x}_calibrate" for x in [3, 5, 10]]

In [None]:
latex_df = tables1['2-shot']["sst2"][tables1['2-shot']["sst2"]['5_calibrate'] != "???"]
print(
    latex_df[col_order].to_latex()
)

In [None]:
import matplotlib

%config InlineBackend.figure_format = 'retina'

nice_fonts = {
        # Use LaTeX to write all text
        "text.usetex": True,
        "font.family": "Times New Roman",
        # Use 10pt font in plots, to match 10pt font in document
        "axes.labelsize": 14,
        "font.size": 14,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 14,
        "xtick.labelsize": 14,
        "ytick.labelsize": 14,
}
matplotlib.rcParams.update(nice_fonts)

In [None]:
import torch
import numpy as np
import pandas as pd
from scipy import stats
from collections import defaultdict

from utils import load_split_dataset, names_to_checkpoints

import warnings
warnings.filterwarnings("ignore")


dataset = "sst2"
base_path = {
    "direct": f"templates_ensemble_2-shot_{dataset}/",
    "calibrate": f"ensembles_calibrate_true/{dataset}/2_shot/",
    "channel": f"ensembles_channel_true/{dataset}/2_shot/",
}
train, val, labels_mp = load_split_dataset(dataset, seed=59)

mega_res_df = {}
stds = {}

for method in [
    "direct",
    "calibrate",
    "channel"
]:
    mega_res_df[method] = {}
    stds[method] = {}
    for name in ["gpt2-large", "gpt2-xl", "llama-7b", "llama-13b", "llama-30b", "llama-65b"]:
        stds[method][name] = []
        mega_res_df[method][name] = {}
        mode_means = defaultdict(list)
        mean_means = defaultdict(list)
        seeds = [13, 21, 59]
        for seed in seeds:
            res = torch.load(base_path[method] + f"{name}_{method}_{seed}")
            for size in range(1, 11):
                mode = stats.mode(np.array(res['results'][:size])).mode[0]
                mode_mean = (mode == val['target']).mean()
                probs = torch.stack([
                    res['probs'][i] for i in range(len(res['probs'][:size]))
                ])
                answers = [labels_mp[x.item()] for x in probs.mean(dim=0).argmax(dim=1)]
                mean_mean = (answers == val['target']).mean()

                mode_means[size].append(mode_mean)
                mean_means[size].append(mean_mean)

        for size in range(1, 11):
            mega_res_df[method][name][size] = np.mean(mean_means[size])
            stds[method][name].append(np.std(mean_means[size]))


In [None]:
import matplotlib.pyplot as plt

x = range(1, 11)

fig, axs = plt.subplots(3, 2, figsize=(6.4 * 2, 4 * 3))
fig.tight_layout() 

i = 0
j = 0
for name in mega_res_df["channel"]:
    direct = float(tables1["2-shot"][dataset]["direct_False"][name][:5])
    calibrate = float(tables1["2-shot"][dataset]["calibrate_True"][name][:5])
    channel = float(tables1["2-shot"][dataset]["channel_True"][name][:5])

    for method in [
        "direct",
        "calibrate",
        "channel"
    ]:
        values = list(mega_res_df[method][name].values())
        axs[i, j].plot(x, values, label=f"Channel + Ensemble")
        axs[i, j].fill_between(range(1, 11),
                     np.array(values) + np.array(stds[method][name]),
                     np.array(values) - np.array(stds[method][name]),
                     alpha=0.2)

    axs[i, j].plot(x, [direct] * 10, label="Direct", linestyle="dashed")
    axs[i, j].plot(x, [calibrate] * 10, label="Calibrate", linestyle="dashed")
    axs[i, j].plot(x, [channel] * 10, label="Channel", linestyle="dashed", color="tab:blue")




    axs[i, j].set_title(f"{spelling[name]}")
    axs[i, j].set_xlabel("Ensemble size")
    axs[i, j].set_ylabel("Accuracy")
    axs[i, j].legend(loc="lower right")
    
    j += 1
    if j == 2:
        j = 0
        i += 1

plt.subplots_adjust(hspace=0.3)
plt.savefig(f'../pictures/all_ensemble_2shot.pdf', format='pdf', bbox_inches='tight', pad_inches=0)

## 0-shot to 2-shot transfer

In [None]:
lena_templates = {}

for dataset in dataset_templates:
    lena_templates[dataset] = {}
    for seed in dataset_templates[dataset]:
        lena_templates[dataset][seed] = [''.join(t[:-1]) for t in dataset_templates[dataset][seed]]

In [None]:
all_runs_df.head()

In [None]:
def make_t_to_s(df):
    templates = lena_templates[df["dataset"]][df["seed"]]
    return {t: s for t, s in zip(templates, df["scores"])}

all_runs_df["templates_to_scores"] = all_runs_df.apply(make_t_to_s, axis=1)

In [None]:
all_runs_df.head()

In [None]:
bests = []
others = []
for seed in [13, 21, 59]:
    s = few_shot[(few_shot["n_shots"] == 2) & (few_shot["seed"] == seed)]["templates_to_scores"].values[0]
    s = {k: v for k, v
             in sorted(s.items(),
                       key=lambda x: x[1],
                       reverse=True)}
    bests.append(list(s.values())[0])
    others += list(s.values())
    print(s)
    print()
print(bests)
print(others)

In [None]:
dataset = "sst2"
method = "direct_False"

few_shot_res = {}
for model in ["gpt2-large", "gpt2-xl", "llama-7b", "llama-13b", "llama-30b", "llama-65b"]:#names_to_checkpoints:
    try:
        zero_shot = torch.load(f"template_selection/{dataset}/{model}_formats_stats_zero_shot_{method}")
    except FileNotFoundError as e:
        print(e)
        continue
    
    few_shot = all_runs_df[
        (all_runs_df["prediction_method"] == method) &
        (all_runs_df["model"] == model) &
        (all_runs_df["dataset"] == dataset) &
        (all_runs_df["selection_method"] == "random")
    ].sort_values(by=["n_shots", "seed"])
    
    few_shot_res[model] = {}
    for n_shot in [2, 4, 8]:
        few_shot_res[model][n_shot] = {}
        bests = []
        others = []
        zero_on_few = []
        for seed in [13, 21, 59]:
            s = few_shot[(few_shot["n_shots"] == n_shot) & (few_shot["seed"] == seed)]["templates_to_scores"].values[0]
            s = {k: v for k, v
                     in sorted(s.items(),
                               key=lambda x: x[1],
                               reverse=True)}
            bests.append(list(s.values())[0])
            others += list(s.values())[1:]
            
            if n_shot !=8:
                zero_on_few.append(
                    torch.load(f"gpt2_llama_fewshot/{model}_{dataset}_{method}_{n_shot}shot_{seed}")["\n"]
                )
            
        few_shot_res[model][n_shot]["best"] = (np.mean(bests), np.std(bests))
        few_shot_res[model][n_shot]["others"] = (np.mean(others), np.std(others))
        few_shot_res[model][n_shot]["zero_on_few"] = (np.mean(zero_on_few), np.std(zero_on_few))

In [None]:
for model in few_shot_res:
    for n_shot in few_shot_res[model]:
        for k in few_shot_res[model][n_shot]:
            mean = few_shot_res[model][n_shot][k][0]
            std = few_shot_res[model][n_shot][k][1]
            few_shot_res[model][n_shot][k] = f"{mean:.2f}_\textsubscript{std:.2f}"

In [None]:
for model in few_shot_res:
    print(model)
    print(
        pd.DataFrame(few_shot_res[model]).transpose()[["zero_on_few", "best", "others"]].to_latex()
    )

In [None]:
from collections import defaultdict


for model in [
    "bloom-1.7b", "bloom-3b", "bloom-7.1b",
    "cerebras-6.7b", "cerebras-13b",
    "gpt2-large", "gpt2-xl", "gptj",
    "opt-1.3b", "opt-6.7b",
    "pythia-6.9b", "pythia-12b"
]:
    zero_shot = torch.load(f"template_selection/agnews/{model}_formats_stats_zero_shot_direct_False")
    few_shot = torch.load(f"agnews_random_shot_direct_false/{model}_res/stats")
    zero_shot = {k.replace("{}", "{}" + k[-1], 1)[:-1]: v for k, v in zero_shot.items()}
    zero_shot = {k: v for k, v
                 in sorted(zero_shot.items(),
                           key=lambda x: x[1],
                           reverse=True)}
    
    few_shot_res = {}
    for n_shot in [2, 4, 8]:
        few_shot[n_shot] = defaultdict(list)
        for seed in [13, 21, 59]:
            for key in few_shot[f"{n_shot}_shot_{seed}"]:
                few_shot[n_shot][key].append(np.array(few_shot[f"{n_shot}_shot_{seed}"][key]))
        few_shot_res[f"{n_shot}_mean"] = {k: np.mean(v, axis=0) for k, v in few_shot[n_shot].items()}
        few_shot_res[f"{n_shot}_std"] = {k: np.mean(v, axis=0) for k, v in few_shot[n_shot].items()}
        
        few_shot_res[f"{n_shot}_mean"] = {k: v for k, v
                                          in sorted(few_shot_res[f"{n_shot}_mean"].items(),
                                                    key=lambda x: x[1],
                                                    reverse=True)}
    break

In [None]:
method = "direct_False"
dataset = "sst2"

best_templates = {}
best_templates[dataset] = {}
best_templates[dataset][method] = {}

for model in names_to_checkpoints:
        try:
            zero_shot = torch.load(f"template_selection/{dataset}/{model}_formats_stats_zero_shot_{method}")
        except FileNotFoundError as e:
            print(e)
            continue
        zero_shot = {k: v for k, v
                     in sorted(zero_shot.items(),
                               key=lambda x: x[1],
                               reverse=True)}
        best_template = list(zero_shot.keys())[0]
        
        template = best_template.split(" {}")
        if best_template.startswith("{}"):
                    template.insert(0, "{}")
                    template[1] = template[1][2:]
        else:
            template[0] += " {}"
        template[1] += " {}"
        best_templates[dataset][method][model] = template

In [None]:
best_templates

# Methods

In [None]:
methods = ['direct_True', 'direct_False', 'channel_True', 'channel_False', 'calibrate_True', 'calibrate_False']
for n_shots in ['0-shot', '2-shot']:
    worst_table = {'sst2':[], 'dbpedia':[], 'agnews':[], 'trec':[]}
    for dataset in mean_table:
        table = tables1[n_shots][dataset]
        for method in methods:
            mean_table[dataset].append(f"${table[method].min():.3f}_{{\\pm{table[method].std():.3f}}}$")

    print(n_shots)
    table = pd.DataFrame(mean_table, index=methods)
    display(table)
    print(table.to_latex())

In [None]:
methods = ['direct_True', 'direct_False', 'channel_True', 'channel_False', 'calibrate_True', 'calibrate_False']
for n_shots in ['0-shot', '2-shot']:
    mean_table = {'sst2':[], 'dbpedia':[], 'agnews':[], 'trec':[]}
    for dataset in mean_table:
        table = tables1[n_shots][dataset]
        for method in methods:
            mean_table[dataset].append(f"${table[method].mean():.3f}_{{\\pm{table[method].std():.3f}}}$")

    print(n_shots)
    table = pd.DataFrame(mean_table, index=methods)
    display(table)
    print(table.to_latex())

In [None]:
from IPython.display import display

datasets = ['sst2', 'dbpedia', 'agnews', 'trec']
tables1 = {'0-shot': {}, '2-shot': {}}
for n_shots in [0, 2]:
    conditions = []
    
    seeds = [59] if n_shots == 0 else [59, 13, 21]
    selection_method = '0-shot' if n_shots == 0 else 'random'
    
    conditions.append({'selection_method': selection_method})
    conditions.append({'n_shots': n_shots})
    for dataset in datasets:
        conditions.append({'dataset': dataset})
        table = []
        for model in names_to_checkpoints:
            if model == 'opt-30b':
                seeds = [13]
            else:
                seeds = [59] if n_shots == 0 else [59, 13, 21]
            conditions.append({'model': model})
            entry = {'model': model}
            for method in prediction_methods:
                conditions.append({'prediction_method': method})
                res = aggregate_scores(all_runs_df, seeds, conditions, method='worst')
                entry.update({method: res})
                conditions.remove({'prediction_method': method})
            conditions.remove({'model': model})
            table.append(entry)
        table = pd.DataFrame(table)
        table.set_index('model', inplace=True)
        tables1[f"{n_shots}-shot"][dataset] = table
        conditions.remove({'dataset': dataset})
        
methods = ['direct_True', 'direct_False', 'channel_True', 'channel_False', 'calibrate_True', 'calibrate_False']
for n_shots in ['0-shot', '2-shot']:
    worst_table = {'sst2':[], 'dbpedia':[], 'agnews':[], 'trec':[]}
    for dataset in worst_table:
        table = tables1[n_shots][dataset]
        for method in methods:
            worst_table[dataset].append(f"${table[method].mean():.3f}_{{\\pm{table[method].std():.3f}}}$")

    print(n_shots)
    table = pd.DataFrame(worst_table, index=methods)
    display(table)
    print(table.to_latex())

In [None]:
def get_mean(str_score, return_std=False):
    score = str_score.split(" (")[0]
    if score == "NaN":
        mean, std = 0, 1
    else:
        mean, std = map(float, score.split(" ± "))
    if return_std:
        return mean, std
    else:
        return mean
    
    
for n_shots in [0, 2]:
    for dataset in ['sst2', 'dbpedia', 'agnews', 'trec']:
        table = tables1[f'{n_shots}-shot'][d]
        method_stats = {'direct': 0, 'channel': 0, 'calibrate': 0}
        
        for model, row in table.iterrows():
            for method in ['direct', 'channel', 'calibrate']:
                method_stats[method] += int(get_mean(row[f"{method}_True"]) > get_mean(row[f"{method}_False"]))
        for method in method_stats:
            method_stats[method] = round(method_stats[method] / len(table), 2)
        
        print(f"{n_shots}-shot & {dataset} & {' & '.join([str(method_stats[k]) for k in ['direct', 'channel', 'calibrate']])}\\\\")

In [None]:
def get_min(model, dataset, n_shots, seed, method='direct_False'):
    selected_runs = all_runs_df
    for k, v in zip(['model', 'dataset', 'n_shots', 'seed', 'prediction_method', 'selection_method'], 
                    [model, dataset, n_shots, seed, method, 'random']):
        selected_runs = selected_runs.loc[selected_runs[k] == v]
    if len(selected_runs) == 0:
        return 0
    else:
        return min(selected_runs['scores'].values[0])

for n_shots in [0, 2]:
    seeds = [59] if n_shots == 0 else [59, 13, 21]
    
    for dataset in ['sst2', 'dbpedia', 'agnews', 'trec']:
        method_stats = {k: 0 for k in ['direct_True', 'channel_True', 'channel_False', 'calibrate_True', 'calibrate_False']}
        for model in table.index:
            for seed in seeds:
                baseline_score = get_min(model, dataset, n_shots, seed, 'direct_False')
                for method in method_stats:
                    method_score = get_min(model, dataset, n_shots, seed, method)
                    if method_score > baseline_score:
                        method_stats[method] += 1
        for k in method_stats:
            method_stats[k] = round(method_stats[k] / len(table.index) / len(seeds), 2)
        print(f"{n_shots}-shot & {dataset} & {' & '.join([str(v) for k, v in method_stats.items()])}\\\\")

In [None]:
def print_table(table):
    columns = table.columns.values
    for model, row in table.iterrows():
        res = f"{model} & "
        best_mean, best_std, idx1, idx2 = 0, 1, 0, 0
        for i, col in enumerate(columns):
            score = row[col].split(" (")[0]
            if score == "NaN":
                mean, std = 0, 1
            else:
                mean, std = map(float, row[col].split(" (")[0].split(" ± "))
            if mean > best_mean:
                idx1 = i
                best_mean = mean
            if std < best_std:
                idx2 = i
                best_std = std
        for i, col in enumerate(columns):
            score = row[col].split(" (")[0]
            if score == "NaN":
                mean, std = 0, 1
            else:
                mean, std = map(float, row[col].split(" (")[0].split(" ± "))
                
            if i == idx1 and i == idx2:
                row_line = "$\\mathbf{{{:.3f}}}_{{\\underline{{\\pm{:.3f}}}}}$"
            elif i == idx1:
                row_line = "$\\mathbf{{{:.3f}}}_{{\\pm{:.3f}}}$"
            elif i == idx2:
                row_line = "${:.3f}_{{\\underline{{\\pm{:.3f}}}}}$"
            else:
                row_line = "${:.3f}_{{\\pm{:.3f}}}$"
            res += row_line.format(mean, std)
            if i != len(columns) - 1:
                res += '& '
        print(f"{res} \\\\")
for dataset in ['sst2', 'dbpedia', 'agnews', 'trec']:
    print_table(tables1['0-shot'][dataset])
    print('\n\n\n\n')

# 4.3. Example selection methods

* they are not resistant to the choice of template
* they are not transferable from one model to another

In [None]:
model_to_method = {
    'sst2': {k: 'calibrate_True' if k in ['falcon-7b', 'falcon-40b', 'llama-65b'] else 'channel_True'
            for k in model_specs},
    'dbpedia': {k: 'channel_True' if k in ['gpt2-xl', 'bloom-1.7b', 'bloom-3b', 'bloom-7.1b', 
                                           'cerebras-13b', 'llama-7b'] else 'calibrate_True' 
                for k in model_specs},
    'agnews': {k: 'calibrate_True' if k in ['opt-6.7b', 'falcon-1b', 'llama-13b', 'llama-30b', 'llama-65b'] 
              else 'channel_True' for k in model_specs},
    'trec': {model: 'channel_True' for model in ['gpt2-xl', 'opt-1.3b', 'opt-30b', 'opt-66b', 'falcon-1b',
                                                'pythia-6.9b', 'pythia-12b', 'cerebras-6.7b', 'cerebras-13b']},
    }
model_to_method['trec'].update({model: 'calibrate_True' 
                                for model in ['bloom-1.7b', 'llama-30b', 'llama-7b', 'llama-13b', 
                                              'llama-65b', 'falcon-40b', 'falcon-7b']})
model_to_method['trec'].update({model: 'calibrate_False' for model in ['gptj', 'gpt-neox', 
                                                                       'bloom-7.1b', 'bloom-3b',
                                                                      'opt-6.7b', 'gpt2-large']})

In [None]:
def get_scores(dataset, model, n_shots, method, prediction_method='direct_False'):
    selected_runs = all_runs_df.loc[all_runs_df['model'] == model]
    selected_runs = selected_runs.loc[selected_runs['dataset'] == dataset]
    selected_runs = selected_runs.loc[selected_runs['selection_method'] == method]
    selected_runs = selected_runs.loc[selected_runs['prediction_method'] == prediction_method]
    selected_runs = selected_runs.loc[selected_runs['n_shots'] == n_shots]
    seeds = selected_runs['seed'].value_counts()
    scores = []
    if len(seeds) == 3:
        pass
    elif len(seeds) > 0:
        print(f"{dataset}_{model}_{method}_{n_shots} seeds: {seeds.keys()}")
        pass
    else:
        print(f"!!!!!no seeds {dataset}_{model}_{method}_{n_shots} no seeds!!!!")
        pass
    for seed, _ in seeds.items():
        scores.extend(selected_runs.loc[selected_runs['seed'] == seed]['scores'].values[0])
    return scores
    
tables_dataset = {}
methods = ['random', 'implicitly_topic_models', 'z-ICL', 'CEIL'][:-1]
n_shots_range = [2, 4, 8][:-1]
for dataset in datasets:
    print(dataset)
    table = {f"{m}-{k}": {} for m in methods for k in n_shots_range}

    for method in methods:
        n_shot_res = {}
        for n_shots in n_shots_range:
            if method == 'z-ICL' and n_shots == 8:
                continue
            for model in names_to_checkpoints:
                method_scores = get_scores(model=model, dataset=dataset, n_shots=n_shots, method=method)
                if len(method_scores) > 0:
                    mean = np.mean(method_scores)
                    std = np.std(method_scores)
                    res = f"{mean:.3f} ± {std:.3f}"
                else:
                    res = "NaN"
                table[f"{method}-{n_shots}"][model] = res
    table = pd.DataFrame(table, index=names_to_checkpoints.keys(), columns=['random-2', 'random-4',
                                                                            'implicitly_topic_models-2', 
                                                                            'implicitly_topic_models-4', 
                                                                            'z-ICL-2', 'z-ICL-4'])
    table.columns = ['random-2', 'random-4', 'ITM-2', 'ITM-4', 'z-ICL-2', 'z-ICL-4']
    display(table)
    tables_dataset[dataset] = table
    print(f"\multirow{{3}}{{*}}{{{model}}} & 2 & {method_res['random'][2]} & {method_res['implicitly_topic_models'][2]} & {method_res['z-ICL'][2]} & {method_res['CEIL'][2]}\\\\")
    print(f"& 4 &{method_res['random'][4]} & {method_res['implicitly_topic_models'][4]} & {method_res['z-ICL'][4]} & {method_res['CEIL'][4]}\\\\")
    print(f"& 8 &{method_res['random'][8]} & {method_res['implicitly_topic_models'][8]} & {method_res['z-ICL'][8]} & {method_res['CEIL'][8]}\\\\")

In [None]:
tables_dataset['sst2'].loc['gpt2-large']

In [None]:
def get_scores(model, method, n_shots, dataset):
    table = tables_dataset[dataset]
    score = table.loc[model, f"{method}-{n_shots}"]
    mean, std = map(float, score.split(" ± "))
    res = f"{mean:.2f}\\textsubscript{{{std:.2f}}}"
    return res
    
for model in ['gpt2-large', 'gpt2-xl', 'llama-7b', 'llama-13b', 'llama-30b', 'llama-65b']:
    print(f"\\multirow{{2}}{{*}}{{{model}}}", end=" & ")
    for n_shots in [2, 4]:
        if n_shots == 4:
            print(f"& {n_shots} &", end=' ')
        for dataset in datasets:
            for method in ['random', 'ITM', 'z-ICL']:    
                end = '' if (method == 'z-ICL') and (dataset == 'trec') else " & "
                print(f"{get_scores(model, method, n_shots, dataset)}", end=end)
        print('\\\\')
    if model == 'gpt2-xl':
        print('\\midrule')

In [None]:
def get_toptemplates(dataset, model, n_shots, method, prediction_method='direct_False', return_k=5):
    selected_runs = all_runs_df.loc[all_runs_df['model'] == model]
    selected_runs = selected_runs.loc[selected_runs['dataset'] == dataset]
    selected_runs = selected_runs.loc[selected_runs['selection_method'] == method]
    selected_runs = selected_runs.loc[selected_runs['prediction_method'] == prediction_method]
    selected_runs = selected_runs.loc[selected_runs['n_shots'] == n_shots]
    seeds = selected_runs['seed'].value_counts()
    scores = []
    if len(seeds) == 3:
        for seed in [59, 13, 21]:
            scores.extend(selected_runs.loc[selected_runs['seed'] == seed]['scores'].values[0])
        all_templates = dataset_templates[dataset][59] + dataset_templates[dataset][13] + dataset_templates[dataset][21]
    elif len(seeds) > 0:
        print(f"{dataset}_{model}_{method}_{n_shots} seeds: {seeds.keys()}")
        for seed in [59, 13]:
            scores.extend(selected_runs.loc[selected_runs['seed'] == seed]['scores'].values[0])
        all_templates = dataset_templates[dataset][59] + dataset_templates[dataset][13]
    else:
        print(f"!!!!!no seeds {dataset}_{model}_{method}_{n_shots} no seeds!!!!")
        raise NotImplementedError
    toptemplates = np.argsort(scores)[::-1]
    res = []
    for i in range(return_k):
        res.append(all_templates[toptemplates[i]])
    return res

def calc_iou(list_a, list_b):
    intersection, union = 0, len(list_b)
    for template in list_a:
        if template in list_b:
            intersection += 1
        else:
            union += 1
    return intersection / union

pred_methods = ["direct_False", "calibrate_True", "channel_True"]

heatmaps = {}
models = list(names_to_checkpoints.keys())
models = [x for x in models if not x.startswith("pythia") and not x.startswith("cerebras")]
for return_k in [10]:
    for n_shots in [2]:
        for dataset in ['sst2', 'dbpedia', 'agnews', 'trec']:
            heatmap = [[0 for _ in range(len(models))] for _ in range(len(models))]
            method_toptemplates = {}
            avg_iou = []
            for i in range(len(models)):
                for j in range(len(models)):
                        method_i_toptemplates = get_toptemplates(model=models[i], dataset=dataset, 
                                                                 n_shots=n_shots, 
                                                                 method="random", 
                                                                 return_k=return_k,
                                                                 prediction_method="direct_False")
                        method_j_toptemplates = get_toptemplates(model=models[j], dataset=dataset, 
                                                                 n_shots=n_shots,
                                                                 method="random",
                                                                 return_k=return_k,
                                                                 prediction_method="direct_False")

                        heatmap[i][j] = calc_iou(method_i_toptemplates, method_j_toptemplates)
                        
            heatmaps[dataset] = heatmap
            print(model, n_shots, heatmap[0][1], heatmap[0][2], heatmap[1][2], sep=" & ")


In [None]:
spelling = {'gpt2-large': 'GPT-2 Large',
 'gpt2-xl': 'GPT-2 XL',
 'gptj': 'GPT-J',
 'gpt-neox': 'GPT-NeoX',
 'opt-1.3b': 'OPT 1.3B',
 'opt-6.7b': 'OPT 6.7B',
 'opt-30b': 'OPT 30B',
 'opt-66b': 'OPT 66B',
 'bloom-1.7b': 'BLOOM 1.7B',
 'bloom-3b': 'BLOOM 3B',
 'bloom-7.1b': 'BLOOM 7.1B',
 'pythia-6.9b': 'Pythia 6.9B',
 'pythia-12b': 'Pythia 12B',
 'cerebras-6.7b': 'Cerebras 6.7B',
 'cerebras-13b': 'Cerebras 13B',
 'llama-7b': 'LLaMA 7B',
 'llama-13b': 'LLaMA 13B',
 'llama-30b': 'LLaMA 30B',
 'llama-65b': 'LLaMA 65B',
 'falcon-1b': 'Falcon 1B',
 'falcon-7b': 'Falcon 7B',
 'falcon-40b': 'Falcon 40B'}

In [None]:
import matplotlib

%config InlineBackend.figure_format = 'retina'

nice_fonts = {
        # Use LaTeX to write all text
        "text.usetex": True,
        "font.family": "Times New Roman",
        # Use 10pt font in plots, to match 10pt font in document
        "axes.labelsize": 14,
        "font.size": 14,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 14,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
}
matplotlib.rcParams.update(nice_fonts)

In [None]:
import seaborn as sns

dataset = "trec"
sns.heatmap(
    heatmaps[dataset],
    cmap=sns.color_palette("coolwarm", as_cmap=True),
    yticklabels=[spelling[x] for x in models],
    xticklabels=[spelling[x] for x in models],
)
plt.xticks(rotation=60, ha="right")
plt.savefig(f'../pictures/heatmap_{dataset}.pdf', format='pdf', bbox_inches='tight', pad_inches=0)

In [None]:
tables3 = {}
seeds = [59, 13, 21]
conditions = [{'prediction_method': 'direct_False'}]
for dataset in datasets:
    table = []
    conditions.append({'dataset': dataset})
    for model in names_to_checkpoints:
        conditions.append({'model': model})
        entry = {'model': model}
        for n_shots in [2, 4, 8]:
            conditions.append({'n_shots': n_shots})
            for method in ['random', 'implicitly_topic_models']:
                conditions.append({'selection_method': method})
                
                res = aggregate_scores(all_runs_df, seeds, conditions)
                entry.update({f"{method}-{n_shots}": res})
                conditions.remove({'selection_method': method})
            conditions.remove({'n_shots': n_shots})
        conditions.remove({"model": model})
        table.append(entry)
    table = pd.DataFrame(table)
    table.set_index('model', inplace=True)
    tables3[dataset] = table
    conditions.remove({'dataset': dataset})

In [None]:
for d in datasets:
    print(d)
    display(tables3[d])
print("-------")

In [None]:
all_runs_df[all_runs_df['selection_method'] == 'implicitly_topic_models']