# Training embcompr

In [None]:
# import pdb; pdb.set_trace()

CLR = {
    'blue': ['#e0f3ff', '#aadeff', '#2bb1ff', '#15587f', '#0b2c40'],
    'gold': ['#fff3dc', '#ffebc7', '#ffddab', '#b59d79', '#5C4938'],
    'red':  ['#ffd8e8', '#ff9db6', '#ff3e72', '#6B404C', '#521424'],
}

In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as ppt

from tabulate import tabulate
from tqdm import tqdm_notebook as tqdm

import pathlib

In [None]:
def _rolling_mean(a: np.array, window: int):
    assert len(a.shape) == 1

    a_prev = np.repeat(a[0], window // 2)
    a_post = np.repeat(a[-1], window // 2 - 1)

    x = np.concatenate((a_prev, a, a_post))
    v = np.ones((window, )) / window

    return np.convolve(x, v, mode='valid')

In [None]:
def load_loss(path: pathlib.Path, skip: int):
    try:
        with h5py.File(str(path), mode='r') as fd:

            train_shape, valid_shape = fd['train'].shape, fd['valid'].shape
            _, batches = fd['train'].shape
            epochs = len(np.where(fd['train'][:].sum(axis=1) > 0)[0])  # early stopping (TODO: may be solved more elegantly)

            epochs -= skip
            data_train = fd['train'][skip:epochs]
            data_valid = fd['valid'][skip:epochs]

            if not len(data_train) or not len(data_valid):
                print('skipping {} - no data available'.format(str(path)))
                return

    except OSError as e:
        # print(str(e))
        print('skipping {} - currently in use'.format(str(path)))
        return

    return data_train, data_valid, epochs

## Summary

In [None]:
def summarize(selection: str, display: bool, save: bool, skip: int = 20):

    def argmin(x: np.array, data: np.array):
        y_mean = data.mean(axis=1)
        idx_min = y_mean.argmin()
        return x[idx_min], y_mean[idx_min]

    rows = []
    for glob in pathlib.Path('..').glob(selection + '/losses.h5'):

        data = load_loss(glob, skip)
        if data is None:
            continue

        data_train, data_valid, epochs = data
        x = np.arange(skip, epochs + skip)
        exp = glob.parts[-2]

        row = (exp, )

        # add loss minima to data
        row = row + argmin(x, data_train)
        row = row + argmin(x, data_valid)

        best_epoch = row[-2]

        # find code entropy
        try:
            codes_fd = h5py.File(str(glob.parents[0] / 'codes.h5'))
            epochs = [int(key) for key in codes_fd['valid'].keys()]
            s = epochs[np.argmin([abs(best_epoch - e) for e in epochs])]
            row = row + (s, codes_fd['valid'][str(s)]['entropy'].attrs['mean'])

        except OSError as e:
            print('{} no codes.h5 found'.format(str(glob.parent.name)))
            row = row + ('-', '-')

        except KeyError:
            print('{} no "valid" dataset found in codes.h5'.format(str(glob)))
            row = row + ('-', '-')

        rows.append(row)

    headers = ['exp', 't epoch', 't loss', 'v epoch', 'v loss', 'entropy epoch', 'entropy']
    rows.sort(key=lambda t: t[4])

    assert len(rows), 'no data found'
    
    if display:
        print()
        print(tabulate(rows, headers=headers))
        print()

    if save:
        path = glob.parents[1] / 'summary.txt'
        print('writing', str(path))
        with path.open('w') as fd:
            fd.write(tabulate(rows, headers=headers, tablefmt='orgtbl'))

## Loss

In [None]:
def _loss_line_plot(ax, x, data: np.array, window: int, label_fmt: str,
                    color: str = None, show_bounds: bool = True,
                    marker = 2, baseline: float = None):

    y_min = _rolling_mean(data.min(axis=1), window)
    y_max = _rolling_mean(data.max(axis=1), window)

    if show_bounds:
        ax.plot(x, y_min, color=CLR[color][3], alpha=0.5, lw=0.6)
        ax.plot(x, y_max, color=CLR[color][3], alpha=0.5, lw=0.6)
        ax.fill_between(x, y_min, y_max, color=CLR[color][0], alpha=0.2)

    y_mean = data.mean(axis=1)
    ax.plot(x, y_mean, color=CLR[color][2])

    # marker and lines

    line_style = dict(linestyle='dashed', lw=1, alpha=0.5, color=CLR[color][3])

    patches = []
    if baseline is not None:
        ax.axhline(baseline, 0, 1, **line_style)
        patches.append(ppt.Patch(
            color='black', label='Baseline: {}'.format(baseline)))

    if show_bounds:
        ax.vlines(x[0], y_min[0] - 1, y_max[0] + 1, **line_style)
        ax.vlines(x[-1], y_min[-1] - 1, y_max[-1] + 1, **line_style)

    idx_min = y_mean.argmin()
    ax.scatter(x[idx_min], y_mean[idx_min], marker=marker, s=100, color=CLR[color][2])

    # legend

    patches.insert(0, ppt.Patch(
        color=CLR[color][2],
        label=label_fmt.format(y_mean[idx_min], x[idx_min])))

    return patches


def plot_loss(path: pathlib.Path, display: bool, save: bool, skip: int = 0, window: int = 50, baseline: float = None):

    name_exp = '/'.join(path.parts[-3:-1])
    out_dir = path.parents[0]/'images'
    out_dir.mkdir(exist_ok=True)

    # data aggregation
    data = load_loss(path, skip)
    if data is None:
        return

    data_train, data_valid, epochs = data

    if epochs < 50:
        print('not enough data:', epochs)
        return

    data_train = data_train[:epochs]
    data_valid = data_valid[:epochs]
    x = np.arange(skip, epochs + skip)

    # clip above baseline
    data_train[data_train > baseline] = baseline
    data_valid[data_valid > baseline] = baseline

    def fig_before(title: str):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.set_title(title)
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Distance Loss')
        return fig, ax

    def fig_after(fix, ax, patches, fname):
        ax.legend(handles=patches)
        if display:
            plt.show(fig)
        if save:
            for out_file in [str(out_dir/fname) + s for s in ('.png', '.svg')]:
                # print('saving to', out_file)
                fig.savefig(out_file)

        plt.close(fig)

    # plot three plots

#     fig, ax = fig_before('Training Loss ({})'.format(name_exp))
#     patch = _loss_line_plot(
#         ax, x, data_train, window, 'Min. Training: {:2.3f} (Epoch {})', color='blue', baseline=baseline)
#     fig_after(fig, ax, patch, 'loss-training')

#     fig, ax = fig_before('Validation Loss ({})'.format(name_exp))
#     patch = _loss_line_plot(
#         ax, x, data_valid, window, 'Min. Validation: {:2.3f} (Epoch {})', color='red', baseline=baseline)
#     fig_after(fig, ax, patch, 'loss-validation')

    fig, ax = fig_before('Loss ({})'.format(name_exp))
    patch1 = _loss_line_plot(
        ax, x, data_train, window,
        'Min. Training: {:2.3f} (Epoch {})', color='blue',
        show_bounds=False, marker=3)
    patch2 = _loss_line_plot(
        ax, x, data_valid, window,
        'Min. Validation: {:2.3f} (Epoch {})', color='red',
        show_bounds=False, baseline=baseline)
    fig_after(fig, ax, patch1 + patch2, 'loss')

def plot_losses(selection, display: bool = True, save: bool = True, baseline: float = None):
    for glob in pathlib.Path('..').glob(selection + 'losses.h5'):
        print('plot loss', str(glob))
        plot_loss(glob, display, save, baseline=baseline)

## Encoder Activations

### Training

In [None]:
def plot_activation_train(codefile: pathlib.Path, display: bool, save: bool):

    path = codefile.parents[0]
    fd = h5py.File(str(codefile), mode='r')
    name_exp = '/'.join(path.parts[-2:])

    try:
        group = fd['train']
    except KeyError:
        print('no "train" datagroup: only "{}"'.format(str(list(fd.keys()))))
        return

    out_dir = path / 'images'
    out_dir.mkdir(exist_ok=True)

    def fig_before(title: str):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.set_title(title)
        return fig, ax

    def fig_after(fig, ax, fname):
        if display:
            plt.show(fig)
        if save:
            for out_file in [str(out_dir/fname) + s for s in ('.png', '.svg')]:
                # print('saving to', out_file)
                fig.savefig(out_file)

        plt.close(fig)

    for epoch in tqdm(sorted([int(key) for key in group.keys()])):
        hist = group[str(epoch)]['histogram']

        bin_edges = hist.attrs['bin_edges']
        x = [(a + b) / 2 for a, b in zip(bin_edges, bin_edges[1:])]
        y = hist[:]

        # pie plot

        fig, ax = fig_before('Encoder Activations (Epoch {}) \n ({})\n'.format(
            epoch, name_exp))

        l0 = 'x < {:2.2f}'.format(bin_edges[2])
        l1 = 'other'.format(bin_edges[1], bin_edges[-3])
        l2 = '{:2.2f} < x'.format(bin_edges[-2])
        sizes = y[0], sum(y[1:-1]), y[-1]

        colors = CLR['blue'], CLR['red'], CLR['gold']

        wp_outer = dict(wedgeprops=dict(width=0.3, edgecolor='w'))
        wp_inner = dict(wedgeprops=dict(width=0.1, edgecolor='w'))

        pct_style = dict(pctdistance=0.4, autopct='%1.1f%%')

        kw_outer = {**wp_outer, **pct_style, 'colors': [c[2] for c in colors]}
        kw_inner = {**wp_inner, 'colors': [c[0] for c in colors]}

        ax.pie(sizes, labels=(l0, l1, l2), startangle=0, **kw_outer)
        ax.pie(sizes, radius=0.7, startangle=0, **kw_inner)

        # circle = plt.Circle((0,0), 0.8, color='w', fc=CLR['blue'][0], lw=1)
        # ax.add_artist(circle)
        plt.axis('equal')

        fig_after(fig, ax, 'encoder-activation-train_{}'.format(epoch))


def plot_activations_train(selection, display: bool = True, save: bool = True):
    for glob in pathlib.Path('..').glob(selection + 'codes.h5'):
        print('training activations', str(glob))
        plot_activation_train(glob, display, save)

### Validation

In [None]:
def _plot_activation_bar(fig, ax, arr):

    def color_gen(switch: int = 1, kind: int = 2):
        on = False
        count = 1
        while True:
            if on:
                yield CLR['blue'][kind]
            else:
                yield CLR['red'][kind]

            if count % switch == 0:
                on = not on
            count += 1

    def color_map(amount: int, switch: int):
        gen = color_gen(switch=switch)
        return [next(gen) for _ in range(amount)]

    M, K = arr.shape
    fig.set_size_inches(40, 10)

    # draw codebook separators
    line_style = dict(color='black', ls='dashed', lw=1)
    for i, bg_color in zip(range(M), color_gen(kind=0)):
        ax.axvline(x=i * (K + 1) - 1, **line_style)
        begin = i * (K + 1) - 1
        end = begin + K + 1
        ax.axvspan(begin, end, color=bg_color, alpha=0.2)

    ax.axvline(x=M * (K + 1) - 1, **line_style)

    # arr shape: (n, M, K)
    # retrieve selection per codebook (along dim=1)
    # adding [0] as seperator element between codebooks
    sums = np.array([np.concatenate((arr[i], [0])) for i in range(M)])
    ax.bar(range(len(sums.flat)), sums.flat, color=color_map(M * (K + 1), K + 1), align='edge')


def plot_activation_valid(codefile: pathlib.Path, display: bool, save: bool):

    path = codefile.parents[0]

    try:
        fd = h5py.File(str(codefile), mode='r')
        group = fd['valid']
    except OSError:
        print('skipping {}, currently in use'.format(str(path)))
        return
    except KeyError:
        print('no "train" datagroup: only "{}"'.format(str(list(fd.keys()))))
        return

    out_dir = path / 'images'
    out_dir.mkdir(exist_ok=True)

    def fig_before(title: str):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.set_title(title)
        return fig, ax

    def fig_after(fig, ax, fname):
        if display:
            plt.show(fig)
        if save:
            for out_file in [str(out_dir/fname) + s for s in ('.png', '.svg')]:
                # print('saving to', out_file)
                fig.savefig(out_file)

        fig.clear()
        plt.close(fig)

    # stupid thing to mitigate the memory leak...
    # process manually chunk-wise if end < len(epochs)
    epochs = sorted([int(key) for key in group.keys()])
    start, end = 0, 500
    if end - start < len(epochs):
        print('WARNING: only using data subset!')
        print('current epoch range: ', start, end)

    epochs = epochs[start:end]
    for i, epoch in tqdm(enumerate(epochs), total=len(epochs)):
        counts = group[str(epoch)]['counts'][:]

        fig, ax = fig_before('Encoder Activations (Epoch {})'.format(epoch))
        _plot_activation_bar(fig, ax, counts)
        fig_after(fig, ax, 'encoder-activation-valid_{}'.format(epoch))

        if i == end:
            break


def plot_activations_valid(selection, display: bool = True, save: bool = True):
    for glob in pathlib.Path('..').glob(selection + 'codes.h5'):
        print('code bars', str(glob))
        plot_activation_valid(glob, display, save)

## Play the organ

In [None]:
# note: there is a small memory leak becoming obvious when plotting many figures.
# I have no idea why that is, but it seems to come from pyplot.
# It does not help, that it also only occurs sometimes and not consistently...

experiment = 'experiments/enwiki'

# --- RAW

# embedding = 'glove', 20.17
# embedding = 'fasttext.en', 12.11
# embedding = 'fasttext.de', 11.47

# --- BOV

embedding = '**', 2

options = dict(display=True, save=True)

selection = f'opt/{experiment}/{embedding[0]}/'

summarize(selection, **options)
plot_losses(selection, baseline=embedding[1], **options)
# plot_activations_train(selection, **options)
# plot_activations_valid(selection, **options)

print('done')