In [None]:
%matplotlib inline

import time
import json
import shutil

import h5py
import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics.cluster

import torch

# filename = 'hists/hypprop_with_predictor_eeb99_hughtok1_20180921_191657.h5'
# filename = 'hists/hypprop_with_predictor_eeb100_hughtok2_20180921_193244.h5'
# filename = 'hists/hypprop_with_predictor_eeb101_hughtok3_20180921_193516.h5'
# filename = 'hists/hypprop_with_predictor_eeb102_hughtok4_20180921_193310.h5'

# filename = 'hists/hypprop_with_predictor_eeb103_hughtok1_20180921_212947.h5'

# filename = 'hists/hypprop_with_predictor_eeb105_hughtok5_20180921_230949.h5'
# filename = 'hists/hypprop_with_predictor_eeb106_hughtok6_20180921_231010.h5'

# filename = 'hists/hypprop_with_predictor_eeb107_hughtok2_20180922_013551.h5'
# filename = 'hists/hypprop_with_predictor_eeb108_hughtok1_20180922_170013.h5'
# filename = 'hists/hypprop_with_predictor_eeb109_hughtok2_20180922_184537.h5'
# filename = 'hists/hypprop_with_predictor_eeb113_hughtok6_20180922_211153.h5'

filename = 'hists/hypprop_with_predictor_eeb115_hughtok1_20180923_173818.h5'

def create_ami_matrix(gnd, pred):
    """
    assumptions:
    - gnd and pred are 2d matrices
    - gnd is [U_gnd][N]
    - pred is [U_pred][N]
    """
    U_gnd = gnd.size(0)
    U_pred = pred.size(0)

    ami_matrix = torch.zeros(U_pred, U_gnd)
    for i in range(U_pred):
        for j in range(U_gnd):
            AMI = sklearn.metrics.cluster.adjusted_mutual_info_score(
                labels_true=gnd[j].numpy(),
                labels_pred=pred[i].numpy()
            )
            ami_matrix[i, j] = AMI
    return ami_matrix

def run():
    shutil.copyfile(f'../{filename}', '/tmp/foo.h5')
    # f = h5py.File('../hists/hypprop_with_predictor_eeb95_hughtok2_20180921_144419.h5', 'r')

    f = h5py.File('/tmp/foo.h5', 'r')

    print('f.keys()', f.keys())

    hypotheses_train = torch.from_numpy(f['hypotheses_train'][:])
    hypotheses_gnd_train = torch.from_numpy(f['gnd_hypotheses_train'][:].astype(np.uint8)).long()
    dsrefs_train = torch.from_numpy(f['dsrefs_train'][:].astype(np.uint8)).long()
    resdicts = f['resdicts']

    print(hypotheses_train.shape)
    print(hypotheses_gnd_train.shape)
    print(dsrefs_train.shape)

    N = dsrefs_train.size(0)
    print('N', N)

    num_renders = len(resdicts)

    batch_size = 128

    print('len(resdicts)', len(resdicts))
    print('holdout dsrefs len', f['dsrefs_holdout'].shape[0])

    num_dsrefs = dsrefs_train.max().item() + 1
    print('num_dsrefs', num_dsrefs)
    for dsref in range(num_dsrefs):
        print('dsref', dsref)
        start_render_id = 0
        if num_renders > 100:
            start_render_id = num_renders - 100
        rows = (num_renders - start_render_id + 4) // 5
        plt.figure(figsize=(20, 5 * rows))
        for render_id in range(start_render_id, num_renders):
            print('render_id', render_id)
            b_start = render_id * batch_size
            b_end = b_start + batch_size
            _hypotheses_train = hypotheses_train[:, b_start:b_end]
            _hypotheses_gnd_train = hypotheses_gnd_train[:, b_start:b_end]
            _dsrefs_train = dsrefs_train[b_start:b_end]

            dsref_idxes = (_dsrefs_train == dsref).view(-1).nonzero().view(-1).long()
            _hypotheses_train = _hypotheses_train[:, dsref_idxes]
            _hypotheses_gnd_train = _hypotheses_gnd_train[:, dsref_idxes]
            
            resdict = json.loads(resdicts[render_id])
            episode = resdict['episode']

            _, _pred = _hypotheses_train.max(dim=-1)
            ami_matrix = create_ami_matrix(
                gnd=_hypotheses_gnd_train,
                pred=_pred
            )
    
            r_acc1 = resdict['r_acc1']
            plt.subplot(rows, 5, render_id - start_render_id + 1)
            plt.imshow(ami_matrix.numpy())
            plt.title(f'e={episode} r_acc1={r_acc1:.3f}')
        plt.show()
    f.close()

run()