In [None]:
import torch
import os
import pickle
import pytorch_lightning
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

from oticon_utils.training_module import TrainingModule


In [None]:
models_dir = 'models/cnn-None'
version_path = os.path.join(models_dir, 'lightning_logs', 'version_0')
v = 0
while True:
    v += 1
    path = os.path.join(models_dir, 'lightning_logs', f'version_{v}')
    if os.path.exists(path):
        version_path = path
    else:
        break

test_predictions = np.load(os.path.join(version_path, 'test_predictions.npy'))[:,::-1]
test_labels = np.load(os.path.join(version_path, 'test_labels.npy'))

In [None]:
n_labels = 5
top_n_accuracy = np.full(test_labels.shape, False)
for i in range(n_labels):
    top_n_accuracy |= (test_predictions[:, i] == test_labels)
    print(f'Top {i+1} acc: {np.mean(top_n_accuracy):.3}')

In [None]:
classes = ['Other', 'Music', 'Human speech', 'Trafic', 'Alarms']

cm = confusion_matrix(test_labels, test_predictions[:, 0])
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
cm = np.around(cm, decimals=2)
cm[np.isnan(cm)] = 0.0


fig = plt.figure(figsize=(4,4))
cax = plt.imshow(cm, cmap='viridis')
fig.colorbar(cax)
tick_marks = np.arange(len(classes))
plt.title('Confusion Matrix')
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, cm[i, j], horizontalalignment="center", color="gray")
cm