<a href="https://colab.research.google.com/github/cs-deep-quickdraw/notebooks/blob/master/results/cnn_bitmaps_test_results.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook shows the performance of three models on the test set.

In [1]:
!wget 'https://raw.githubusercontent.com/cs-deep-quickdraw/notebooks/master/100_classes.txt'
!mkdir data

--2020-03-01 17:05:53--  https://raw.githubusercontent.com/cs-deep-quickdraw/notebooks/master/100_classes.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 760 [text/plain]
Saving to: ‘100_classes.txt’


2020-03-01 17:05:53 (208 MB/s) - ‘100_classes.txt’ saved [760/760]



In [0]:
import urllib.request

f = open("100_classes.txt","r")
# And for reading use
classes = [cls.strip() for cls in f.readlines()]
f.close()

def download(classes):
  base = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'
  for i, c in enumerate(classes):
    cls_url = c.replace('_', '%20')
    path = base+cls_url+'.npy'
    print((1+i)/len(classes), c, path)
    urllib.request.urlretrieve(path, 'data/'+c+'.npy')

In [3]:
download(classes)

0.01 drums https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/drums.npy
0.02 sun https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sun.npy
0.03 laptop https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/laptop.npy
0.04 anvil https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/anvil.npy
0.05 baseball_bat https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/baseball%20bat.npy
0.06 ladder https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/ladder.npy
0.07 eyeglasses https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/eyeglasses.npy
0.08 grapes https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/grapes.npy
0.09 book https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/book.npy
0.1 dumbbell https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/dumbbell.npy
0.11 traffic_light https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/tra

In [4]:
!ls data

airplane.npy	  circle.npy	    key.npy	      shorts.npy
alarm_clock.npy   clock.npy	    knife.npy	      shovel.npy
anvil.npy	  cloud.npy	    ladder.npy	      smiley_face.npy
apple.npy	  coffee_cup.npy    laptop.npy	      snake.npy
axe.npy		  cookie.npy	    light_bulb.npy    sock.npy
baseball_bat.npy  cup.npy	    lightning.npy     spider.npy
baseball.npy	  diving_board.npy  line.npy	      spoon.npy
basketball.npy	  donut.npy	    lollipop.npy      square.npy
beard.npy	  door.npy	    microphone.npy    star.npy
bed.npy		  drums.npy	    moon.npy	      stop_sign.npy
bench.npy	  dumbbell.npy	    mountain.npy      suitcase.npy
bicycle.npy	  envelope.npy	    moustache.npy     sun.npy
bird.npy	  eyeglasses.npy    mushroom.npy      sword.npy
book.npy	  eye.npy	    pants.npy	      syringe.npy
bread.npy	  face.npy	    paper_clip.npy    table.npy
bridge.npy	  fan.npy	    pencil.npy	      tennis_racquet.npy
broom.npy	  flower.npy	    pillow.npy	      tent.npy
butterfly.npy	  frying_pan.npy    pizza.np

In [0]:
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

torch.manual_seed(42)
np.random.seed(42)

class DrawDataset(Dataset):
    def __init__(self, X, Y, transform=None):
        self.X = X
        self.Y = Y
        assert len(self.X) == len(self.Y)
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        x = torch.Tensor(self.X[idx]).type('torch.FloatTensor')
        y = self.Y[idx]

        if self.transform:
            x = self.transform(x)

        return [x, y]

In [7]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device = {}'.format(device))

device = cuda


In [0]:
def evaluate_model(model, loader):
  with torch.no_grad():
    correct = 0
    total = 0
    for batch_idx, (x, target) in enumerate(loader):
      x, target = x.to(device), target.to(device)
      out = model(x)
      _, pred = torch.max(out.data, 1)
      total += target.size(0)
      correct += (pred == target).sum().item()
    return 100. * correct / total

In [0]:
# N_TEST = 20_000 for results on the full test set
N_TEST = 200

In [0]:
from torchvision import transforms

def load_test_sets(max_classes=10):
  X = None
  Y = []

  for i, cls in enumerate(classes[:max_classes]):
    data = np.load(f'data/{cls}.npy')[:N_TEST].reshape(N_TEST, 1, 28, 28)
    if X is not None:
      X = np.concatenate((X, data))
    else:
      X = data

    Y.extend([i for _ in range(N_TEST)])

  mean = np.array([0.485, 0.456, 0.406])
  std = np.array([0.229, 0.224, 0.225])

  return DrawDataset(X, Y, transform=None), DrawDataset(X, Y, transform=transforms.Compose([
        lambda x: x / 255,
        lambda x: x.repeat(3,1,1),
        transforms.Normalize(mean=mean, std=std)
    ]))

In [0]:
test_set, test_set_rgb = load_test_sets(max_classes=100)

In [0]:
batch_size = 64

test_loader = torch.utils.data.DataLoader(
                 dataset=test_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader_rgb = torch.utils.data.DataLoader(
                 dataset=test_set_rgb,
                 batch_size=batch_size,
                 shuffle=True)

In [0]:
# Upload the models to colab's file system before running this cell
resnet34_pretrained = torch.load('/content/resnet34_pretrained.model.epoch.9')
mobilenet_pretrained = torch.load('/content/mobilenet_pretrained.model.epoch.37')
resnet_experimental = torch.load('/content/resnet_experimental.model.epoch.8')

In [21]:
print("Test accuracies:")
resnet34_pretrained.to(device)
print(f"ResNet34: {evaluate_model(resnet34_pretrained, test_loader_rgb)}")
mobilenet_pretrained.to(device)
print(f"MobileNet: {evaluate_model(mobilenet_pretrained, test_loader_rgb)}")
resnet_experimental.to(device)
print(f"Adapted ResNet18: {evaluate_model(resnet_experimental, test_loader)}")

Test accuracies:
ResNet34: 84.805
MobileNet: 84.725
Adapted ResNet18: 84.16
