In [1]:
import struct
from struct import unpack
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device = {}'.format(device))

device = cuda


In [0]:
# Helper from: https://github.com/googlecreativelab/quickdraw-dataset/blob/master/examples/binary_file_parser.py
def unpack_drawing(file_handle):
    # Skip key_id: 8, countrycode: 2, recognized: 1, timestamp: 4 = 15
    file_handle.read(15)
    n_strokes, = unpack('H', file_handle.read(2))
    idx = 0

    N = 0
    strokes = []
    for i in range(n_strokes):
      n_points, = unpack('H', file_handle.read(2))
      N += n_points
      fmt = str(n_points) + 'B'
      x = unpack(fmt, file_handle.read(n_points))
      y = unpack(fmt, file_handle.read(n_points))
      strokes.append((x, y))

    image = np.zeros((N, 3), dtype=np.float32)


    # Return a tensor of size number of stroke x 3 like here: https://github.com/tensorflow/docs/blob/master/site/en/r1/tutorials/sequences/recurrent_quickdraw.md#optional-converting-the-data
    for i, (x, y) in enumerate(strokes):
        n_points = len(x)
        image[idx:idx+n_points, 0] = np.asarray(x)
        image[idx:idx+n_points, 1] = np.asarray(y)
        idx += n_points
        # Mark stroke end with a 1
        image[idx -1, 2] = 1


    # Preprocessing.
    # 1. Size normalization.
    lower = np.min(image[:, 0:2], axis=0)
    upper = np.max(image[:, 0:2], axis=0)
    scale = upper - lower
    scale[scale == 0] = 1
    image[:, 0:2] = (image[:, 0:2] - lower) / scale
    # 2. Compute deltas.
    image[1:, 0:2] -= image[0:-1, 0:2]
    image = image[1:, :]

    return torch.FloatTensor(image)


def unpack_drawings(filename):
    with open(filename, 'rb') as f:
        while True:
            try:
                yield unpack_drawing(f)
            except struct.error:
                break

In [3]:
!wget 'https://raw.githubusercontent.com/cs-deep-quickdraw/notebooks/master/10_classes.txt'
!mkdir data

--2020-02-13 14:44:33--  https://raw.githubusercontent.com/cs-deep-quickdraw/notebooks/master/10_classes.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 75 [text/plain]
Saving to: ‘10_classes.txt’


2020-02-13 14:44:33 (19.9 MB/s) - ‘10_classes.txt’ saved [75/75]



In [0]:
import urllib.request

f = open("10_classes.txt","r")
# And for reading use
classes = [cls.strip() for cls in f.readlines()]
f.close()

def download(classes):
  base = 'https://storage.googleapis.com/quickdraw_dataset/full/binary/'
  for c in classes:
    cls_url = c.replace('_', '%20')
    path = base+cls_url+'.bin'
    print(path)
    urllib.request.urlretrieve(path, 'data/'+c+'.bin')

In [5]:
download(classes[:10])

https://storage.googleapis.com/quickdraw_dataset/full/binary/drums.bin
https://storage.googleapis.com/quickdraw_dataset/full/binary/sun.bin
https://storage.googleapis.com/quickdraw_dataset/full/binary/laptop.bin
https://storage.googleapis.com/quickdraw_dataset/full/binary/anvil.bin
https://storage.googleapis.com/quickdraw_dataset/full/binary/baseball%20bat.bin
https://storage.googleapis.com/quickdraw_dataset/full/binary/ladder.bin
https://storage.googleapis.com/quickdraw_dataset/full/binary/eyeglasses.bin
https://storage.googleapis.com/quickdraw_dataset/full/binary/grapes.bin
https://storage.googleapis.com/quickdraw_dataset/full/binary/book.bin
https://storage.googleapis.com/quickdraw_dataset/full/binary/dumbbell.bin


In [6]:
!ls data

anvil.bin	  book.bin   dumbbell.bin    grapes.bin  laptop.bin
baseball_bat.bin  drums.bin  eyeglasses.bin  ladder.bin  sun.bin


In [0]:
i_drawings = unpack_drawings("data/anvil.bin")

In [8]:
from pprint import pprint
pprint(next(i_drawings)[:2])
pprint(next(i_drawings)[:2])

tensor([[ 0.2235, -0.0404,  0.0000],
        [ 0.3176, -0.0101,  0.0000]])
tensor([[-0.0118,  0.3372,  0.0000],
        [ 0.0984, -0.0116,  0.0000]])


In [0]:
class StrokeClassifier(nn.Module):

  def __init__(self, hidden_dim, n_classes):
    super(StrokeClassifier, self).__init__()
    self.hidden_dim = hidden_dim

    # The LSTM takes 3 things as input (x, y, isLastPoint) and outputs hidden states with dimensionality hidden_dim
    self.lstm = nn.LSTM(3, hidden_dim, batch_first=True)

    # The linear layer maps the LSTM output to a linear space
    self.linear = nn.Linear(hidden_dim, n_classes)

  def forward(self, strokes):
    # initial states
    h0 = torch.zeros(1, strokes.size(0), self.hidden_dim).to(device)
    c0 = torch.zeros(1, strokes.size(0), self.hidden_dim).to(device)

    out, _ = self.lstm(strokes)
    # Keep last layer of the NN
    out = out[:,-1,:]
    out = self.linear(out)
    return out


In [0]:
from torch.utils.data import Dataset

class DrawDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
        assert len(self.X) == len(self.Y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return [torch.Tensor(self.X[idx]).type('torch.FloatTensor'), self.Y[idx]]

In [0]:
# Config:
batch_size = 256
learning_rate = 0.01
hidden_size = 64
train_classes = classes[:]
N_train = 10000
N_test = 2000
max_padding = 50

In [0]:
from itertools import islice
from torch.nn.utils.rnn import pad_sequence

def extract_train_test(samples_train, samples_test, classes, max_padding=100):
  X_train = []
  X_test = []
  y_train = []
  y_test = []

  for c, cls in enumerate(classes):
    drawings = unpack_drawings('data/' + cls + '.bin')

    # TODO: better way of doing this
    for i in range(samples_train):
      X_train.append(next(drawings))
      y_train.append(c)

    for i in range(samples_test):
      X_test.append(next(drawings))
      y_test.append(c)
    
  X_train = pad_sequence(X_train, batch_first=True)[:, :max_padding, :]
  X_test = pad_sequence(X_test, batch_first=True)[:, :max_padding, :]
  print("training shape", X_train.shape)
  print("testing shape", X_test.shape)
  print("classes", len(classes))

  return DrawDataset(X_train, y_train), DrawDataset(X_test, y_test)

In [20]:
from torch.nn.utils.rnn import pad_sequence

# TODO: really take the last 2k images for testing
train_dataset, test_dataset = extract_train_test(N_train, N_test, train_classes, max_padding=max_padding)

training shape torch.Size([100000, 50, 3])
testing shape torch.Size([20000, 50, 3])
classes 10


In [0]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [0]:
import torch.optim as optim

model = StrokeClassifier(hidden_size, len(train_classes)).to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [23]:
# BAD don't do that much epochs
last_loss = None
for epoch in range(50):
  print(f"Epoch: {epoch}, last_loss: {last_loss}")

  for i, (img, lab) in enumerate(train_loader):
    img = img.to(device)
    lab = torch.LongTensor(lab).to(device)

    out = model(img)

    loss = loss_function(out, lab)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    last_loss = loss.item()

Epoch: 0, last_loss: None
Epoch: 1, last_loss: 0.955832839012146
Epoch: 2, last_loss: 0.6931900978088379
Epoch: 3, last_loss: 0.5070900917053223
Epoch: 4, last_loss: 0.5635804533958435
Epoch: 5, last_loss: 0.39240333437919617
Epoch: 6, last_loss: 0.3495434522628784
Epoch: 7, last_loss: 0.49082279205322266
Epoch: 8, last_loss: 0.26588380336761475
Epoch: 9, last_loss: 0.2726176679134369
Epoch: 10, last_loss: 0.2209448367357254
Epoch: 11, last_loss: 0.20589976012706757
Epoch: 12, last_loss: 0.27800363302230835
Epoch: 13, last_loss: 0.2930254638195038
Epoch: 14, last_loss: 0.23719331622123718
Epoch: 15, last_loss: 0.1916734278202057
Epoch: 16, last_loss: 0.3294978439807892
Epoch: 17, last_loss: 0.274522602558136
Epoch: 18, last_loss: 0.18168215453624725
Epoch: 19, last_loss: 0.25937822461128235
Epoch: 20, last_loss: 0.28454890847206116
Epoch: 21, last_loss: 0.1985488384962082
Epoch: 22, last_loss: 0.17741376161575317
Epoch: 23, last_loss: 0.30256128311157227
Epoch: 24, last_loss: 0.1874100

In [24]:
# Test

with torch.no_grad():
  correct = 0
  total = 0
  
  for i, (img, label) in enumerate(test_loader):
    img = img.to(device)
    label = label.to(device)

    out = model(img)

    _, pred = torch.max(out.data, 1)

    total += label.size(0)
    correct += (pred == label).sum().item()

  print('Test Accuracy: {}%'.format(100. * correct / total) )

Test Accuracy: 90.655%
