In [None]:
# %matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib
import matplotlib.pyplot as plt

import glob
import itertools
import os

import sys
sys.path.append('..')

from run import Run
from run_db import db as run_db
from run_utils import init_runs, find_runs

tasks = ['dep', 'lmo', 'ner', 'pos']
langs = ['cs', 'de', 'en', 'es']

log_path = '/home/fiit/logs/'

runs = init_runs(log_path, run_db)

In [None]:
def draw_graphs(runs, tasks, langs, role, metric=None, focused=False, label=None):
    
    if label is None:
        label = lambda run: f'{run.name}-{run.type}'
    
    fig, axes = plt.subplots(
        len(tasks),
        len(langs),
        figsize=(5*len(langs), 4*len(tasks)),
        squeeze=False)

    for ax_row, task in zip(axes, tasks):
        for ax, lang in zip(ax_row, langs):
            for run in runs:
                if not focused or run.config['focus_on'] == f'{task}-{lang}':
                    history = run.history(
                        metric=metric,
                        task=task,
                        language=lang,
                        role=role)
                    ax.plot(list(history), label=label(run))

    for ax, lang in zip(axes[0], langs):
        ax.set_title(lang)

    for ax, task in zip(axes[:, 0], tasks):
        ax.set_ylabel(task, rotation=0, size='large')
        
    for ax_row in axes:
        for ax in ax_row:
            ax.legend()

    plt.show()

    
def results(runs, tasks, langs, focused=True, values_only=False):
    
    def get_results(run):
        out = {}
        for task, lang in itertools.product(tasks, langs):
            run.load()
            if not focused or run.config['focus_on'] == f'{task}-{lang}':
                res, epoch = run.best(task=task, language=lang)
                if res <= 1.01:
                    res *= 100
                out[(task, lang)] = res, epoch
        return out
    
    out = {
        run: get_results(run)
        for run in runs}
    
    if values_only:
        return [list(run.values())[0][0] for run in out.values()]
    else:
        return out

# tmp_runs = find_runs(runs, name='zero-shot-task-lang-ortho-10')
# res = results(tmp_runs, tasks, langs, focused=True)
# print(list(res.values()))

In [None]:
import numpy as np

def tempo(name, type='all'):    
    out = results(
        runs=find_runs(
            runs,
            type=type,
            name=name
        ),
        tasks=tasks,
        langs=langs,
        values_only=True
    )
    for value in out:
        print(f'{value:.2f} ', end='')
    print(np.round([np.mean(out[:4]), np.mean(out[4:8]), np.mean(out[8:])], decimals=2))


In [None]:
two_by_two = results(
    runs=find_runs(
        runs,
        type='var',
        name='zero-shot-two-by-two'
    ),
    tasks=tasks,
    langs=langs,
    values_only=True
)

adv = results(
    runs=find_runs(
        runs,
        type='var',
        name='zero-shot-two-by-two-adversarial'
    ),
    tasks=tasks,
    langs=langs,
    values_only=True
)

for i, val in enumerate(two_by_two):
    print(f'{val:.2f}', end=' ')
    if i % 6 == 5:
        print(f'{adv[i // 6]:.2f}')

In [None]:
tempo('zero-shot', 'ml-3')
tempo('zero-shot', 'rel')
tempo('zero-shot')
tempo('zero-shot-400')
print()
tempo('zero-shot')
tempo('zero-shot-task')
tempo('zero-shot-lang')
tempo('zero-shot-task-lang-no-global')
tempo('zero-shot-task-lang')
print()
tempo('zero-shot-task-lang')
tempo('zero-shot-task-lang-ortho-10')
tempo('zero-shot-task-lang-ortho-25')
tempo('zero-shot-task-lang-ortho-50')
tempo('zero-shot-task-lang-ortho-100')
tempo('zero-shot-task-lang-ortho-200')
print()
tempo('zero-shot-task-lang-no-global')
tempo('zero-shot-task-lang-no-global-ortho-50')
tempo('zero-shot-task-lang-no-global-ortho-100')
tempo('zero-shot-task-lang-no-global-ortho-200')
print()
tempo('zero-shot')
tempo('zero-shot-task-emb')
tempo('zero-shot-lang-emb')
tempo('zero-shot-embs')
tempo('zero-shot-embs-400')
tempo('zero-shot-task-lang-both-embs')
print()
tempo('zero-shot', 'ml-3')
tempo('zero-shot')
tempo('zero-shot-embs')
tempo('zero-shot-task-lang')
print()
tempo('zero-shot-rotated', 'ml-3')
tempo('zero-shot-rotated')
tempo('zero-shot-embs-rotated')
tempo('zero-shot-task-lang-rotated')
print()
tempo('zero-shot-char-level', 'ml-3')
tempo('zero-shot-char-level')
tempo('zero-shot-embs-char-level')
tempo('zero-shot-task-lang-char-level')
print()
tempo('zero-shot')
tempo('zero-shot-adversarial')
tempo('zero-shot-adversarial-embs')
tempo('zero-shot-adversarial-task-lang')

In [None]:
tempo('zero-shot')
tempo('zero-shot-adversarial')
tempo('zero-shot-embs')
tempo('zero-shot-task-lang')
print()
tempo('zero-shot-limited-task-200')
tempo('zero-shot-adversarial-limited-task-200')
tempo('zero-shot-embs-limited-task-200')
tempo('zero-shot-task-lang-limited-task-200')
print()
tempo('zero-shot-limited-lang-200')
tempo('zero-shot-adversarial-limited-lang-200')
tempo('zero-shot-embs-limited-lang-200')
tempo('zero-shot-task-lang-limited-lang-200')


In [None]:
tmp = find_runs(runs, name='zero-shot')
print(tmp[0].data[0])
draw_graphs(tmp, tasks, langs, 'test', metric='loss', focused=True)
