# Quick Draw Model Loader - PyTorch

## <font color='yellow'>Attention! This is the LOADER version of the model, use `QuickDrawModel.ipynb` if you want to train the model!</font>

## Setup

### Imports

In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import json
from datetime import datetime
import functools
import sklearn as sk

if torch.cuda.is_available():
    print("Full power!")
    dev = torch.set_default_device("cuda")
else:
    print("Regular power..")
    dev = torch.set_default_device("cpu")

Full power!


### Checks

In [2]:
torch.get_default_device()

device(type='cuda', index=0)

## Manipulating The Data

### Reading the data

In [3]:
# Define some useful global variables

classes = {}
batch_size = 8
eval_batch_size = 8
dropout_rate = 0.3
num_layers = 3
num_nodes = 128

In [4]:
from os import listdir
from os.path import isfile, join

classes = {}

def parseLine(ndjsonLine):
  """Parse an ndjson line and return ink (as np array) and classname."""
  sample = json.loads(ndjsonLine)
  class_name = sample["word"]
  inkarray = sample["drawing"]
  stroke_lengths = [len(stroke[0]) for stroke in inkarray]
  total_points = sum(stroke_lengths)
  np_ink = np.zeros((total_points, 3), dtype=np.float32)
  current_t = 0
  for stroke in inkarray:
    for i in [0, 1]:
      np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i]
    current_t += len(stroke[0])
    np_ink[current_t - 1, 2] = 1  # stroke_end
      
  # Preprocessing.
  # 1. Size normalization.
  lower = np.min(np_ink[:, 0:2], axis=0)
  upper = np.max(np_ink[:, 0:2], axis=0)
  scale = upper - lower
  scale[scale == 0] = 1
  np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale
    
  # 2. Compute deltas.
  np_ink[1:, 0:2] -= np_ink[0:-1, 0:2]
  np_ink = np_ink[1:, :]
  return torch.from_numpy(np_ink), class_name

def readData(files, train_data, test_data, limit = -1):
    # Clear the global variables
    classes = {}
    
    filesToParse = files if limit < 0 else files[:limit]

    currClassIndex = 0
    classNameToIndex = {}
    
    cnt = 0
    sampleCnt = 0
    for filePath in filesToParse:        
        with open(filePath) as file:
            for line in file:
                # sample = json.loads(line)
                # className = sample["word"]
                features = {}
                features["ink"], features["className"] = parseLine(line)

                # Define the shape of the ink
                features["shape"] = features["ink"].shape

                # Index the class
                if features["className"] not in classNameToIndex:
                    classNameToIndex[features["className"]] = currClassIndex
                    currClassIndex += 1

                features["classIndex"] = classNameToIndex[features["className"]]

                # Keep a class statistic
                if features["className"] not in classes:
                    classes[features["className"]] = 0

                classes[features["className"]] += 1

                if sampleCnt % 11000 < 10000:
                    train_data.append(features)
                else:
                    test_data.append(features)

                sampleCnt += 1

        cnt += 1

        print("Finished parsing {0}/{1}: {2}".format(cnt, len(files), filePath))

    print("Finished parsing all the data!")
    return classes


In [5]:
qd_train_raw_data = []
qd_test_raw_data = []

root_dir = "datasets"

dataFiles = [root_dir + "/" + f for f in listdir(root_dir) if f.endswith(".ndjson")]
classes = readData(dataFiles, qd_train_raw_data, qd_test_raw_data, limit=4)

print()
print(classes)
print("Train data len:", len(qd_train_raw_data))
print("Test data len:", len(qd_test_raw_data))

Finished parsing 1/16: datasets/full_simplified_airplane.ndjson
Finished parsing 2/16: datasets/full_simplified_ant.ndjson
Finished parsing 3/16: datasets/full_simplified_axe.ndjson
Finished parsing 4/16: datasets/full_simplified_bed.ndjson
Finished parsing all the data!

{'airplane': 151623, 'ant': 124612, 'axe': 124122, 'bed': 113862}
Train data len: 468219
Test data len: 46000


### Creating the Dataset and DataLoader

In [6]:
class QuickDrawDataset(Dataset):
    """Quick, Draw! data subset."""

    def __init__(self, data, classes, train):
        """
        Arguments:
            data (list): List of all the parsed data with the readData() function.
            classes (dict): Dictionary with all the classes and how many of each there are.
            train (bool): Says if the dataset is used for training or testing.
        """
        self.data = data
        self.classes = classes
        self.train = train

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

qd_train_dataset = QuickDrawDataset(qd_train_raw_data, classes, True)
qd_test_dataset = QuickDrawDataset(qd_test_raw_data, classes, False)

In [7]:
def quickDrawCollateFn(batch, batch_size):
    shapes = [sample["shape"] for sample in batch]
    maxLen = max([shape[0] for shape in shapes])

    ## Makes a dictionary of lists
    newBatch = {
        "ink": torch.zeros((batch_size, maxLen, 3)),
        "shape": torch.zeros((batch_size, 2), dtype=int),
        "length": torch.zeros((batch_size), dtype=int),
        "className": [],
        "classIndex": torch.zeros((batch_size), dtype=int),
        "maxLen": maxLen
    }
    for i, sample in enumerate(batch):
        newBatch["className"].append(sample["className"])
        newBatch["classIndex"][i] = sample["classIndex"]
        newBatch["shape"][i] = torch.FloatTensor(list(sample["shape"]))
        newBatch["length"][i] = sample["shape"][0]

        # Makes a copy of the tensor
        newInk = F.pad(sample["ink"], (0, 0, 0, maxLen - sample["shape"][0]))
        newBatch["ink"][i] = newInk
    
    return newBatch

qd_train_dataloader = DataLoader(qd_train_dataset, batch_size=batch_size, shuffle=True, 
                                 num_workers=0, generator=torch.Generator(device='cuda'),
                                 collate_fn=functools.partial(quickDrawCollateFn, batch_size=batch_size))

qd_test_dataloader = DataLoader(qd_test_dataset, batch_size=batch_size, shuffle=False, 
                                num_workers=0, generator=torch.Generator(device='cuda'),
                                collate_fn=functools.partial(quickDrawCollateFn, batch_size=batch_size))

qd_eval_dataloader = DataLoader(qd_test_dataset, batch_size=eval_batch_size, shuffle=True, 
                                num_workers=0, generator=torch.Generator(device='cuda'),
                                collate_fn=functools.partial(quickDrawCollateFn, batch_size=eval_batch_size))

# for i_batch, sample_batched in enumerate(qd_train_dataloader):
#     print(i_batch, sample_batched["ink"][0], sample_batched["shape"], sample_batched["length"], sample_batched["classIndex"])
#     print()

#     # observe 4th batch and stop.
#     if i_batch == 1:
#         break

## Defining The Model

In [8]:
# Utils

# Thanks to: https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/39036/2
def sequence_mask(lengths, maxlen = None, dtype=torch.bool):
    if maxlen is None:
        maxlen = lengths.max()
    mask = ~(torch.ones((len(lengths), maxlen)).cumsum(dim=1).t() > lengths).t()
    mask.type(dtype)
    return mask


In [9]:
class QuickDrawRNN(torch.nn.Module):

    def __init__(self, classes):
        super(QuickDrawRNN, self).__init__()

        # Init data
        classCnt = len(classes)

        # 3x 1D Convolutions
        
        # Filters: [48, 64, 96]
        # Length of convolutional filters: [5, 5, 3]
        self.conv = torch.nn.Sequential(
            torch.nn.Conv1d( 3, 48, 5, stride=1, padding=2),  # Should we disable bias?
            torch.nn.Dropout(p=dropout_rate),
            torch.nn.Conv1d(48, 64, 5, stride=1, padding=2),
            torch.nn.Dropout(p=dropout_rate),
            torch.nn.Conv1d(64, 96, 3, stride=1, padding=1)
        )
        
        # Bidirectional LSTM

        # Num layers: num_layers (3)
        # Num nodes: num_nodes (128)
        # Dropout = dropout_rate if TRAIN else 0
        # Direction = bidirectional
        self.lstm = torch.nn.LSTM(
            96, 
            num_nodes, 
            num_layers=num_layers, 
            bias=True,    # Should this be false?
            batch_first=True, 
            dropout=dropout_rate, 
            bidirectional=True
        )

        # Fully Connected

        # Input: 2 * num_nodes (256)
        # Output: Number of classes
        print("Classes:", classes)
        print("Class count:", classCnt)
        
        self.fc = torch.nn.Linear(num_nodes * 2, classCnt)
        

    def forward(self, inks, lengths):
        
        # print(inks.shape)
        # print(lengths)
        # print()

        # conv
        inks = self.conv(inks.permute(0, 2, 1))
        
        # permute inks back
        inks = inks.permute(0, 2, 1)

        # Inks should now be of shape: (B, L, convFilters[3] (default 96))
        # print(inks.shape)
        
        # lstm
        inks, _ = self.lstm(inks)

        # Inks should now be of shape: (B, L, 2 * num_nodes (default 2 * 128, the 'times 2' is because bidir LSTM doubles the features/nodes))
        # print(inks.shape)

        # mask to remove the data past the initial length of each drawing
        mask = torch.tile(
            torch.unsqueeze(sequence_mask(lengths, inks.shape[1]), 2),
            (1, 1, inks.shape[2])
        )

        # print()
        # print("Mask:", mask.shape)

        inks_maked = torch.where(mask, inks, torch.zeros_like(inks))
        # print(inks_maked.shape)

        inks = torch.sum(inks_maked, dim=1)

        # Inks should now be of shape: (B, 2 * num_nodes)
        # print(inks.shape)
        # print()

        # fc
        inks = self.fc(inks)

        # Inks should now be of shape: (B, num_classes)
        # print(inks.shape)
        # print()

        return inks
        
        
        # embeds = self.word_embeddings(sentence)
        # lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        # tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        # tag_scores = F.log_softmax(tag_space, dim=1)
        # return tag_scores

qd_model = QuickDrawRNN(qd_train_dataset.classes)


Classes: {'airplane': 151623, 'ant': 124612, 'axe': 124122, 'bed': 113862}
Class count: 4


## Load the model

In [10]:
qd_model.load_state_dict(torch.load("models/model_20250411_222609_1", weights_only=True))
qd_model.eval()

QuickDrawRNN(
  (conv): Sequential(
    (0): Conv1d(3, 48, kernel_size=(5,), stride=(1,), padding=(2,))
    (1): Dropout(p=0.3, inplace=False)
    (2): Conv1d(48, 64, kernel_size=(5,), stride=(1,), padding=(2,))
    (3): Dropout(p=0.3, inplace=False)
    (4): Conv1d(64, 96, kernel_size=(3,), stride=(1,), padding=(1,))
  )
  (lstm): LSTM(96, 128, num_layers=3, batch_first=True, dropout=0.3, bidirectional=True)
  (fc): Linear(in_features=256, out_features=4, bias=True)
)

In [11]:
def evalInput(batch, y_true, y_pred):
    # We don't need gradients on to do reporting
    qd_model.train(False)

    logits = qd_model(batch["ink"], batch["length"])

    predicted_labels = torch.argmax(logits, dim=1)
    actual_labels = batch["classIndex"]

    # Very inefficient, would love to just count true pos/neg, false pos/neg, but using scikit is fun :)
    y_true += actual_labels.cpu()
    y_pred += predicted_labels.cpu()
    
    # print(logits)
    # print("Result:   ", predictedLabels)
    # print("Expected: ", actualLabels)
    # print()


In [12]:
y_true, y_pred = [], []
for i, batch in enumerate(qd_eval_dataloader):
    # print(batch)

    evalInput(batch, y_true, y_pred)

    # Intermediate accuracy
    int_acc = sk.metrics.accuracy_score(y_true, y_pred)

    if i % 10 == 9:
        print("Accuracy after {}/{}: {:.4f}%".format(i + 1, len(qd_eval_dataloader), int_acc * 100.))

    if i > 150:
        break

acc = sk.metrics.accuracy_score(y_true, y_pred)

print()
print("Final accuracy:", acc)


Accuracy after 10/5750: 97.5000%
Accuracy after 20/5750: 97.5000%
Accuracy after 30/5750: 97.9167%
Accuracy after 40/5750: 96.8750%
Accuracy after 50/5750: 97.0000%
Accuracy after 60/5750: 96.8750%
Accuracy after 70/5750: 97.1429%
Accuracy after 80/5750: 97.0312%
Accuracy after 90/5750: 96.9444%
Accuracy after 100/5750: 97.0000%
Accuracy after 110/5750: 97.1591%
Accuracy after 120/5750: 97.1875%
Accuracy after 130/5750: 97.3077%
Accuracy after 140/5750: 97.0536%
Accuracy after 150/5750: 97.1667%

Final accuracy: 0.9712171052631579


### Per-class accuracy

In [33]:
def evalInput(batch, y_true, y_pred):
    # We don't need gradients on to do reporting
    qd_model.train(False)

    logits = qd_model(batch["ink"], batch["length"])

    predicted_labels = torch.argmax(logits, dim=1)
    actual_labels = batch["classIndex"]

    # Very inefficient, would love to just count true pos/neg, false pos/neg, but using scikit is fun :)
    true = actual_labels.cpu()
    pred = predicted_labels.cpu()

    for i, true_cls_raw in enumerate(true):
        true_cls = true_cls_raw.item()

        if true_cls not in class_cnt:
            class_cnt[true_cls] = 0;
            class_correct_cnt[true_cls] = 0
        
        class_cnt[true_cls] += 1
        class_correct_cnt[true_cls] += 1 if pred[i].item() == true_cls else 0
    
    # print(logits)
    # print("Result:   ", predictedLabels)
    # print("Expected: ", actualLabels)
    # print()


In [48]:
class_cnt = {}
class_correct_cnt = {}
for i, batch in enumerate(qd_eval_dataloader):
    evalInput(batch, y_true, y_pred)

    if i % 50 == 9:
        for j, _ in enumerate(classes):
            # Intermediate accuracy
            int_acc = 0 if j not in class_cnt else class_correct_cnt[j] / class_cnt[j]
            
            print("Accuracy after {}/{} for class {}: {:.4f}%".format(i + 1, len(qd_eval_dataloader), j, int_acc * 100.))

        print()

    if i > 1000:
        break

print()
for i, _ in enumerate(classes):
    # Intermediate accuracy
    int_acc = 0 if i not in class_cnt else class_correct_cnt[i] / class_cnt[i]
    
    print("Final accuracy for class {}: {:.4f}%".format(i, int_acc * 100.))


Accuracy after 10/5750 for class 0: 100.0000%
Accuracy after 10/5750 for class 1: 100.0000%
Accuracy after 10/5750 for class 2: 96.2963%
Accuracy after 10/5750 for class 3: 100.0000%

Accuracy after 60/5750 for class 0: 97.0588%
Accuracy after 60/5750 for class 1: 99.2537%
Accuracy after 60/5750 for class 2: 95.6522%
Accuracy after 60/5750 for class 3: 97.8947%

Accuracy after 110/5750 for class 0: 97.4170%
Accuracy after 110/5750 for class 1: 99.0868%
Accuracy after 110/5750 for class 2: 94.8837%
Accuracy after 110/5750 for class 3: 97.1429%

Accuracy after 160/5750 for class 0: 97.8836%
Accuracy after 160/5750 for class 1: 99.0881%
Accuracy after 160/5750 for class 2: 95.1923%
Accuracy after 160/5750 for class 3: 96.9349%

Accuracy after 210/5750 for class 0: 97.9508%
Accuracy after 210/5750 for class 1: 98.3982%
Accuracy after 210/5750 for class 2: 95.7816%
Accuracy after 210/5750 for class 3: 96.5909%

Accuracy after 260/5750 for class 0: 97.7966%
Accuracy after 260/5750 for class 