In [None]:
from typing import List, Optional
import urllib.request
from tqdm import tqdm
from pathlib import Path
import requests
import torch
import math
import numpy as np
import os
import glob
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import pandas as pd
import sklearn.metrics
import seaborn as sb

import torch.nn as nn


torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)


In [None]:
import utils

In [None]:
import importlib
importlib.reload(utils)

In [None]:
# given the random seeds and arguments are the same, this "should" give the same train/test split..
# TODO: find better way to persist train/test split

utils.download_quickdraw_dataset(root="../data/npy", class_names = ['airplane', 'apple', 'wine bottle', 'car', 'mouth', 'pineapple', 'umbrella', 'pear', 'moustache', 'smiley face'] + ['train', 'mosquito', 'bee', 'dragon', 'piano'])
dataset = utils.QuickDrawDataset(root = "../data/npy", max_items_per_class=100000)

train_ds, val_ds = dataset.split(0.2)
validation_dataloader = DataLoader(val_ds, batch_size=1, shuffle=False)

In [None]:
model = nn.Sequential(
   nn.Conv2d(1, 16, 3, padding='same'),
   nn.ReLU(),
   nn.MaxPool2d(2),
   nn.Conv2d(16, 32, 3, padding='same'),
   nn.ReLU(),
   nn.MaxPool2d(2),
   nn.Conv2d(32, 32, 3, padding='same'),
   nn.ReLU(),
   nn.MaxPool2d(2),
   nn.Flatten(),
   nn.Linear(288, 128),
   nn.ReLU(),
   nn.Linear(128, len(dataset.classes)),
)

In [None]:
checkpoint = torch.load('./model_lessCapacity.pth',  map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)

model.eval()

In [None]:
stats = [{"idx": i, "label": dataset.classes[i], "count": 0, "correct": 0} for i in range(15)]

for i, batch in enumerate(validation_dataloader, 0):
    x, y, idx = batch
    logits = model(x)
    y_hat = np.argmax(logits.detach().numpy())
    
    class_idx = y.item()
    
    stats[class_idx]["count"] += 1
    if(y.item() == y_hat):
        stats[class_idx]["correct"] += 1
        

In [None]:
df = pd.DataFrame.from_dict(stats).set_index('idx')
df['accuracy'] = df['correct']/df['count']


In [None]:
## manual mapping could be done better...

df['category'] = ['convergent' for i in range(15)]

df.at[2, 'category'] = 'divergent'
df.at[4, 'category'] = 'divergent'
df.at[5, 'category'] = 'divergent'
df.at[9, 'category'] = 'divergent'
df.at[12, 'category'] = 'divergent'


In [None]:
df[df.category == 'convergent'].accuracy.mean()

In [None]:
df

In [None]:
indices = [i*100000 for i in range(13)]
indices = [i + j for j in range(30) for i in indices]
for idx in indices:
    plt.figure()
    plt.imshow(dataset[idx][0].reshape(28,28,1), cmap='Greys')
    #plt.title(dataset.classes[dataset[idx][1]])
    plt.axis('off')
    plt.savefig(dataset.classes[dataset[idx][1]] +  str(idx) + '.png', bbox_inches='tight')

In [None]:
idx = 300000

In [None]:
idx += +1
plt.figure()
plt.imshow(dataset[idx][0].reshape(28,28,1), cmap='Greys')
#plt.title(dataset.classes[dataset[idx][1]])
plt.axis('off')
plt.show()



## Confusion Matrix

In [None]:
columns = ['index', 'y', 'y_hat', 'logits', 'probs']
predictions = []

for i, batch in tqdm(enumerate(validation_dataloader, 0)):
    x, y, idx = batch
    logits = model(x)

    logits = logits.detach().numpy()
    probs = np.exp(logits)/np.exp(logits).sum()

    y_hat = np.argmax(logits)
    
    class_idx = y.item()
    
    predictions.append([idx, y, y_hat, logits, probs])
        

In [None]:
df2 = pd.DataFrame(predictions, columns=columns)
df2.y = df2.y.apply(lambda x: x[0].item())

In [None]:
df2.y_hat

In [None]:
cm = sklearn.metrics.confusion_matrix(list(df2.y), list(df2.y_hat), normalize='true')
disp = sklearn.metrics.ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=dataset.classes)
fig, ax = plt.subplots(figsize=(15, 15))
disp.plot(xticks_rotation='vertical', ax=ax, cmap='plasma')

In [None]:
cm[:,4].sum()

In [None]:
cm[4,:].sum()

## Metrics

In [None]:
metrics = sklearn.metrics.precision_recall_fscore_support(list(df2.y), list(df2.y_hat)
                                                         ,labels=range(len(dataset.classes)))


In [None]:
df3 = pd.DataFrame(np.array(metrics).T, columns=['precision', 'recall', 'fscore', 'support'])


In [None]:
df3['className'] = dataset.classes

In [None]:
df3

## ROC

In [None]:
one_hot = np.zeros((df2.y_hat.size, df2.y_hat.max() + 1))
one_hot[np.arange(df2.y_hat.size), df2.y_hat] = 1

In [None]:
logits_as_list = np.array(list(df2.logits.apply(lambda x: x[0])))

In [None]:
sklearn.metrics.RocCurveDisplay.from_predictions(
    one_hot.ravel(),
    logits_as_list.ravel(),
    name="micro-average OvR",
    color="darkorange",
)

In [None]:
import itertools
fig, ax = plt.subplots(figsize=(10, 10))
# classes for which AUC is the worst
for class_id in [4,5,6,8,12]:
    sklearn.metrics.RocCurveDisplay.from_predictions(
        one_hot[:, class_id],
        logits_as_list[:, class_id],
        name=f"ROC curve for {dataset.classes[class_id]}",
        ax=ax,
    )

## Boxplot

In [None]:
df2['probs']

In [None]:
dfx = df2.join(pd.DataFrame(list(df2['probs'].apply(lambda x : x[0])), columns=dataset.classes))
dfx['y_class'] = dfx.y.apply(lambda x : dataset.classes[x])

In [None]:
dfx['pred_prob_for_true_class'] = dfx.apply(lambda x: x.probs[0][x.y], axis=1)

In [None]:
chart = sb.violinplot(
    x = 'y_class', 
    y = "pred_prob_for_true_class", 
    data = dfx, 
    inner="stick", 
    cut=0, 
    linewidth=0)
chart.set_xticklabels(chart.get_xticklabels(),rotation = 90)
chart.set(xlabel='Classes', ylabel='Predicted Probabilities')

In [None]:
chart = sb.boxplot(
    x = 'y_class', 
    y = "pred_prob_for_true_class", 
    data = dfx)
chart.set_xticklabels(chart.get_xticklabels(),rotation = 90)
chart.set(xlabel='Classes', ylabel='Predicted Probabilities')

## Single class

In [None]:
chart = sb.violinplot(
    x = 'y_class', 
    y = "mouth", 
    data = dfx, 
    inner="stick", 
    cut=0, 
    linewidth=0)
chart.set_xticklabels(chart.get_xticklabels(),rotation = 90)
chart.set(xlabel='Classes', ylabel='Predicted Probabilities')

In [None]:
def boxplot_for_class(class_name):
    chart = sb.boxplot(
        x = 'y_class', 
        y = class_name, 
        data = dfx,
        flierprops={"marker": "x"},
        medianprops={"color": "coral"},
        fliersize=1,
        whis=[1, 99])
    chart.set_xticklabels(chart.get_xticklabels(),rotation = 90)
    chart.set(xlabel='Classes', 
              ylabel='Predicted Probabilities',
              title="Predicted Probabilities for classes given true class is "+class_name+"\nWhiskers are at 0.01 and 0.99 percentile")


In [None]:
boxplot_for_class("apple")

In [None]:
dfe = dfx[dfx.columns.difference(["logits", "probs"])]
dfe["index"] = dfe["index"].apply(lambda x : x.item())
dfe["y_hat_class"] = dfe["y_hat"].apply(lambda x : dataset.classes[x])
dfe.to_csv("evaluation_main_stats.csv")

# Single Predictions

In [None]:
index = np.random.randint(0, len(dataset))

In [None]:
def get_prediction(index):
    batch = dataset[index]
    x, y, idx = batch
    x_batch = x.unsqueeze(dim=0)
    logits = model(x_batch)

    logits = logits.detach().numpy()
    probs = np.exp(logits)/np.exp(logits).sum()
    probs = probs[0]
    y_hat = np.argmax(logits)
    class_name = dataset.classes[y]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    ax1.bar(range(len(probs)),probs, tick_label=dataset.classes)
    ax1.set_xticklabels(dataset.classes, rotation=90)
    ax1.set_title(index)
    plt.xticks()
    ax2.imshow(x.reshape(28,28,1), cmap='Greys')
    plt.title(dataset.classes[dataset[index][1]])
    plt.axis('off')

In [None]:
get_prediction(1033675)

In [None]:
get_prediction(np.random.randint(0, len(dataset)))

In [None]:
class_name

In [None]:
x.shape

### Specific Analysis

In [None]:
completely_wrong = list(dfx[dfx.pred_prob_for_true_class < 0.01]['index'].apply(lambda x: x.item()))

In [None]:
get_prediction(completely_wrong[np.random.randint(0,len(completely_wrong))])

## Correlation

In [None]:
corr = dfx[dataset.classes][dfx.pred_prob_for_true_class < 0.8].corr()
corr.style.background_gradient(cmap='coolwarm')


In [None]:
# resulting pdf would be 500MB+
#fig = pd.plotting.scatter_matrix(dfx[dataset.classes], alpha=0.2, diagonal='kde', figsize=(20,20))
#plt.savefig("scattermatrix.pdf")

In [None]:
#https://towardsdatascience.com/better-heatmaps-and-correlation-matrix-plots-in-python-41445d0f2bec
def heatmap(x, y, **kwargs):
    if 'color' in kwargs:
        color = kwargs['color']
    else:
        color = [1]*len(x)

    if 'palette' in kwargs:
        palette = kwargs['palette']
        n_colors = len(palette)
    else:
        n_colors = 256 # Use 256 colors for the diverging color palette
        palette = sb.color_palette("Blues", n_colors) 

    if 'color_range' in kwargs:
        color_min, color_max = kwargs['color_range']
    else:
        color_min, color_max = min(color), max(color) # Range of values that will be mapped to the palette, i.e. min and max possible correlation

    def value_to_color(val):
        if color_min == color_max:
            return palette[-1]
        else:
            val_position = float((val - color_min)) / (color_max - color_min) # position of value in the input range, relative to the length of the input range
            val_position = min(max(val_position, 0), 1) # bound the position betwen 0 and 1
            ind = int(val_position * (n_colors - 1)) # target index in the color palette
            return palette[ind]

    if 'size' in kwargs:
        size = kwargs['size']
    else:
        size = [1]*len(x)

    if 'size_range' in kwargs:
        size_min, size_max = kwargs['size_range'][0], kwargs['size_range'][1]
    else:
        size_min, size_max = min(size), max(size)

    size_scale = kwargs.get('size_scale', 500)

    def value_to_size(val):
        if size_min == size_max:
            return 1 * size_scale
        else:
            val_position = (val - size_min) * 0.99 / (size_max - size_min) + 0.01 # position of value in the input range, relative to the length of the input range
            val_position = min(max(val_position, 0), 1) # bound the position betwen 0 and 1
            return val_position * size_scale
    if 'x_order' in kwargs: 
        x_names = [t for t in kwargs['x_order']]
    else:
        x_names = [t for t in sorted(set([v for v in x]))]
    x_to_num = {p[1]:p[0] for p in enumerate(x_names)}

    if 'y_order' in kwargs: 
        y_names = [t for t in kwargs['y_order']]
    else:
        y_names = [t for t in sorted(set([v for v in y]))]
    y_to_num = {p[1]:p[0] for p in enumerate(y_names)}

    plot_grid = plt.GridSpec(1, 15, hspace=0.2, wspace=0.1) # Setup a 1x10 grid
    ax = plt.subplot(plot_grid[:,:-1]) # Use the left 14/15ths of the grid for the main plot

    marker = kwargs.get('marker', 's')

    kwargs_pass_on = {k:v for k,v in kwargs.items() if k not in [
         'color', 'palette', 'color_range', 'size', 'size_range', 'size_scale', 'marker', 'x_order', 'y_order'
    ]}

    ax.scatter(
        x=[x_to_num[v] for v in x],
        y=[y_to_num[v] for v in y],
        marker=marker,
        s=[value_to_size(v) for v in size], 
        c=[value_to_color(v) for v in color],
        **kwargs_pass_on
    )
    ax.set_xticks([v for k,v in x_to_num.items()])
    ax.set_xticklabels([k for k in x_to_num], rotation=45, horizontalalignment='right')
    ax.set_yticks([v for k,v in y_to_num.items()])
    ax.set_yticklabels([k for k in y_to_num])

    ax.grid(False, 'major')
    ax.grid(True, 'minor')
    ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True)
    ax.set_yticks([t + 0.5 for t in ax.get_yticks()], minor=True)

    ax.set_xlim([-0.5, max([v for v in x_to_num.values()]) + 0.5])
    ax.set_ylim([-0.5, max([v for v in y_to_num.values()]) + 0.5])
    ax.set_facecolor('#F1F1F1')

    # Add color legend on the right side of the plot
    if color_min < color_max:
        ax = plt.subplot(plot_grid[:,-1]) # Use the rightmost column of the plot

        col_x = [0]*len(palette) # Fixed x coordinate for the bars
        bar_y=np.linspace(color_min, color_max, n_colors) # y coordinates for each of the n_colors bars

        bar_height = bar_y[1] - bar_y[0]
        ax.barh(
            y=bar_y,
            width=[5]*len(palette), # Make bars 5 units wide
            left=col_x, # Make bars start at 0
            height=bar_height,
            color=palette,
            linewidth=0
        )
        ax.set_xlim(1, 2) # Bars are going from 0 to 5, so lets crop the plot somewhere in the middle
        ax.grid(False) # Hide grid
        ax.set_facecolor('white') # Make background white
        ax.set_xticks([]) # Remove horizontal ticks
        ax.set_yticks(np.linspace(min(bar_y), max(bar_y), 3)) # Show vertical ticks for min, middle and max
        ax.yaxis.tick_right() # Show vertical ticks on the right 


def corrplot(data, size_scale=500, marker='s'):
    corr = pd.melt(data.reset_index(), id_vars='index')
    corr.columns = ['x', 'y', 'value']
    heatmap(
        corr['x'], corr['y'],
        color=corr['value'], color_range=[-1, 1],
        palette=sb.diverging_palette(20, 220, n=256),
        size=corr['value'].abs(), size_range=[0,1],
        marker=marker,
        x_order=data.columns,
        y_order=data.columns[::-1],
        size_scale=size_scale
    )

In [None]:
dfx[dataset.classes].corr()

In [None]:
plt.figure(figsize=(10, 10))
corrplot(dfx[dataset.classes][dfx.pred_prob_for_true_class < 0.8].corr())