In [1]:
# %matplotlib inline
%load_ext autoreload
%autoreload 2

In [18]:
import matplotlib
import matplotlib.pyplot as plt

import glob
import itertools
import os

import sys
sys.path.append('..')

from run import Run
from run_db import db as run_db
from run_utils import init_runs, find_runs

tasks = ['dep', 'lmo', 'ner', 'pos']
langs = ['cs', 'de', 'en', 'es']

log_path = '/home/fiit/logs/'

runs = init_runs(log_path, run_db)

In [11]:
def draw_graphs(runs, tasks, langs, role, metric=None, focused=False, label=None):
    
    if label is None:
        label = lambda run: f'{run.name}-{run.type}'
    
    fig, axes = plt.subplots(
        len(tasks),
        len(langs),
        figsize=(5*len(langs), 4*len(tasks)),
        squeeze=False)

    for ax_row, task in zip(axes, tasks):
        for ax, lang in zip(ax_row, langs):
            for run in runs:
                if not focused or run.config['focus_on'] == f'{task}-{lang}':
                    history = run.history(
                        metric=metric,
                        task=task,
                        language=lang,
                        role=role)
                    ax.plot(list(history), label=label(run))

    for ax, lang in zip(axes[0], langs):
        ax.set_title(lang)

    for ax, task in zip(axes[:, 0], tasks):
        ax.set_ylabel(task, rotation=0, size='large')
        
    for ax_row in axes:
        for ax in ax_row:
            ax.legend()

    plt.show()

    
def results(runs, tasks, langs, focused=True, values_only=False):
    
    def get_results(run):
        out = {}
        for task, lang in itertools.product(tasks, langs):
            run.load()
            if not focused or run.config['focus_on'] == f'{task}-{lang}':
                res, epoch = run.metric_eval(task=task, language=lang)
                if res <= 1.01:
                    res *= 100
                out[(task, lang)] = res, epoch
        return out
    
    out = {
        run: get_results(run)
        for run in runs}
    
    if values_only:
        return [list(run.values())[0][0] for run in out.values()]
    else:
        return out

# tmp_runs = find_runs(runs, name='zero-shot-task-lang-ortho-10')
# res = results(tmp_runs, tasks, langs, focused=True)
# print(list(res.values()))

In [12]:
import numpy as np

def tempo(name, type='all'):    
    out = results(
        runs=find_runs(
            runs,
            type=type,
            name=name
        ),
        tasks=tasks,
        langs=langs,
        values_only=True
    )
    for value in out:
        print(f'{value:.2f} ', end='')
    print(np.round([np.mean(out[:4]), np.mean(out[4:8]), np.mean(out[8:])], decimals=2))


In [13]:
two_by_two = results(
    runs=find_runs(
        runs,
        type='var',
        name='zero-shot-two-by-two'
    ),
    tasks=tasks,
    langs=langs,
    values_only=True
)

adv = results(
    runs=find_runs(
        runs,
        type='var',
        name='zero-shot-two-by-two-adversarial'
    ),
    tasks=tasks,
    langs=langs,
    values_only=True
)

for i, val in enumerate(two_by_two):
    print(f'{val:.2f}', end=' ')
    if i % 6 == 5:
        print(f'{adv[i // 6]:.2f}')

49.37 57.57 48.39 81.51 84.35 86.67 73.18
60.08 66.73 61.53 67.83 68.99 64.72 74.82
50.60 52.56 47.21 54.21 56.60 43.19 66.19
22.65 23.97 22.60 24.86 27.45 25.11 37.19
46.28 47.66 47.70 47.37 46.99 47.68 48.10
20.13 16.41 20.04 26.26 35.77 29.42 47.25
22.70 22.05 21.74 23.67 33.09 21.93 40.75
44.87 41.23 42.86 45.09 45.47 49.44 47.23
38.92 42.72 39.72 48.34 48.79 30.72 58.35
62.04 67.09 62.05 64.51 67.51 65.12 71.72
30.36 32.69 30.28 47.93 46.32 22.29 51.61
56.61 64.42 62.70 66.67 66.54 63.83 71.50
58.53 63.15 59.64 78.45 80.83 81.18 80.64
22.37 20.42 22.53 24.13 31.67 16.91 53.89
31.31 35.42 30.60 37.00 52.58 49.25 52.51
31.08 27.65 31.58 27.37 37.67 33.11 42.51
44.00 43.05 42.02 45.82 46.98 43.64 44.64
62.49 69.67 59.84 69.72 69.68 55.69 73.61
70.70 77.97 72.97 82.37 82.51 87.24 86.27
24.46 23.35 23.05 33.84 45.11 45.67 51.33


In [22]:
tempo('zero-shot', 'ml-3')
tempo('zero-shot', 'rel')
tempo('zero-shot')
tempo('zero-shot-400')
print()
tempo('zero-shot')
tempo('zero-shot-task')
tempo('zero-shot-lang')
tempo('zero-shot-task-lang-no-global')
tempo('zero-shot-task-lang')
print()
tempo('zero-shot-task-lang')
tempo('zero-shot-task-lang-ortho-10')
tempo('zero-shot-task-lang-ortho-25')
tempo('zero-shot-task-lang-ortho-50')
tempo('zero-shot-task-lang-ortho-100')
tempo('zero-shot-task-lang-ortho-200')
print()
tempo('zero-shot-task-lang-no-global')
tempo('zero-shot-task-lang-no-global-ortho-50')
tempo('zero-shot-task-lang-no-global-ortho-100')
tempo('zero-shot-task-lang-no-global-ortho-200')
print()
tempo('zero-shot')
tempo('zero-shot-task-emb')
tempo('zero-shot-lang-emb')
tempo('zero-shot-embs')
tempo('zero-shot-embs-400')
tempo('zero-shot-task-lang-both-embs')
print()
tempo('zero-shot', 'ml-3')
tempo('zero-shot')
tempo('zero-shot-embs')
tempo('zero-shot-task-lang')
print()
tempo('zero-shot-rotated', 'ml-3')
tempo('zero-shot-rotated')
tempo('zero-shot-embs-rotated')
tempo('zero-shot-task-lang-rotated')
print()
tempo('zero-shot-char-level', 'ml-3')
tempo('zero-shot-char-level')
tempo('zero-shot-embs-char-level')
tempo('zero-shot-task-lang-char-level')
print()
tempo('zero-shot')
tempo('zero-shot-adversarial')
tempo('zero-shot-adversarial-embs')

30.47 44.76 47.75 32.02 40.17 55.91 47.83 53.27 65.78 77.66 75.05 57.72 [38.75 49.3  69.05]
33.39 43.37 42.19 30.59 45.61 60.49 50.18 56.22 76.32 78.95 80.87 73.21 [37.38 53.13 77.34]
43.61 54.03 44.38 42.97 49.11 62.71 53.12 61.54 83.12 85.02 83.40 84.81 [46.25 56.62 84.09]
44.83 56.08 50.71 46.38 52.53 63.76 52.87 60.99 81.49 86.25 83.30 86.39 [49.5  57.54 84.36]

43.61 54.03 44.38 42.97 49.11 62.71 53.12 61.54 83.12 85.02 83.40 84.81 [46.25 56.62 84.09]
49.54 57.50 53.45 52.26 46.78 60.81 52.86 60.94 86.73 86.62 82.47 89.78 [53.19 55.35 86.4 ]
43.55 52.70 54.30 52.26 45.55 64.55 54.03 61.14 74.89 82.79 78.18 76.49 [50.7  56.32 78.09]
51.29 59.22 58.05 59.85 48.45 64.84 51.38 60.87 83.85 86.84 82.43 87.46 [57.1  56.38 85.14]
52.35 59.90 58.58 60.65 49.63 64.72 53.51 62.35 85.55 87.19 82.35 88.71 [57.87 57.55 85.95]

52.35 59.90 58.58 60.65 49.63 64.72 53.51 62.35 85.55 87.19 82.35 88.71 [57.87 57.55 85.95]
52.62 60.34 59.05 61.11 48.68 63.43 53.70 62.73 84.94 87.71 82.69 89.00 [58.28

In [21]:
tempo('zero-shot')
tempo('zero-shot-adversarial')
tempo('zero-shot-adversarial-embs')
tempo('zero-shot-adversarial-task-lang')

43.61 54.03 44.38 42.97 49.11 62.71 53.12 61.54 83.12 85.02 83.40 84.81 [46.25 56.62 84.09]
55.44 60.04 57.80 62.83 49.47 64.84 53.09 61.64 82.62 86.56 83.61 86.38 [59.03 57.26 84.79]
41.18 54.91 44.05 53.30 49.03 63.94 52.87 61.77 80.00 85.04 80.23 83.18 [48.36 56.9  82.11]
50.57 52.74 51.41 52.94 49.25 65.75 50.09 61.31 81.02 86.71 83.03 86.41 [51.92 56.6  84.29]


In [20]:
tempo('zero-shot')
tempo('zero-shot-adversarial')
tempo('zero-shot-embs')
tempo('zero-shot-task-lang')
print()
tempo('zero-shot-limited-task-200')
tempo('zero-shot-adversarial-limited-task-200')
tempo('zero-shot-embs-limited-task-200')
tempo('zero-shot-task-lang-limited-task-200')
print()
tempo('zero-shot-limited-lang-200')
tempo('zero-shot-adversarial-limited-lang-200')
tempo('zero-shot-embs-limited-lang-200')
tempo('zero-shot-task-lang-limited-lang-200')


43.61 54.03 44.38 42.97 49.11 62.71 53.12 61.54 83.12 85.02 83.40 84.81 [46.25 56.62 84.09]
55.44 60.04 57.80 62.83 49.47 64.84 53.09 61.64 82.62 86.56 83.61 86.38 [59.03 57.26 84.79]
54.41 60.81 58.09 60.51 50.44 64.50 53.49 61.75 82.40 85.72 84.21 87.17 [58.45 57.54 84.88]
52.35 59.90 58.58 60.65 49.63 64.72 53.51 62.35 85.55 87.19 82.35 88.71 [57.87 57.55 85.95]

40.72 48.44 43.66 40.00 45.78 56.23 48.61 53.78 78.65 82.17 83.31 82.21 [43.2  51.1  81.59]
49.81 54.56 52.13 55.22 46.07 58.93 48.33 50.12 78.40 80.81 81.00 83.29 [52.93 50.86 80.87]
46.48 51.56 51.79 53.23 44.44 56.90 50.85 52.68 78.36 82.45 83.07 81.95 [50.76 51.22 81.46]
41.47 48.86 47.09 52.06 47.31 56.52 51.64 47.46 81.94 82.56 79.40 84.78 [47.37 50.73 82.17]

41.31 49.11 45.63 47.50 40.44 49.13 52.16 51.73 80.76 81.61 83.58 80.49 [45.89 48.37 81.61]
50.23 57.92 55.10 58.59 46.40 61.05 51.44 58.40 81.56 85.01 84.19 85.74 [55.46 54.32 84.12]
45.77 54.28 53.78 53.95 42.00 46.02 50.15 51.99 78.40 81.61 82.29 80.55 [51.95

In [19]:
tempo('no-dropout', 'var')

50.02 50.85 [50.44   nan   nan]


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


In [None]:
tmp = find_runs(runs, name='zero-shot')
print(tmp[0].data[0])
draw_graphs(tmp, tasks, langs, 'test', metric='loss', focused=True)


In [16]:
oink = results(
    runs=find_runs(
        runs,
        type='var',
        name='low-resource'
    ),
    tasks=tasks,
    langs=langs,
    values_only=True
)


for i, val in enumerate(oink):
    print(f'{val:.2f}', end=' ')
    if i % 5 == 4:
        print()

89.57 90.68 92.92 91.81 92.38 
86.09 88.15 88.15 89.51 90.00 
89.34 90.58 93.15 92.38 92.30 
51.66 52.67 55.20 59.28 57.85 
56.24 64.49 64.80 63.51 61.68 
62.74 64.54 64.22 69.40 69.90 
50.15 57.01 59.87 61.39 64.67 
60.93 59.97 59.11 61.47 62.06 
58.26 64.59 68.82 67.83 70.30 
86.32 87.66 87.79 90.06 90.23 
