In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from transformers import BertModel, DistilBertModel
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import datasets
import numpy as np
import os.path
import data
from datasets import load_from_disk
import pickle as pkl
from sklearn.linear_model import LogisticRegressionCV
from collections import defaultdict
from copy import deepcopy
from tqdm import tqdm
import dvu
dvu.set_style()
import sys
sys.path.append('..')
import pandas as pd
from auggam import analyze_helper
from auggam import config
from os.path import join as oj
import matplotlib.pyplot as plt
pd.set_option('display.max_rows', None)

In [None]:
# rs = data.load_fitted_results()
# rs.to_pickle(oj(config.results_dir, 'fitted_results_aggregated.pkl'))

rs = pd.read_pickle(oj(config.results_dir, 'fitted_results_aggregated.pkl'))
rr, r_sem = analyze_helper.average_seeds(rs)
rs.head()

# calc accs for tables

In [None]:
def rename_checkpoint(checkpoint):
        cp = checkpoint.lower()
        if '/' in cp or 'finetune' in cp:
            if 'distilbert' in cp:
                return 'distilbert-finetuned'
            elif 'roberta' in cp:
                return 'roberta-finetuned'
            else:
                return 'bert-finetuned'
        else:
            return checkpoint

def get_acc_table(r):
    r = r[['dataset', 'checkpoint', 'layer', 'parsing'] + ['acc_val_print']]
    r.checkpoint = r.checkpoint.apply(rename_checkpoint)

    # group by (dataset, checkpoint, layer)
    rg = r.groupby(['dataset', 'checkpoint', 'layer', 'parsing'])

    # calc max acc
    rg = rg.max()

    # make acc table (dataset x [checkpoint, layer])
    rg = rg.reset_index().pivot(index='dataset',
                                 columns=['checkpoint', 'layer', 'parsing'],
                                 values='acc_val_print')
    rg.columns = ['___'.join(s) for s in rg.columns.to_flat_index()] # flatten index to tuples
    return rg

r = rr[rr.subsample == -1]
accs = get_acc_table(r)
r1 = r[r.ngrams == 1]
accs1 = get_acc_table(r1).add_suffix('___ngrams=1')
accs = accs.join(accs1)

In [None]:
accs.keys()

In [None]:
accs.transpose()

**best-model accs**

In [None]:
columns = {
    'bert-finetuned___last_hidden_state_mean___': '\\textbf{Emb-GAM}', #'\\textbf{Emb-grams (BERT finetuned)}',
    'countvectorizer___last_hidden_state_mean___': 'Bag of ngrams',
    'tfidfvectorizer___last_hidden_state_mean___': 'TF-IDF',
    'bert-finetuned___last_hidden_state_mean______ngrams=1': '\makecell[l]{Emb-GAM\\\\(unigrams only)}',    # 'Emb-GAM (Unigrams only)',# '\makecell{Emb-grams\\\\(BERT finetuned, Ngram size=1)}',    
}

tab = accs[list(columns.keys())].rename(columns=columns)
tab

# rename index
tab.index = map(analyze_helper.DSETS_RENAME_DICT.get, tab.index, tab.index)
# tab.index = tab.append(tab.pop())
# tab.round(2)

# tab2 = tab.apply(analyze_helper.bold_extreme_values, axis=1)
tab2 = tab
print(tab2.transpose().to_latex(escape=False))

**model variations table**

In [None]:
columns = {
    'bert-finetuned___last_hidden_state_mean___': '\makecell{BERT finetuned}',
    'bert-finetuned___pooler_output___': '\makecell{BERT finetuned\\\\(pooler output)}',
    'bert-finetuned___last_hidden_state_mean___noun_chunks': '\makecell{BERT finetuned\\\\(noun chunks)}',    
    'bert-base-uncased___last_hidden_state_mean___': 'BERT',
    'bert-base-uncased___pooler_output___': '\makecell{BERT\\\\(pooler output)}',    
    'bert-base-uncased___last_hidden_state_mean___': 'BERT',
    'distilbert-base-uncased___last_hidden_state_mean___': 'DistilBERT finetuned',
    'distilbert-finetuned___last_hidden_state_mean___': 'DistilBERT',
}

tab = accs[list(columns.keys())].rename(columns=columns)
tab

# rename index
tab.index = map(analyze_helper.DSETS_RENAME_DICT.get, tab.index, tab.index)
tab = tab.round(3)

In [None]:
print(tab.transpose().to_latex(escape=False, column_format='c' + 'l' * (tab.shape[1] - 1)))

# all curves

In [None]:
dvu.set_style()
plt.figure(figsize=(6, 4.75))
plt.rcParams['font.size'] = '14'

for i, dset in enumerate(['financial_phrasebank', 'rotten_tomatoes', 'sst2', 'emotion']):
    # r = rs[rs.dataset == dset]
    r = rr[rr.dataset == dset]    
    r = r[r.parsing == '']
    r1 = r[~r.checkpoint.str.lower().str.contains('bert')]
    # r2 = r[
    #     (r.checkpoint.apply(rename_checkpoint) == 'bert-base-uncased') & \
    #     (r.layer == 'last_hidden_state_mean')
    # ]
    r3 = r[
        (r.checkpoint.apply(rename_checkpoint) == 'bert-finetuned') & \
        (r.layer == 'last_hidden_state_mean')
    ]    
    # r = r1.append(r2).append(r3)
    # r = r1.append(r3)
    r = pd.concat((r1, r3))

    plt.subplot(2, 2, i + 1)
    plt.title(f'{analyze_helper.DSETS_RENAME_DICT.get(dset).replace("Financial phrasebank", "FPB")}', fontsize='large')
    d = r[(r.subsample == -1)]

    curve = sorted(d.groupby(['checkpoint', 'all', 'norm']),
                   key=lambda x: analyze_helper.COLUMNS_RENAME_DICT.get(x[0][0], 'BERT finetuned'))
    # curve.append(curve.pop(0)) # move BERT to bottom
    for key, group in curve:
        g = group.sort_values('ngrams')
        if 'distilbert' in key[0].lower():
            label = analyze_helper.COLUMNS_RENAME_DICT.get(key[0], 'Emb-GAM (DistilBERT finetuned)')    
        else:
            label = analyze_helper.COLUMNS_RENAME_DICT.get(key[0], 'Aug-Linear')
        if 'Aug-Linear' in label:
            # plt.plot(g.ngrams, g.acc_val, '.-', label=label, lw=2.5, color='black', ms=8)
            plt.errorbar(g.ngrams, g.acc_val, yerr=g.acc_val_sem, fmt='.-', label=label, lw=2.5, color='C0', ms=8)
        else:
            if label == 'Bag of ngrams':
                color = '#555'
            elif label == 'TF-IDF':
                color = '#AAA'
            # plt.plot(g.ngrams, g.acc_val, '.--', label=label, lw=1.5, alpha=0.8, ms=8)
            plt.errorbar(g.ngrams, g.acc_val, yerr=g.acc_val_sem, fmt='.-', label=label, lw=1.5, alpha=0.8, ms=8, color=color)
    if i % 2 == 0:
        plt.ylabel(f'Accuracy', fontsize='large')
    if i >= 2:
        plt.xlabel('Ngram size', fontsize='large')
    # plt.xlabel('Ngram size', fontsize='large')
    # plt.legend()
#         dvu.line_legend(fontsize=12)

    plt.tight_layout()
# plt.legend(bbox_to_anchor=(0,1.02,1,0.2), loc="lower left",
#                 mode="expand", borderaxespad=0, ncol=3)
# plt.legend(bbox_to_anchor=(1.04, 1.0), loc="center left",
#            borderaxespad=0, labelcolor='linecolor')
# plt.legend(bbox_to_anchor=(1.04, 1.0), loc="center left",
        #    borderaxespad=0,)
plt.legend(fontsize='medium', bbox_to_anchor=(0.5, -0.45))
plt.savefig(f'acc_by_ngrams_full.pdf', bbox_inches='tight')
#     plt.show()    
    #     print(curve)

## older versions

In [None]:
for dset in ['financial_phrasebank', 'rotten_tomatoes', 'sst2']: # rs.dataset.unique():
    r = rs[rs.dataset == dset]

    R = 1
    C = 3
    plt.figure(figsize=(12, 5))
    # plt.figure(figsize=(8, 12))
    for i, subsample in enumerate([100, 1000, -1]):
        plt.subplot(R, C, i + 1)
        plt.title('Num train=' + str(subsample))
        d = r[r.subsample == subsample]

        curve = sorted(d.groupby(['checkpoint', 'all', 'norm']),
                       key=lambda x: data.COLUMNS_RENAME_DICT.get(x[0][0], 'BERT finetuned'))
        curve.append(curve.pop(0)) # move BERT to bottom
        for key, group in curve:
            g = group.sort_values('ngrams')
            label = data.COLUMNS_RENAME_DICT.get(key[0], 'BERT finetuned')
            if label == 'BERT finetuned':
                plt.plot(g.ngrams, g.acc_val, '.-', label=label, lw=2, color='black')
            else:
                plt.plot(g.ngrams, g.acc_val, '.-', label=label, lw=1)
        plt.ylabel(f'Accuracy ({data.DSETS_RENAME_DICT.get(dset)})')
        plt.xlabel('N-gram size')
        plt.legend()
#         dvu.line_legend(fontsize=12)

    plt.tight_layout()
    plt.savefig(f'results/acc_by_ngrams_{dset}.pdf')
#     plt.show()    
    #     print(curve)

In [None]:
plt.figure(figsize=(12, 4))
for i, dset in enumerate(['financial_phrasebank', 'rotten_tomatoes', 'sst2']):
    r = rs[rs.dataset == dset]
    plt.subplot(1, 3, i + 1)
    plt.title(f'{data.DSETS_RENAME_DICT.get(dset)}', fontsize='large')
    d = r[r.subsample == subsample]

    curve = sorted(d.groupby(['checkpoint', 'all', 'norm']),
                   key=lambda x: data.COLUMNS_RENAME_DICT.get(x[0][0], 'BERT finetuned'))
    curve.append(curve.pop(0)) # move BERT to bottom
    for key, group in curve:
        g = group.sort_values('ngrams')
        label = data.COLUMNS_RENAME_DICT.get(key[0], 'BERT finetuned')
        if label == 'BERT finetuned':
            plt.plot(g.ngrams, g.acc_val, '.-', label=label, lw=2.5, color='black')
        else:
            plt.plot(g.ngrams, g.acc_val, '.-', label=label, lw=1.5)
    plt.ylabel(f'Accuracy', fontsize='large')
    plt.xlabel('Ngram size', fontsize='large')
#     plt.legend()
#         dvu.line_legend(fontsize=12)

    plt.tight_layout()
plt.legend(labelcolor='linecolor', fontsize='large')
# plt.savefig(f'results/acc_by_ngrams_full.pdf')
#     plt.show()    
    #     print(curve)

## save best models

In [None]:
import embgam
df = deepcopy(rs)
df.checkpoint_clean = df.checkpoint.apply(rename_checkpoint)
df = df[(df.checkpoint_clean == 'bert-finetuned') * (df.subsample == -1) * (df['all'] == 'all') * (df['norm'] == '')]
idx = df.groupby('dataset')['acc_val'].transform(max) == df['acc_val']
df = df[idx]

In [None]:
def print_fname(row):
    out_dir_name = embgam.data.get_dir_name(
        row, seed=row.seed)
    save_dir = oj(config.results_dir, row.dataset, out_dir_name)
    save_dir += '-all'
    results_file = save_dir + '/results.pkl'
    out_name = 'results/best___' + os.path.basename(save_dir) + '.pkl'
    print('cp', results_file, out_name, '\n')
for i in range(df.shape[0]):
    print_fname(df.iloc[i])