In [1]:
# %matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib
import matplotlib.pyplot as plt

import glob
import itertools
import os

from run import Run
from runs_db import db as runs_db

[('deepnet5', 216), ('deepnet2070', 144)]


In [3]:
log_path = '/home/fiit/logs'
runs = []

for server in runs_db:
    paths = glob.glob(os.path.join(log_path, server, '*'))
    paths = iter(sorted(paths))

    try:
        for (number, type_, code) in runs_db[server]:
            for _ in range(number):
                try:
                    path = next(paths)
                    runs.append(Run(path, type_, code))
                except KeyError:
                    print(path)
    except StopIteration:
        pass


In [22]:
tasks = ['dep', 'lmo', 'ner', 'pos']
langs = ['cs', 'de', 'en', 'es']

task_metr = {
    'dep': 'las',
    'lmo': 'perplexity',
    'ner': 'chunk_f1',
    'pos': 'acc'
}

task_max = {
    'dep': True,
    'lmo': False,
    'ner': True,
    'pos': True
}

def draw_graphs(metric_func, tasks, langs, role, run_codes=None, run_types=None, focused=False):
    fig, axes = plt.subplots(len(tasks), len(langs), figsize=(5*len(langs), 4*len(tasks)), squeeze=False)

    relevant_runs = ['vanilla']

    for task, lang in itertools.product(tasks, langs):
        for run in runs:
            if (
                (not focused or run.config['focus_on'] == f'{task}-{lang}') and
                (run_codes is None or run.code in run_codes) and
                (run_types is None or run.type in run_types)
            ):

                history = run.history(
                    metric=metric_func(task),
                    task=task,
                    language=lang,
                    role=role)
                axes[tasks.index(task), langs.index(lang)].plot(list(history), label=f'{run.code}-{run.type}')
                #axes[tasks.index(task), langs.index(lang)].plot(list(history), label=run.type)

    for ax, col in zip(axes[0], langs):
        ax.set_title(col)

    for ax, row in zip(axes[:, 0], tasks):
        ax.set_ylabel(row, rotation=0, size='large')
        
    for ax_row in axes:
        for ax in ax_row:
            ax.legend()

    plt.show()
    
def find_runs(run_code=None, run_type=None, contains=None, **config):
    
    if contains is None:
        contains = []
    
    return (run
           for run
           in runs
           if  (run_code is None or run_code == run.code) and
               (run_type is None or run_type == run.type) and
               all(run.contains(*task_lang) for task_lang in contains) and
               all(run.config[key] == value for key, value in config.items()))
    
    
def print_results(runs, tasks, langs, metric_func=None, metric_max_func=None, focused=True):
        
    if metric_func is None:
        metric_func = lambda task: task_metr[task]
        
    if metric_max_func is None:
        metric_max_func = lambda task: task_max[task]
        
    output = []

    for run in runs:
        for task, lang in itertools.product(tasks, langs):
            if not focused or run.config['focus_on'] == f'{task}-{lang}':
                res, epoch = run.metric_eval(
                    metric=metric_func(task),
                    max_=metric_max_func(task),
                    task=task,
                    language=lang)
                if res <= 1.01:
                    res *= 100
                print(task, lang, res, epoch)
                output.append((res, epoch, run))

    return output

In [40]:
tasks=['dep']
print_results(
    runs=find_runs(
        run_type='all',
        run_code='zero-shot'
    ),
    tasks=tasks,
    langs=langs,
)
print_results(
    runs=find_runs(
        run_type='all',
        run_code='zero-shot-task-lang-both-embs'
    ),
    tasks=tasks,
    langs=langs,
)

# draw_graphs(lambda t: task_metr[t], tasks, langs, 'test', run_types=['all'], run_codes=[
#     'zero-shot-task-emb',
#     'zero-shot-lang-emb', 
#     'zero-shot-embs',
#     'zero-shot',
# ], focused=True)
# draw_graphs(lambda t: task_metr[t], tasks, langs, 'test', run_types=['stsl'], run_codes=[
#     'normal-training',
# ])

#draw_graphs(lambda t: 'ortho', tasks, langs, 'test', run_codes=[code])
#print_results(find_runs(run_code=code), [task], [lang], lambda t: task_metr[t], lambda t: task_max[t])

dep cs 43.605280994560275 6
dep de 54.034936646573996 16
dep en 44.37745316230324 13
dep es 42.974757984257664 11
dep cs 50.79800022013799 16
dep de 59.945872801082544 29
dep en 59.96843766438717 54
dep es 62.354112005790284 17


[(50.79800022013799, 16, <run.Run at 0x7f3a6229e978>),
 (59.945872801082544, 29, <run.Run at 0x7f3a63dc8a58>),
 (59.96843766438717, 54, <run.Run at 0x7f3a621e7c88>),
 (62.354112005790284, 17, <run.Run at 0x7f3a64ad4128>)]