In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, log_loss
import torch
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
%matplotlib inline

import sys
sys.path.append('../training_utils/')
from data_utils import get_folders
from train_utils import predict
from diagnostic_tools import top_k_accuracy, per_class_accuracy,\
    entropy, model_calibration, show_errors, most_confused_classes,\
    most_inaccurate_k_classes
    
torch.cuda.is_available()

In [None]:
torch.backends.cudnn.benchmark = True

# Model

In [None]:
from get_densenet import get_model

In [None]:
model, _, _ = get_model()

# Error analysis

### get all predictions and all misclassified images 

In [None]:
val_iterator_no_shuffle = DataLoader(
    val_folder, batch_size=256, shuffle=False
)

In [None]:
val_predictions, val_true_targets,\
    erroneous_samples, erroneous_targets,\
    erroneous_predictions = predict(model, val_iterator_no_shuffle, return_erroneous=True)
# erroneous_samples: images that were misclassified
# erroneous_targets: their true labels
# erroneous_predictions: predictions for them

### get human readable class names

In [None]:
# index to class name
decode = {val_iterator_no_shuffle.class_to_idx[k]: decode[int(k)] for k in val_folder.class_to_idx}

### number of misclassified images (there are overall 5120 images in the val dataset)

In [None]:
n_errors = len(erroneous_targets)
n_errors

### logloss and accuracies

In [None]:
log_loss(val_true_targets, val_predictions)

In [None]:
accuracy_score(val_true_targets, val_predictions.argmax(1))

In [None]:
print(top_k_accuracy(val_true_targets, val_predictions, k=(2, 3, 4, 5, 10)))

### entropy of predictions

In [None]:
hits = val_predictions.argmax(1) == val_true_targets

In [None]:
plt.hist(
    entropy(val_predictions[hits]), bins=30, 
    normed=True, alpha=0.7, label='correct prediction'
);
plt.hist(
    entropy(val_predictions[~hits]), bins=30, 
    normed=True, alpha=0.5, label='misclassification'
);
plt.legend();
plt.xlabel('entropy of predictions');

### confidence of predictions

In [None]:
plt.hist(
    val_predictions[hits].max(1), bins=30, 
    normed=True, alpha=0.7, label='correct prediction'
);
plt.hist(
    val_predictions[~hits].max(1), bins=30, 
    normed=True, alpha=0.5, label='misclassification'
);
plt.legend();
plt.xlabel('confidence of predictions');

### difference between biggest and second biggest probability

In [None]:
sorted_correct = np.sort(val_predictions[hits], 1)
sorted_incorrect = np.sort(val_predictions[~hits], 1)

plt.hist(
    sorted_correct[:, -1] - sorted_correct[:, -2], bins=30, 
    normed=True, alpha=0.7, label='correct prediction'
);
plt.hist(
    sorted_incorrect[:, -1] - sorted_incorrect[:, -2], bins=30, 
    normed=True, alpha=0.5, label='misclassification'
);
plt.legend();
plt.xlabel('difference');

### probabilistic calibration of the model

In [None]:
model_calibration(val_true_targets, val_predictions, n_bins=10)

### per class accuracies

In [None]:
per_class_acc = per_class_accuracy(val_true_targets, val_predictions)
plt.hist(per_class_acc);
plt.xlabel('accuracy');

In [None]:
most_inaccurate_k_classes(per_class_acc, 15, decode)

### class accuracy vs. number of samples in the class

In [None]:
plt.scatter((1.0/w), per_class_acc);
plt.ylabel('class accuracy');
plt.xlabel('number of available samples');

### most confused pairs of classes

In [None]:
confused_pairs = most_confused_classes(
    val_true_targets, val_predictions, decode, min_n_confusions=4
)
confused_pairs

### show some low entropy errors

In [None]:
erroneous_entropy = entropy(erroneous_predictions)
mean_entropy = erroneous_entropy.mean()
low_entropy = mean_entropy < erroneous_entropy
mean_entropy

In [None]:
show_errors(
    erroneous_samples[low_entropy], 
    erroneous_predictions[low_entropy], 
    erroneous_targets[low_entropy], 
    decode
)

### show some high entropy errors

In [None]:
show_errors(
    erroneous_samples[~low_entropy], 
    erroneous_predictions[~low_entropy], 
    erroneous_targets[~low_entropy], 
    decode
)

# Save

In [None]:
model.cpu();
torch.save(model.state_dict(), 'model121.pytorch_state')