In [None]:
"""
how does distribution of hyp values differ by symbol?

in this one, we'll graph the distributions of the argmaxes, rather than the raw values
"""
%matplotlib inline

import time
import json
import shutil
from collections import defaultdict

import h5py
import numpy as np
import matplotlib.pyplot as plt

import torch

def run(filename):
    shutil.copyfile(f'../{filename}', '/tmp/foo.h5')
    f = h5py.File('/tmp/foo.h5', 'r')

    hypotheses_train = torch.from_numpy(f['hypotheses_train'][:])
    dsrefs_train = torch.from_numpy(f['dsrefs_train'][:].astype(np.uint8)).long()
    resdicts = f['resdicts']
    
    meta = json.loads(f['meta'][0])
    ref = meta['ref']

    U_pred = hypotheses_train.size(0)
    V = hypotheses_train.size(2)
    N = dsrefs_train.size(0)

    batch_size = 128

    sample_render_ids = 4
    num_render_ids = len(resdicts)

    render_start_id = 0
    num_samples = num_render_ids // sample_render_ids
    
    ids_to_render = []
    for i in range(num_samples):
        ids_to_render.append(num_render_ids - sample_render_ids * (num_samples - i))

    results = torch.zeros(num_samples, U_pred, dtype=torch.float32)
    episodes = []
    for i, render_id in enumerate(ids_to_render):
        resdict = json.loads(resdicts[render_id])
        episode = resdict['episode']
        episodes.append(episode)
        n_start = render_id * batch_size
        n_end = (render_id + sample_render_ids) * batch_size
        _hypotheses_train = hypotheses_train[:, n_start:n_end]

        for u in range(U_pred):
            _preds = _hypotheses_train[u, :].max(dim=-1)[1]
            total_count = _preds.size(0)
            non_zero_symbol_count = (_preds != 0).long().sum().item()
            results[i, u] = non_zero_symbol_count / total_count

    plt.figure(figsize=(6, 4))
    for u in range(U_pred):
        plt.plot(episodes, results[:, u].numpy(), label=f'utterance pos {u}')
    plt.legend()
    plt.ylim([0, 1])
    plt.xlabel('episode')
    plt.ylabel('proportion of non-zero symbol chosen')
    plt.title(f'{ref} proportion of non-zero symbol chosen, per position')
    plt.show()

    f.close()

filenames = """
../hists/hypprop_eec59_hughtok1_20180929_150429.h5
../hists/hypprop_eec60_hughtok2_20180929_150511.h5
../hists/hypprop_eec61_hughtok33_20180929_150540.h5
../hists/hypprop_eec64_hughtok4_20180930_135141.h5
../hists/hypprop_eec65_hughtok5_20180930_135510.h5
../hists/hypprop_eec66_hughtok6_20180930_141031.h5
../hists/hypprop_eec67_hughtok7_20180930_141448.h5
../hists/hypprop_eec68_hughtok8_20180930_141821.h5
"""

filenames = filenames.split('\n')
filenames = [f.replace('../', '') for f in filenames if f != '']
# filenames = [filenames[-1]]

for filename in filenames:
    run(filename)