In [None]:
"""
marginalize over the datasets, and the ground truth positions
"""
%matplotlib inline

import time
import json
import shutil

import h5py
import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics.cluster

import torch

def create_ami_matrix(gnd, pred):
    """
    assumptions:
    - gnd and pred are 2d matrices
    - gnd is [U_gnd][N]
    - pred is [U_pred][N]
    """
    U_gnd = gnd.size(0)
    U_pred = pred.size(0)

    ami_matrix = torch.zeros(U_pred, U_gnd)
    for i in range(U_pred):
        for j in range(U_gnd):
            AMI = sklearn.metrics.cluster.adjusted_mutual_info_score(
                labels_true=gnd[j].numpy(),
                labels_pred=pred[i].numpy()
            )
            ami_matrix[i, j] = AMI
    return ami_matrix

def run(filename):
    shutil.copyfile(f'../{filename}', '/tmp/foo.h5')
    f = h5py.File('/tmp/foo.h5', 'r')

    hypotheses_train = torch.from_numpy(f['hypotheses_train'][:])
    hypotheses_gnd_train = torch.from_numpy(f['gnd_hypotheses_train'][:].astype(np.uint8)).long()
    dsrefs_train = torch.from_numpy(f['dsrefs_train'][:].astype(np.uint8)).long()
    resdicts = f['resdicts']

    meta = json.loads(f['meta'][0])
    ref = meta['ref']
    print(ref)
    dsrefs = meta['ds_refs']
    
    U_gnd = hypotheses_gnd_train.size(0)
    U_pred = hypotheses_train.size(0)

    N = dsrefs_train.size(0)

    render_start_id = 0
    render_end_id_excl = len(resdicts)
    if render_end_id_excl > 32:
        render_start_id = render_end_id_excl - 32

    batch_size = 128

    num_dsrefs = dsrefs_train.max().item() + 1
    for dsref in range(num_dsrefs):
        num_renders = render_end_id_excl - render_start_id
        amis = torch.zeros((num_renders, U_pred, U_gnd), dtype=torch.float32)
        for render_id in range(render_start_id, render_end_id_excl):
            b_start = render_id * batch_size
            b_end = b_start + batch_size
            _hypotheses_train = hypotheses_train[:, b_start:b_end]
            _hypotheses_gnd_train = hypotheses_gnd_train[:, b_start:b_end]
            _dsrefs_train = dsrefs_train[b_start:b_end]

            dsref_idxes = (_dsrefs_train == dsref).view(-1).nonzero().view(-1).long()
            _hypotheses_train = _hypotheses_train[:, dsref_idxes]
            _hypotheses_gnd_train = _hypotheses_gnd_train[:, dsref_idxes]
            
            _, _pred = _hypotheses_train.max(dim=-1)
            ami_matrix = create_ami_matrix(
                gnd=_hypotheses_gnd_train,
                pred=_pred
            )
            amis[render_id - render_start_id] = ami_matrix

        resdict_start = json.loads(resdicts[render_start_id])
        episode_start = resdict_start['episode']
        resdict_final = json.loads(resdicts[-1])
        episode_final = resdict_final['episode']

        gnd_utt_len = 1
        if 'things' in dsrefs[dsref]:
            gnd_utt_len = 2
        elif 'rels' in dsrefs[dsref]:
            gnd_utt_len = 5
        for u_gnd in range(gnd_utt_len):
            ami = amis[:, :, u_gnd]

    plt.figure(figsize=(30, 0.5))
    plt.cla()
    plt.imshow(
        ami.transpose(0, 1).numpy(),
        extent=[episode_start, episode_final, 0, U_pred],
        interpolation='none',
        vmin=0,
        vmax=1
    )
    plt.title(
        f'{ref} episodes={episode_final} dsref={dsrefs[dsref]} u_gnd={u_gnd}'
        f' min={ami.min().item():.3f} max={ami.max().item():.3f}')
    plt.show()
    f.close()

filenames = [
    'hists/hypprop_with_predictor_eeb99_hughtok1_20180921_191657.h5',
]

filenames = """
../hists/hypprop_eec59_hughtok1_20180929_150429.h5
../hists/hypprop_eec60_hughtok2_20180929_150511.h5
../hists/hypprop_eec61_hughtok33_20180929_150540.h5
"""

filenames = filenames.split('\n')
filenames = [f.replace('../', '') for f in filenames if f != '']

for filename in filenames:
    run(filename)