### Import Packages

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.font_manager
from matplotlib.font_manager import findfont, FontProperties

In [None]:
# Change this line to your own path
# Remember to run this line before any execution of the cells!
os.environ['BASE_DIR'] = '/path/to/base_dir'

In [None]:
from constants import BASE_DIR, DATA_DIR
from helper_utils.helper_methods import list_datasets_and_their_splits, list_hardcode_datasets_and_their_splits

In [None]:
# Set plotting parameters
font = {'family' : 'serif',
        # 'weight' : 'bold',
        'size'   : 11}
mpl.rcParams['figure.dpi'] = 600
mpl.rc('font', **font)
mpl.rc('xtick', labelsize=11) 
plt.rcParams["font.family"] = "Nimbus Roman"
mpl.rc('ytick', labelsize=11) 

### Compute Dataset Stat

Compute the basic stastics of different datasets, including number of instances, raw sequence length, lengths after tokenization, overlap between a subsample of instances.

Documentation of each function can be found via `dataset_stat.py`. The functions in this section assumes access to datasets contain in `data/` dir of the BASE_DIR.

In [None]:
from dataset_stat import build_table_for_all_datasets, compute_sample_overlap_all_datasets

In [None]:
# Outputs a PrettyTable for the data type inputed; Data type = str in {num_instances, raw_avg_length, tok_seq_length, lexical_overlap}
# sub_datatype = {input, output}, used when computing raw_avg_length and tok_seq_length.
# model name = {HF model names with FastTokenizers}, default to t5-base. Can be facebook/bart-base; used only for computing seq length after tokenization.
print("Computing number of instances")
build_table_for_all_datasets("num_instances")

In [None]:
print("Computing average length of each dataset")
build_table_for_all_datasets('raw_avg_length', sub_datatype='input')

In [None]:
print("Computing average length of each dataset after tokenization")
build_table_for_all_datasets('raw_avg_length', sub_datatype='input', model_name='t5-base')

In [None]:
res = compute_sample_overlap_all_datasets(lex_type="Lev")

### Plot the training curve

Plot the training curve of HF models. This section assumes access to `trained_models/` in BASE_DIR. Each model dir should include a `trainer_state.json`.

In [None]:
from analysis_utils import load_training_curve_info

In [None]:
# Compute Training Curve Info
dataset_name = 'geoquery'
split = 'standard'
model_name = 't5-base'
dataset_names, splits_mapping = list_hardcode_datasets_and_their_splits()

steps, ems, best_em = load_training_curve_info(model_name, dataset_name, split)

In [None]:
def plot_training_curve(steps, ems, best_em=-1):
    fig, ax = plt.subplots(nrows=1, ncols=1)
    ax.plot(steps, ems, label=model_name)
    if best_em != -1:
        ax.plot([steps[0], steps[-1]], [best_em, best_em], label="Best EM")

    ax.set_xlabel('Steps')
    ax.set_ylabel('EM')

In [None]:
plot_training_curve(steps, ems, best_em)

#### Plot training curve for all splits

In [None]:
import importlib
import analysis_utils
importlib.reload(analysis_utils)

In [None]:
colors = ['#D81B60', '#999999', '#7570B3', '#E66100', '#7570B3', '#E66100']

In [None]:
# Plot training curve for all splits
dataset_name = 'geoquery'
fig, ax = plt.subplots(nrows=1, ncols=1)
splits = []
for idx, split in enumerate(splits_mapping[dataset_name]):
    # steps, ems, best_em = analysis_utils.load_training_curve_info(model_name, dataset_name, split, checkpoint=None)
    steps, ems, best_em = analysis_utils.load_avg_training_curve_info(model_name, dataset_name, split, checkpoint=None)
    # if split == 'standard':
    #     steps = steps[:638]
    #     ems = ems[:638]
    ax.plot(steps, ems, label=split, color=colors[idx], alpha=0.9, linewidth=2.0)
    splits.append(split)
    
ax.set_xlabel('Steps')
ax.set_ylabel('EM')
ax.grid(alpha=0.4)
f = lambda m,c: plt.plot([],[],marker=m, color=c, ls="none")[0]
handles = [f('s', colors[idx]) for idx in range(len(splits))]
labels = [split for split in splits]
ax.legend(handles, labels)
fig.suptitle("Training Curve of " + model_name + " on " + dataset_name)
plt.savefig(f"{BASE_DIR}/results/analysis_res/{model_name}-{dataset_name}.pdf", format='pdf', bbox_inches="tight")
    

### Evaluation

Evaluates the models and save to csv files. This section assumes access to `pred/` dir, which includes `.txt` files of model predictions.

In [None]:
from evaluate_utils import evaluate_model, evaluate_all_model_for_dataset, evaluate_all, gen_performance_table

In [None]:
evaluate_model(dataset_name='geoquery', split='standard', model_name='t5-base', random_seed='42', eval_split='test')

In [None]:
# Takes a long time, do not run
res = evaluate_all_model_for_dataset(dataset_name='geoquery')

In [None]:
# Evaluate all models on all datasets, output will include different random seeds and avg/std
# Takes a long time, do not run unless  
res = evaluate_all()
res.to_csv(os.getenv('BASE_DIR') + '/results/exact_match.csv')

In [None]:
# Generate performance table, in which the numbers are averaged across random seeds
res_table = pd.read_csv(os.getenv('BASE_DIR') + '/results/exact_match.csv')
columns_to_keep = ['raw_exact_match', 'ignore_space', 'f1']
res = gen_performance_table(columns_to_keep, res_table)
res.to_csv(os.getenv('BASE_DIR') + '/results/perf_table.csv')

### Compute Concurrence

This section can be ran without access to `data/`. Make sure to include `results/exact_match.csv` to execute it.

In [None]:
# Takes a long time to run, can just use the pre-computed performance table
# cogs_perf = evaluate_all_model_for_dataset('COGS')

In [None]:
perf_table = pd.read_csv(BASE_DIR + '/results/exact_match.csv')

In [None]:
from importlib import reload
import analysis_utils
reload(analysis_utils)
from analysis_utils import compute_concurrence, compute_concurr_all

In [None]:
# Sanity check
print(compute_concurrence(perf_table, "COGS", "COGS", "random_cvcv", "random_cvcv"))
print(compute_concurrence(perf_table, "geoquery", "geoquery", "tmcd_random_cvcv", "tmcd_random_cvcv"))

In [None]:
# More sanity checks
print(compute_concurrence(perf_table, "COGS", "SCAN", "random_cvcv", "addprim_jump"))
print(compute_concurrence(perf_table, "COGS", "SCAN", "no_mod", "addprim_jump"))
print(compute_concurrence(perf_table, "COGS", "geoquery", "no_mod", "standard"))

In [None]:
# Compute concurrence between all datasets and their splits
# Use metric_type to indicate the metric to compute concurrence
# There will be some None entries, because the training are not done for all models.
concurrences = compute_concurr_all(metric_type='ignore_space', corref='Kendall')

In [None]:
# Save the concurrences to file
concurrences.to_csv(os.getenv('BASE_DIR') + '/results/Kendall_concurrences.csv')