# Quick Draw Model - PyTorch

## Environment Setup

### Imports

In [14]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import json
from datetime import datetime
import functools

if torch.cuda.is_available():
    print("Full power!")
    dev = torch.set_default_device("cuda")
else:
    print("Regular power..")
    dev = torch.set_default_device("cpu")

Full power!


### Checks

In [2]:
torch.get_default_device()

device(type='cuda', index=0)

## Manipulating The Data

### Reading the data

In [3]:
# Define some useful global variables

classes = {}
batch_size = 8
dropout_rate = 0.3
num_layers = 3
num_nodes = 128

In [4]:
from os import listdir
from os.path import isfile, join

classes = {}

def parseLine(ndjsonLine):
  """Parse an ndjson line and return ink (as np array) and classname."""
  sample = json.loads(ndjsonLine)
  class_name = sample["word"]
  inkarray = sample["drawing"]
  stroke_lengths = [len(stroke[0]) for stroke in inkarray]
  total_points = sum(stroke_lengths)
  np_ink = np.zeros((total_points, 3), dtype=np.float32)
  current_t = 0
  for stroke in inkarray:
    for i in [0, 1]:
      np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i]
    current_t += len(stroke[0])
    np_ink[current_t - 1, 2] = 1  # stroke_end
      
  # Preprocessing.
  # 1. Size normalization.
  lower = np.min(np_ink[:, 0:2], axis=0)
  upper = np.max(np_ink[:, 0:2], axis=0)
  scale = upper - lower
  scale[scale == 0] = 1
  np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale
    
  # 2. Compute deltas.
  np_ink[1:, 0:2] -= np_ink[0:-1, 0:2]
  np_ink = np_ink[1:, :]
  return torch.from_numpy(np_ink), class_name

def readData(files, train_data, test_data, limit = -1):
    # Clear the global variables
    classes = {}
    
    filesToParse = files if limit < 0 else files[:limit]

    currClassIndex = 0
    classNameToIndex = {}
    
    cnt = 0
    sampleCnt = 0
    for filePath in filesToParse:        
        with open(filePath) as file:
            for line in file:
                # sample = json.loads(line)
                # className = sample["word"]
                features = {}
                features["ink"], features["className"] = parseLine(line)

                # Define the shape of the ink
                features["shape"] = features["ink"].shape

                # Index the class
                if features["className"] not in classNameToIndex:
                    classNameToIndex[features["className"]] = currClassIndex
                    currClassIndex += 1

                features["classIndex"] = classNameToIndex[features["className"]]

                # Keep a class statistic
                if features["className"] not in classes:
                    classes[features["className"]] = 0

                classes[features["className"]] += 1

                if sampleCnt % 11000 < 10000:
                    train_data.append(features)
                else:
                    test_data.append(features)

                sampleCnt += 1

        cnt += 1

        print("Finished parsing {0}/{1}: {2}".format(cnt, len(files), filePath))

    print("Finished parsing all the data!")
    return classes


In [5]:
qd_train_raw_data = []
qd_test_raw_data = []

root_dir = "datasets"

dataFiles = [root_dir + "/" + f for f in listdir(root_dir) if f.endswith(".ndjson")]
classes = readData(dataFiles, qd_train_raw_data, qd_test_raw_data, limit=4)

print()
print(classes)
print("Train data len:", len(qd_train_raw_data))
print("Test data len:", len(qd_test_raw_data))

Finished parsing 1/16: datasets/full_simplified_airplane.ndjson
Finished parsing 2/16: datasets/full_simplified_ant.ndjson
Finished parsing 3/16: datasets/full_simplified_axe.ndjson
Finished parsing 4/16: datasets/full_simplified_bed.ndjson
Finished parsing all the data!

{'airplane': 151623, 'ant': 124612, 'axe': 124122, 'bed': 113862}
Train data len: 468219
Test data len: 46000


### Creating the Dataset and DataLoader

In [6]:
class QuickDrawDataset(Dataset):
    """Quick, Draw! data subset."""

    def __init__(self, data, classes, train):
        """
        Arguments:
            data (list): List of all the parsed data with the readData() function.
            classes (dict): Dictionary with all the classes and how many of each there are.
            train (bool): Says if the dataset is used for training or testing.
        """
        self.data = data
        self.classes = classes
        self.train = train

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

qd_train_dataset = QuickDrawDataset(qd_train_raw_data, classes, True)
qd_test_dataset = QuickDrawDataset(qd_test_raw_data, classes, False)

In [7]:
print("Classes:", qd_train_dataset.classes)
print()

for i, sample in enumerate(qd_train_dataset):
    print(i, sample)
    if i == 0:
        break

for i, sample in enumerate(qd_test_dataset):
    print(i, sample)
    if i == 0:
        break

Classes: {'airplane': 151623, 'ant': 124612, 'axe': 124122, 'bed': 113862}

0 {'ink': tensor([[-0.2292,  0.2328,  0.0000],
        [-0.1146,  0.1422,  0.0000],
        [-0.0435,  0.0216,  0.0000],
        [-0.0435, -0.0129,  0.0000],
        [-0.1067,  0.0000,  0.0000],
        [ 0.1028, -0.0991,  0.0000],
        [ 0.2372, -0.3534,  0.0000],
        [-0.0711, -0.0043,  0.0000],
        [-0.1858,  0.1121,  0.0000],
        [-0.0870,  0.0345,  0.0000],
        [-0.0949,  0.0086,  0.0000],
        [-0.0198, -0.0259,  0.0000],
        [ 0.0040, -0.1853,  0.0000],
        [ 0.2530, -0.1207,  0.0000],
        [ 0.1265, -0.0216,  0.0000],
        [ 0.6126, -0.0129,  0.0000],
        [ 0.0040,  0.1509,  0.0000],
        [-0.0316,  0.0517,  0.0000],
        [-0.2530,  0.0259,  0.0000],
        [-0.0672, -0.0302,  1.0000],
        [ 0.1660, -0.1983,  0.0000],
        [ 0.0000,  0.1250,  0.0000],
        [ 0.0119,  0.0474,  0.0000],
        [ 0.0435,  0.0603,  0.0000],
        [ 0.0672,  0.0302,

In [8]:
def quickDrawCollateFn(batch, batch_size):
    shapes = [sample["shape"] for sample in batch]
    maxLen = max([shape[0] for shape in shapes])

    ## Makes a dictionary of lists
    newBatch = {
        "ink": torch.zeros((batch_size, maxLen, 3)),
        "shape": torch.zeros((batch_size, 2), dtype=int),
        "length": torch.zeros((batch_size), dtype=int),
        "className": [],
        "classIndex": torch.zeros((batch_size), dtype=int),
        "maxLen": maxLen
    }
    for i, sample in enumerate(batch):
        newBatch["className"].append(sample["className"])
        newBatch["classIndex"][i] = sample["classIndex"]
        newBatch["shape"][i] = torch.FloatTensor(list(sample["shape"]))
        newBatch["length"][i] = sample["shape"][0]

        # Makes a copy of the tensor
        newInk = F.pad(sample["ink"], (0, 0, 0, maxLen - sample["shape"][0]))
        newBatch["ink"][i] = newInk
    
    return newBatch

qd_train_dataloader = DataLoader(qd_train_dataset, batch_size=batch_size, shuffle=True, 
                                 num_workers=0, generator=torch.Generator(device='cuda'),
                                 collate_fn=functools.partial(quickDrawCollateFn, batch_size=batch_size))

qd_test_dataloader = DataLoader(qd_test_dataset, batch_size=batch_size, shuffle=False, 
                                num_workers=0, generator=torch.Generator(device='cuda'),
                                collate_fn=functools.partial(quickDrawCollateFn, batch_size=batch_size))

qd_eval_dataloader = DataLoader(qd_test_dataset, batch_size=1, shuffle=True, 
                                num_workers=0, generator=torch.Generator(device='cuda'),
                                collate_fn=functools.partial(quickDrawCollateFn, batch_size=1))

for i_batch, sample_batched in enumerate(qd_train_dataloader):
    print(i_batch, sample_batched["ink"][0], sample_batched["shape"], sample_batched["length"], sample_batched["classIndex"])
    print()

    # observe 4th batch and stop.
    if i_batch == 1:
        break

0 tensor([[-0.2205,  0.0000,  0.0000],
        [-0.0591,  0.0136,  0.0000],
        [-0.0354,  0.0682,  0.0000],
        [-0.0039,  0.0636,  0.0000],
        [ 0.0787,  0.0409,  0.0000],
        [ 0.0984,  0.0682,  0.0000],
        [ 0.1575,  0.0500,  0.0000],
        [ 0.0354,  0.0227,  0.0000],
        [ 0.2283,  0.0136,  0.0000],
        [ 0.0709, -0.0500,  0.0000],
        [ 0.0236, -0.1136,  0.0000],
        [-0.0039, -0.0500,  0.0000],
        [-0.0906, -0.1727,  0.0000],
        [-0.0906, -0.0545,  0.0000],
        [-0.0669, -0.0045,  0.0000],
        [-0.0276,  0.0045,  0.0000],
        [-0.0276,  0.0273,  0.0000],
        [-0.0394,  0.0636,  1.0000],
        [ 0.5079,  0.0500,  0.0000],
        [-0.0551, -0.0091,  0.0000],
        [-0.0630,  0.0182,  0.0000],
        [-0.0866,  0.1091,  0.0000],
        [-0.0039,  0.1091,  0.0000],
        [ 0.0630,  0.0545,  0.0000],
        [ 0.0630,  0.0182,  0.0000],
        [ 0.1299,  0.0000,  0.0000],
        [ 0.0276, -0.0136,  0.0000],

## Defining The Model

In [9]:
# Utils

# Thanks to: https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/39036/2
def sequence_mask(lengths, maxlen = None, dtype=torch.bool):
    if maxlen is None:
        maxlen = lengths.max()
    mask = ~(torch.ones((len(lengths), maxlen)).cumsum(dim=1).t() > lengths).t()
    mask.type(dtype)
    return mask


In [10]:
class QuickDrawRNN(torch.nn.Module):

    def __init__(self, classes):
        super(QuickDrawRNN, self).__init__()

        # Init data
        classCnt = len(classes)

        # 3x 1D Convolutions
        
        # Filters: [48, 64, 96]
        # Length of convolutional filters: [5, 5, 3]
        self.conv = torch.nn.Sequential(
            torch.nn.Conv1d( 3, 48, 5, stride=1, padding=2),  # Should we disable bias?
            torch.nn.Dropout(p=dropout_rate),
            torch.nn.Conv1d(48, 64, 5, stride=1, padding=2),
            torch.nn.Dropout(p=dropout_rate),
            torch.nn.Conv1d(64, 96, 3, stride=1, padding=1)
        )
        
        # Bidirectional LSTM

        # Num layers: num_layers (3)
        # Num nodes: num_nodes (128)
        # Dropout = dropout_rate if TRAIN else 0
        # Direction = bidirectional
        self.lstm = torch.nn.LSTM(
            96, 
            num_nodes, 
            num_layers=num_layers, 
            bias=True,    # Should this be false?
            batch_first=True, 
            dropout=dropout_rate, 
            bidirectional=True
        )

        # Fully Connected

        # Input: 2 * num_nodes (256)
        # Output: Number of classes
        print("Classes:", classes)
        print("Class count:", classCnt)
        
        self.fc = torch.nn.Linear(num_nodes * 2, classCnt)
        

    def forward(self, inks, lengths):
        
        # print(inks.shape)
        # print(lengths)
        # print()

        # conv
        inks = self.conv(inks.permute(0, 2, 1))
        
        # permute inks back
        inks = inks.permute(0, 2, 1)

        # Inks should now be of shape: (B, L, convFilters[3] (default 96))
        # print(inks.shape)
        
        # lstm
        inks, _ = self.lstm(inks)

        # Inks should now be of shape: (B, L, 2 * num_nodes (default 2 * 128, the 'times 2' is because bidir LSTM doubles the features/nodes))
        # print(inks.shape)

        # mask to remove the data past the initial length of each drawing
        mask = torch.tile(
            torch.unsqueeze(sequence_mask(lengths, inks.shape[1]), 2),
            (1, 1, inks.shape[2])
        )

        # print()
        # print("Mask:", mask.shape)

        inks_maked = torch.where(mask, inks, torch.zeros_like(inks))
        # print(inks_maked.shape)

        inks = torch.sum(inks_maked, dim=1)

        # Inks should now be of shape: (B, 2 * num_nodes)
        # print(inks.shape)
        # print()

        # fc
        inks = self.fc(inks)

        # Inks should now be of shape: (B, num_classes)
        # print(inks.shape)
        # print()

        return inks
        
        
        # embeds = self.word_embeddings(sentence)
        # lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        # tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        # tag_scores = F.log_softmax(tag_space, dim=1)
        # return tag_scores

qd_model = QuickDrawRNN(qd_train_dataset.classes)


Classes: {'airplane': 151623, 'ant': 124612, 'axe': 124122, 'bed': 113862}
Class count: 4


In [11]:
def train_one_epoch(model, qd_loader, optimizer, loss_fn, epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.
    
    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    print(len(qd_loader))
    for i, batch in enumerate(qd_loader):
        # Zero your gradients for every batch!
        optimizer.zero_grad()
        
        # Make predictions for this batch
        logits = model(batch["ink"], batch["length"])
        
        # Compute the loss and its gradients
        loss = loss_fn(logits, batch["classIndex"])
        loss.backward()

        # print("Loss:", loss.item())
        
        # Adjust learning weights
        optimizer.step()
        
        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {}/{} loss: {}'.format(i + 1, len(qd_loader), last_loss))
            tb_x = epoch_index * len(qd_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.
        
        # print()
        # print()
        # print()
        # print()
            
    return last_loss


In [12]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(qd_model.parameters(), lr=0.0001)

In [13]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/quick_draw_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 2

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))
    
    # Make sure gradient tracking is on, and do a pass over the data
    qd_model.train(True)
    avg_loss = train_one_epoch(qd_model, qd_train_dataloader, optimizer, loss_fn, epoch_number, writer)
    
    # We don't need gradients on to do reporting
    qd_model.train(False)

    print()
    print("Beginning validation!")
    
    running_vloss = 0.0
    for i, vbatch in enumerate(qd_test_dataloader):
        voutputs = qd_model(vbatch["ink"], vbatch["length"])
        vloss = loss_fn(voutputs, vbatch["classIndex"])
        running_vloss += vloss

        
        if i % 1000 == 999:
            last_vloss = running_vloss / (i + 1)
            print('  batch {}/{} vloss: {}'.format(i + 1, len(qd_test_dataloader), last_vloss))
    
    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    
    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()
    
    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'models/model_{}_{}'.format(timestamp, epoch_number)
        torch.save(qd_model.state_dict(), model_path)
    
    epoch_number += 1

    print()
    print()


EPOCH 1:
58528
  batch 1000/58528 loss: 1.2873169038891792
  batch 2000/58528 loss: 1.051373823940754
  batch 3000/58528 loss: 0.9791011981368065
  batch 4000/58528 loss: 0.8934804668426514
  batch 5000/58528 loss: 0.8226561665982008
  batch 6000/58528 loss: 0.7095819232165813
  batch 7000/58528 loss: 0.6718349764049053
  batch 8000/58528 loss: 0.6006119401976466
  batch 9000/58528 loss: 0.5546294914782047
  batch 10000/58528 loss: 0.4940850682742894
  batch 11000/58528 loss: 0.48000131288170816
  batch 12000/58528 loss: 0.4424793520085514
  batch 13000/58528 loss: 0.42730687057971956
  batch 14000/58528 loss: 0.39149605928966774
  batch 15000/58528 loss: 0.37618063345970587
  batch 16000/58528 loss: 0.35091781592415644
  batch 17000/58528 loss: 0.3617684913915582
  batch 18000/58528 loss: 0.3085946528383065
  batch 19000/58528 loss: 0.34734545040223747
  batch 20000/58528 loss: 0.32087643685610967
  batch 21000/58528 loss: 0.30544851732684764
  batch 22000/58528 loss: 0.29678082144004