# Edge Probing Extract Results

In [1]:
import sys, os, re, json
import itertools
import collections
from importlib import reload
import pandas as pd
import numpy as np
from sklearn import metrics

In [2]:
import datetime
def get_compact_timestamp():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d.%H%M%S")

In [3]:
import bokeh
import bokeh.plotting as bp
bokeh.io.output_notebook()

The latest runs are here:

In [4]:
top_expt_dir = "/nfs/jsalt/home/iftenney/exp/edges-20180902/"
all_expt_dirs = os.listdir(top_expt_dir)
all_expt_dirs

['cove-edges-constituent-ontonotes',
 'glove-edges-constituent-ontonotes',
 'cove-edges-coref-ontonotes-conll',
 'glove-edges-coref-ontonotes-conll',
 'cove-edges-dep-labeling-ewt',
 'glove-edges-dep-labeling-ewt',
 'elmo-chars-edges-constituent-ontonotes',
 'elmo-full-edges-constituent-ontonotes',
 'elmo-chars-edges-coref-ontonotes-conll',
 'elmo-full-edges-coref-ontonotes-conll',
 'elmo-chars-edges-dep-labeling-ewt',
 'elmo-full-edges-dep-labeling-ewt',
 'cove-edges-dpr',
 'cove-edges-ner-ontonotes',
 'cove-edges-spr2',
 'cove-edges-srl-conll2012',
 'elmo-chars-edges-dpr',
 'elmo-chars-edges-ner-ontonotes',
 'elmo-chars-edges-spr2',
 'elmo-chars-edges-srl-conll2012',
 'elmo-full-edges-dpr',
 'elmo-full-edges-ner-ontonotes',
 'elmo-full-edges-spr2',
 'elmo-full-edges-srl-conll2012',
 'glove-edges-dpr',
 'glove-edges-ner-ontonotes',
 'glove-edges-spr2',
 'glove-edges-srl-conll2012']

In [5]:
task_names = {re.sub(r"^(elmo-\w+-)|(train-(full-|chars-)?)|((cove|glove)-)", "", s) 
              for s in all_expt_dirs}
task_names

{'edges-constituent-ontonotes',
 'edges-coref-ontonotes-conll',
 'edges-dep-labeling-ewt',
 'edges-dpr',
 'edges-ner-ontonotes',
 'edges-spr2',
 'edges-srl-conll2012'}

In [6]:
prefixes = {re.sub("-edges-.*", "-", s) for s in all_expt_dirs}
prefixes

{'cove-', 'elmo-chars-', 'elmo-full-', 'glove-'}

In [7]:
from IPython.display import display
import ipywidgets as widgets

In [8]:
def get_log(prefix, task_name, run_name="run"):
    log_path = os.path.join(top_expt_dir, prefix + task_name,
                            run_name, "log.log")
    if not os.path.exists(log_path):
        log_path = os.path.join(top_expt_dir, prefix + task_name,
                                run_name + "_seed_0", "log.log")
    with open(log_path) as fd:
        return list(fd)

def get_results(prefix, task_name):
    results_path = os.path.join(top_expt_dir, prefix + task_name,
                                "results.tsv")
    with open(results_path) as fd:
        return "\n".join(fd).replace("\t", "\n").replace(", ", "\n")
    
def print_info(prefix, task_name):
    lines = get_log(prefix, task_name)
    for line in lines:
        if line.startswith(f"Trained {task_name} for"):
            print(line)
    print(get_results(prefix, task_name))
    

# Make a stupid little GUI
widgets.interact(print_info, prefix=prefixes,
                 task_name=task_names)

interactive(children=(Dropdown(description='prefix', options=('glove-', 'cove-', 'elmo-full-', 'elmo-chars-'),…

<function __main__.print_info(prefix, task_name)>

In [9]:
def format_info(prefix, task_name, sep=","):
    for line in get_log(prefix, task_name):
        if line.startswith(f"Trained {task_name} for"):
            m = re.match(r"Trained [\w-]+ for (\d+) batches or (.+) epochs\w*", line)
            print(m.group(1), sep, m.group(2), end=sep)
            break
    for line in get_results(prefix, task_name).split("\n"):
        if line.startswith(f"{task_name}_mcc: "):
            print(line.replace(f"{task_name}_mcc: ", ""), end=sep)
        if line.startswith(f"{task_name}_f1: "):
            print(line.replace(f"{task_name}_f1: ", ""))

def print_table(prefix):
    sep=", "
    assert prefix in prefixes
    for task_name in sorted(task_names):
        if task_name == "junk":
            continue
        print(task_name, end=sep)
        try:
            format_info(prefix, task_name, sep=sep)
        except IOError:
            print("<run not found>")
            
# prefix = "elmo-ortho-"
widgets.interact(print_table, prefix=prefixes)

interactive(children=(Dropdown(description='prefix', options=('glove-', 'cove-', 'elmo-full-', 'elmo-chars-'),…

<function __main__.print_table(prefix)>

# Multi-Way Bar Plots

In [None]:
import datetime
import socket
def get_compact_timestamp():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d.%H%M%S")

In [None]:
import analysis
reload(analysis)

task_name = "edges-srl-conll2005"
# task_name = "edges-spr2"
# task_name = "edges-coref-ontonotes"
# task_name = "edges-dep-labeling"
# task_name = "edges-ner-conll2003"

runs_by_name = collections.OrderedDict()
# for prefix in prefixes:
for prefix in ["elmo-chars-", "elmo-ortho-", "elmo-full-", "train-", "train-full-"]:
    run_name = "run_seed_0" if prefix == "elmo-ortho-" else "run"
    run_path = os.path.join(top_expt_dir, prefix + task_name, run_name)
    if not os.path.exists(run_path):
        print("Path %s does not exist, skipping." % run_path)
        continue
    name = prefix.strip('-').replace("elmo-", "")
    if name == "ortho":
        name = "random-ortho"
    runs_by_name[name] = analysis.Predictions.from_run(run_path, task_name, "val")

In [None]:
reload(analysis)
if task_name.startswith("edges-srl"):
    # Filter out references and continuations, because these are mostly noise.
    label_filter = lambda label: not (label.startswith("R-") or label.startswith("C-"))
#                                       or label in ["AM-PRD", "AM-REC", "AM-TM"])
else:
    label_filter = lambda label: True
mc = analysis.MultiComparison(runs_by_name, label_filter=label_filter)

# SORT_FIELD="f1_score"
SORT_FIELD = "label"
# SORT_FIELD = "true_count"
SORT_RUN="full"
cmap = {'chars': "#93C47D",
        'random-ortho': "#B4A7D6",
        'full' : "#6D9EEB",
        'train': "#F46D43",
        'train-full': "#D53E4F"}
p = mc.plot_scores(task_name, sort_field=SORT_FIELD, sort_run=SORT_RUN, row_height=450,
                   sort_ascending=(SORT_FIELD == 'label'), cmap=cmap)

# comp = analysis.Comparison(base=runs_by_name['chars'],
#                            expt=runs_by_name['full'])
# p = comp.plot_scores(task_name, sort_field=SORT_FIELD, row_height=350,
#                      sort_ascending=(SORT_FIELD == 'label'), palette=[cmap['chars'], cmap['full']])

# Save plot to bucket
now = get_compact_timestamp()
key_string = ".".join(runs_by_name.keys())
fname = f"chart.{task_name}.{key_string}.{now:s}.html"
hostname = socket.gethostname()
title = f"{task_name}"
bp.save(p, os.path.join("/tmp", fname), title=title, resources=bokeh.resources.CDN)
!gsutil cp /tmp/$fname gs://jsalt-scratch/$hostname/plots/$fname
!gsutil acl ch -u AllUsers:R gs://jsalt-scratch/$hostname/plots/$fname
print(f"Public URL: https://storage.googleapis.com/jsalt-scratch/{hostname}/plots/{fname}")

bp.show(p)

## Look at examples for particular labels

In [None]:
base = runs_by_name['random-ortho']
expt = runs_by_name['full']

mdf = pd.merge(base.target_df_long, expt.target_df_long, 
               on=['idx', 'span1', 'span2', 'label'], suffixes=(".base", ".expt"))
mdf.head()

In [None]:
reload(analysis)

def show_target(row):
    r = {}
    r['text'] = row['text']
    r['targets'] = [
        {'span1': row['span1'], 'span2': row.get('span2', []), 
         'label':["base", "{:.02f}".format(row['preds.proba.base'])]},
        {'span1': row['span1'], 'span2': row.get('span2', []), 
         'label':["expt", "{:.02f}".format(row['preds.proba.expt'])]},
    ]
    print(analysis.EdgeProbingExample(r))

_mask = mdf['label'] == "AM-LOC"
_mask &= mdf['label.true.base']
# _mask &= (mdf['preds.proba.base'] >= 0.5) != (mdf['preds.proba.expt'] >= 0.5)

_selected_df = mdf[_mask].copy()
_selected_df['text'] = [base.example_df.loc[i, 'text'] for i in _selected_df['idx']]

def print_info(i):
    show_target(_selected_df.iloc[i])

widgets.interact(print_info, i={t:i for i,t in enumerate(_selected_df['text'])})

## Confusion Matricies

In [None]:
run = runs_by_name['chars']
# labels = [l for l in run.all_labels if l.startswith("AM-")]
labels = run.all_labels
N = len(labels)
pmat = np.zeros((N,N), dtype=np.float32)
for i, li in enumerate(labels):
#     ti = run.target_df_wide['preds.proba.' + li] >= 0.5
    ti = run.target_df_wide['label.true.' + li]
    for j, lj in enumerate(labels):
        pj = run.target_df_wide['preds.proba.' + lj]
        pmat[i,j] = metrics.matthews_corrcoef(ti, pj >= 0.5)

print(pmat.shape)

In [None]:
factors = labels
p = bp.figure(title="Per-label cross correlation (MCC)",
              tooltips=[("value", "@image{.00}")],
              tools="hover", x_range=factors, y_range=factors)
p.xaxis.major_label_orientation = np.pi/2

N = len(factors)
xx, yy = np.indices((len(factors), len(factors)))
# p.rect(xx.flat, yy.flat, color=pmat.flat, width=1, height=1)
# palette=bokeh.palettes.Spectral11
cmap = bokeh.models.LinearColorMapper(palette="RdYlBu11", low=-1, high=1)
p.image(image=[pmat], x=0, y=0, dw=N, dh=N, color_mapper=cmap)

colorbar = bokeh.models.ColorBar(color_mapper=cmap, location=(0,0))
p.add_layout(colorbar, 'right')
bp.show(p)