In [None]:
import h5py as h5
import matplotlib.pyplot as pp
import matplotlib.transforms as mpt
import numpy as np

%matplotlib inline

In [None]:
dataset_path = "/home/cqql/data/gists.h5"

with h5.File(dataset_path, "r") as f:
    label_index = list(f["label_index"])
    logits = []
    labels = []

    def collect_logits(name, obj):
        if not name.endswith("-17") and not name.endswith("-18"):
            return
        
        print(name)

        if isinstance(obj, h5.Group) and "data" in obj and "labels" in obj:
            logits.append(np.array(obj["data"]))
            labels.append(np.array(obj["labels"]))

    f.visititems(collect_logits)

In [None]:
logits = np.concatenate(logits, axis=0)
labels = np.concatenate(labels, axis=0)

labels[labels == -1] = len(label_index)
label_index.append("<blank>")

In [None]:
cm = np.zeros((len(label_index), len(label_index)))
predictions = np.argmax(logits, axis=-1)
for i in range(len(label_index)):
    fltr = labels == i
    l = labels[fltr]
    p = predictions[fltr]

    for j in range(len(label_index)):
        cm[i, j] = np.count_nonzero(p == j) / len(l)

In [None]:
fig, ax = pp.subplots(1, 1, figsize=(4, 4), dpi=300)

ticks = np.arange(len(label_index))

ax.set_xticks(ticks)
ax.set_xticklabels(label_index, fontdict={"size": 5}, rotation=-45)
ax.set_yticks(ticks)
ax.set_yticklabels(label_index, fontdict={"size": 5}, rotation=-45)

ax.xaxis.set_tick_params(labeltop="on", labelbottom="off", top="on", bottom="off")

for tick in ax.xaxis.get_majorticklabels():
    tick.set_horizontalalignment("right")
    
for tick in ax.yaxis.get_majorticklabels():
    tick.set_verticalalignment("bottom")

ax.imshow(cm)

In [None]:
fig.savefig("../doc/figures/dataset/confusion-matrix.png", bbox_inches=mpt.Bbox.from_bounds(-0.2, 0.0, 4.0, 4.5))