In [12]:
from typing import List, Optional
import urllib.request
from tqdm import tqdm
from pathlib import Path
import requests
import torch
import math
import numpy as np
import os
import glob
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

import torch.nn as nn


torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)


<torch._C.Generator at 0x7ffe27df6f10>

In [3]:
from utils import *

In [6]:
# given the random seeds and arguments are the same, this "should" give the same train/test split..
# TODO: find better way to persist train/test split

download_quickdraw_dataset(root="../data/npy", class_names = ['airplane', 'apple', 'wine bottle', 'car', 'mouth', 'pineapple', 'umbrella', 'pear', 'moustache', 'smiley face'] + ['train', 'mosquito', 'bee', 'dragon', 'piano'])
dataset = QuickDrawDataset(root = "../data/npy", max_items_per_class=100000)

train_ds, val_ds = dataset.split(0.2)
validation_dataloader = DataLoader(val_ds, batch_size=1, shuffle=False)

Downloading Quickdraw Dataset...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 1018.60it/s]


Loading 100000 examples for each class from the Quickdraw Dataset...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:22<00:00,  1.48s/it]


In [17]:
model = nn.Sequential(
   nn.Conv2d(1, 16, 3, padding='same'),
   nn.ReLU(),
   nn.MaxPool2d(2),
   nn.Conv2d(16, 32, 3, padding='same'),
   nn.ReLU(),
   nn.MaxPool2d(2),
   nn.Conv2d(32, 32, 3, padding='same'),
   nn.ReLU(),
   nn.MaxPool2d(2),
   nn.Flatten(),
   nn.Linear(288, 128),
   nn.ReLU(),
   nn.Linear(128, len(dataset.classes)),
)

In [21]:
checkpoint = torch.load('./model_lessCapacity.pth',  map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)

model.eval()

Sequential(
  (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (4): ReLU()
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (7): ReLU()
  (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (9): Flatten(start_dim=1, end_dim=-1)
  (10): Linear(in_features=288, out_features=128, bias=True)
  (11): ReLU()
  (12): Linear(in_features=128, out_features=15, bias=True)
)

In [36]:
stats = [{"idx": i, "label": dataset.classes[i], "count": 0, "correct": 0} for i in range(15)]

for i, batch in enumerate(validation_dataloader, 0):
    x, y = batch
    logits = model(x)
    y_hat = np.argmax(logits.detach().numpy())
    
    class_idx = y.item()
    
    stats[class_idx]["count"] += 1
    if(y.item() == y_hat):
        stats[class_idx]["correct"] += 1
        

In [38]:
import pandas as pd
df = pd.DataFrame.from_dict(stats).set_index('idx')
df['accuracy'] = df['correct']/df['count']


In [57]:
## manual mapping could be done better...

df['category'] = ['convergent' for i in range(15)]

df.at[2, 'category'] = 'divergent'
df.at[4, 'category'] = 'divergent'
df.at[5, 'category'] = 'divergent'
df.at[9, 'category'] = 'divergent'
df.at[12, 'category'] = 'divergent'


In [66]:
df[df.category == 'convergent'].accuracy.mean()

0.9380084564409474

In [67]:
df

Unnamed: 0_level_0,label,count,correct,accuracy,category
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,airplane,19946,18126,0.908754,convergent
1,apple,19961,18964,0.950053,convergent
2,bee,19941,18041,0.904719,divergent
3,car,20118,18934,0.941147,convergent
4,dragon,19811,16466,0.831154,divergent
5,mosquito,20247,16779,0.828715,divergent
6,moustache,20114,18099,0.899821,convergent
7,mouth,20240,18619,0.919911,convergent
8,pear,20019,18181,0.908187,convergent
9,piano,19962,19009,0.952259,divergent
