In [None]:
import time
import json
import shutil

import h5py
import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics.cluster

import torch

def create_ami_matrix(gnd, pred):
    """
    assumptions:
    - gnd and pred are 2d matrices
    - gnd is [U_gnd][N]
    - pred is [U_pred][N]
    """
    U_gnd = gnd.size(0)
    U_pred = pred.size(0)

    ami_matrix = torch.zeros(U_pred, U_gnd)
    for i in range(U_pred):
        for j in range(U_gnd):
            AMI = sklearn.metrics.cluster.adjusted_mutual_info_score(
                labels_true=gnd[j].numpy(),
                labels_pred=pred[i].numpy()
            )
            ami_matrix[i, j] = AMI
    return ami_matrix

def run(filename):
    shutil.copyfile(f'../{filename}', '/tmp/foo.h5')
    f = h5py.File('/tmp/foo.h5', 'r')

    hypotheses_train = torch.from_numpy(f['hypotheses_train'][:])
    hypotheses_gnd_train = torch.from_numpy(f['gnd_hypotheses_train'][:].astype(np.uint8)).long()
    dsrefs_train = torch.from_numpy(f['dsrefs_train'][:].astype(np.uint8)).long()
    resdicts = f['resdicts']

    meta = json.loads(f['meta'][0])
    ref = meta['ref']
    print(ref)
#     print('meta', json.dumps(meta, indent=2))
    dsrefs = meta['ds_refs']
    
    U_gnd = hypotheses_gnd_train.size(0)
    U_pred = hypotheses_train.size(0)

    N = dsrefs_train.size(0)

    render_start_id = len(resdicts) - 1
    render_end_id_excl = len(resdicts)

    batch_size = 128

    num_dsrefs = dsrefs_train.max().item() + 1
    for dsref in range(num_dsrefs):
        num_renders = render_end_id_excl - render_start_id
        amis = torch.zeros((num_renders, U_pred, U_gnd), dtype=torch.float32)
        for render_id in range(render_start_id, render_end_id_excl):
            b_start = render_id * batch_size
            b_end = b_start + batch_size
            _hypotheses_train = hypotheses_train[:, b_start:b_end]
            _hypotheses_gnd_train = hypotheses_gnd_train[:, b_start:b_end]
            _dsrefs_train = dsrefs_train[b_start:b_end]

            dsref_idxes = (_dsrefs_train == dsref).view(-1).nonzero().view(-1).long()
            _hypotheses_train = _hypotheses_train[:, dsref_idxes]
#             _hypotheses_train.fill_(0)
#             _hypotheses_train[:, :, 0] = 1
#             _hypotheses_train.uniform_(0, 1)
            _hypotheses_gnd_train = _hypotheses_gnd_train[:, dsref_idxes]
            
            _, _pred = _hypotheses_train.max(dim=-1)
            print('_hypotheses_gnd_train', _hypotheses_gnd_train)
            print('_pred', _pred)
            ami_matrix = create_ami_matrix(
                gnd=_hypotheses_gnd_train,
                pred=_pred
            )
            print('ami_matrix', ami_matrix)
            plt.imshow(ami_matrix, vmin=0, vmax=1)
            plt.show()
            amis[render_id - render_start_id] = ami_matrix

        resdict_start = json.loads(resdicts[render_start_id])
        episode_start = resdict_start['episode']
        resdict_final = json.loads(resdicts[-1])
        episode_final = resdict_final['episode']

#         plt.figure(figsize=(10.0 * U_gnd, 0.15 * num_renders))
        for u_gnd in range(U_gnd):
            plt.figure(figsize=(30, 0.2))
            plt.cla()
#             plt.subplot(1, 5, u_gnd + 1)
            ami = amis[:, :, u_gnd]
#             print('u_gnd', u_gnd, 'ami.size()', ami.size())
            print('min', ami.min().item(), 'max', ami.max().item())
            plt.imshow(
                ami.transpose(0, 1).numpy(),
                extent=[episode_start, episode_final, 0, U_pred],
                interpolation='none',
#                 vmin=0,
#                 vmax=1
            )
            plt.title(f'{ref} episodes={episode_final} dsref={dsrefs[dsref]} u_gnd={u_gnd}')
            plt.show()
    f.close()

filenames = [
#     'hists/hypprop_with_predictor_eeb99_hughtok1_20180921_191657.h5',
#     'hists/hypprop_with_predictor_eeb100_hughtok2_20180921_193244.h5',
#     'hists/hypprop_with_predictor_eeb101_hughtok3_20180921_193516.h5',
#     'hists/hypprop_with_predictor_eeb102_hughtok4_20180921_193310.h5',

#     'hists/hypprop_with_predictor_eeb103_hughtok1_20180921_212947.h5',

#     'hists/hypprop_with_predictor_eeb105_hughtok5_20180921_230949.h5',
#     'hists/hypprop_with_predictor_eeb106_hughtok6_20180921_231010.h5',

#     'hists/hypprop_with_predictor_eeb107_hughtok2_20180922_013551.h5',
#     'hists/hypprop_with_predictor_eeb108_hughtok1_20180922_170013.h5',
#     'hists/hypprop_with_predictor_eeb109_hughtok2_20180922_184537.h5',
    'hists/hypprop_with_predictor_eeb113_hughtok6_20180922_211153.h5',
]

# filenames = """
# ../hists/hypprop_with_predictor_eeb133_hughtok4_20180925_082438.h5
# ../hists/hypprop_with_predictor_eeb134_hughtok5_20180925_082440.h5
# ../hists/hypprop_with_predictor_eeb135_hughtok6_20180925_082442.h5
# ../hists/hypprop_with_predictor_eeb130_hughtok1_20180925_014352.h5
# ../hists/hypprop_with_predictor_eeb131_hughtok2_20180925_014520.h5
# ../hists/hypprop_with_predictor_eeb132_hughtok3_20180925_014550.h5
# """.split('\n')
# filenames = [f.replace('../', '') for f in filenames if f != '']

filenames = [
#     'hists/hypprop_with_predictor_eeb128_hughtok7_20180924_145219.h5',
#     'hists/hypprop_with_predictor_eeb129_hughtok1_20180924_145244.h5'
    'hists/hypprop_with_predictor_eeb133_hughtok4_20180925_082438.h5'
]
for filename in filenames:
    run(filename)