In [4]:
!pip install quickdraw

Collecting quickdraw
  Downloading quickdraw-1.0.0-py3-none-any.whl.metadata (1.3 kB)
Downloading quickdraw-1.0.0-py3-none-any.whl (11 kB)
Installing collected packages: quickdraw
Successfully installed quickdraw-1.0.0


In [8]:
import os

import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
import torchvision
from torchvision import transforms

import numpy as np
from matplotlib import patches, path, pyplot as plt
from tqdm.notebook import tqdm

import quickdraw as QD

CLASSES = [
    "aircraft carrier",
    "airplane",
    "alarm clock",
    "ambulance",
    "angel",
    "animal migration",
    "ant",
    "anvil",
    "apple",
    "arm",
    "asparagus",
    "axe",
    "backpack",
    "banana",
    "bandage",
    "barn",
    "baseball",
    "baseball bat",
    "basket",
    "basketball",
    "bat",
    "bathtub",
    "beach",
    "bear",
    "beard",
    "bed",
    "bee",
    "belt",
    "bench",
    "bicycle",
    "binoculars",
    "bird",
    "birthday cake",
    "blackberry",
    "blueberry",
    "book",
    "boomerang",
    "bottlecap",
    "bowtie",
    "bracelet",
    "brain",
    "bread",
    "bridge",
    "broccoli",
    "broom",
    "bucket",
    "bulldozer",
    "bus",
    "bush",
    "butterfly",
    "cactus",
    "cake",
    "calculator",
    "calendar",
    "camel",
    "camera",
    "camouflage",
    "campfire",
    "candle",
    "cannon",
    "canoe",
    "car",
    "carrot",
    "castle",
    "cat",
    "ceiling fan",
    "cello",
    "cell phone",
    "chair",
    "chandelier",
    "church",
    "circle",
    "clarinet",
    "clock",
    "cloud",
    "coffee cup",
    "compass",
    "computer",
    "cookie",
    "cooler",
    "couch",
    "cow",
    "crab",
    "crayon",
    "crocodile",
    "crown",
    "cruise ship",
    "cup",
    "diamond",
    "dishwasher",
    "diving board",
    "dog",
    "dolphin",
    "donut",
    "door",
    "dragon",
    "dresser",
    "drill",
    "drums",
    "duck",
    "dumbbell",
    "ear",
    "elbow",
    "elephant",
    "envelope",
    "eraser",
    "eye",
    "eyeglasses",
    "face",
    "fan",
    "feather",
    "fence",
    "finger",
    "fire hydrant",
    "fireplace",
    "firetruck",
    "fish",
    "flamingo",
    "flashlight",
    "flip flops",
    "floor lamp",
    "flower",
    "flying saucer",
    "foot",
    "fork",
    "frog",
    "frying pan",
    "garden",
    "garden hose",
    "giraffe",
    "goatee",
    "golf club",
    "grapes",
    "grass",
    "guitar",
    "hamburger",
    "hammer",
    "hand",
    "harp",
    "hat",
    "headphones",
    "hedgehog",
    "helicopter",
    "helmet",
    "hexagon",
    "hockey puck",
    "hockey stick",
    "horse",
    "hospital",
    "hot air balloon",
    "hot dog",
    "hot tub",
    "hourglass",
    "house",
    "house plant",
    "hurricane",
    "ice cream",
    "jacket",
    "jail",
    "kangaroo",
    "key",
    "keyboard",
    "knee",
    "knife",
    "ladder",
    "lantern",
    "laptop",
    "leaf",
    "leg",
    "light bulb",
    "lighter",
    "lighthouse",
    "lightning",
    "line",
    "lion",
    "lipstick",
    "lobster",
    "lollipop",
    "mailbox",
    "map",
    "marker",
    "matches",
    "megaphone",
    "mermaid",
    "microphone",
    "microwave",
    "monkey",
    "moon",
    "mosquito",
    "motorbike",
    "mountain",
    "mouse",
    "moustache",
    "mouth",
    "mug",
    "mushroom",
    "nail",
    "necklace",
    "nose",
    "ocean",
    "octagon",
    "octopus",
    "onion",
    "oven",
    "owl",
    "paintbrush",
    "paint can",
    "palm tree",
    "panda",
    "pants",
    "paper clip",
    "parachute",
    "parrot",
    "passport",
    "peanut",
    "pear",
    "peas",
    "pencil",
    "penguin",
    "piano",
    "pickup truck",
    "picture frame",
    "pig",
    "pillow",
    "pineapple",
    "pizza",
    "pliers",
    "police car",
    "pond",
    "pool",
    "popsicle",
    "postcard",
    "potato",
    "power outlet",
    "purse",
    "rabbit",
    "raccoon",
    "radio",
    "rain",
    "rainbow",
    "rake",
    "remote control",
    "rhinoceros",
    "rifle",
    "river",
    "roller coaster",
    "rollerskates",
    "sailboat",
    "sandwich",
    "saw",
    "saxophone",
    "school bus",
    "scissors",
    "scorpion",
    "screwdriver",
    "sea turtle",
    "see saw",
    "shark",
    "sheep",
    "shoe",
    "shorts",
    "shovel",
    "sink",
    "skateboard",
    "skull",
    "skyscraper",
    "sleeping bag",
    "smiley face",
    "snail",
    "snake",
    "snorkel",
    "snowflake",
    "snowman",
    "soccer ball",
    "sock",
    "speedboat",
    "spider",
    "spoon",
    "spreadsheet",
    "square",
    "squiggle",
    "squirrel",
    "stairs",
    "star",
    "steak",
    "stereo",
    "stethoscope",
    "stitches",
    "stop sign",
    "stove",
    "strawberry",
    "streetlight",
    "string bean",
    "submarine",
    "suitcase",
    "sun",
    "swan",
    "sweater",
    "swing set",
    "sword",
    "syringe",
    "table",
    "teapot",
    "teddy-bear",
    "telephone",
    "television",
    "tennis racquet",
    "tent",
    "The Eiffel Tower",
    "The Great Wall of China",
    "The Mona Lisa",
    "tiger",
    "toaster",
    "toe",
    "toilet",
    "tooth",
    "toothbrush",
    "toothpaste",
    "tornado",
    "tractor",
    "traffic light",
    "train",
    "tree",
    "triangle",
    "trombone",
    "truck",
    "trumpet",
    "t-shirt",
    "umbrella",
    "underwear",
    "van",
    "vase",
    "violin",
    "washing machine",
    "watermelon",
    "waterslide",
    "whale",
    "wheel",
    "windmill",
    "wine bottle",
    "wine glass",
    "wristwatch",
    "yoga",
    "zebra",
    "zigzag"
]

In [14]:
datagroup = QD.QuickDrawDataGroup(CLASSES[0], max_drawings=10000)

loading aircraft carrier drawings
load complete


In [190]:
from typing import Union

class QuickDrawDataSet(torch.utils.data.Dataset):
    def __init__(self, name, max_drawings, transform, recognized=True, classes=CLASSES):
        self.id = classes.index(name)
        self.datagroup = QD.QuickDrawDataGroup(name,
                                               max_drawings=max_drawings, 
                                               recognized=recognized)
        self.max_drawings = max_drawings
        self.transform = transform

    def __len__(self):
        return self.datagroup.drawing_count
    
    def _get_single_item(self, index: int):
        img = self.datagroup.get_drawing(index).image
        return (self.transform(img), self.id)

    def __getitem__(self, index: Union[int, slice, np.ndarray]):
        if type(index) == slice or type(index) == np.ndarray:
            if type(index) == slice:
                index = range(index.start, index.stop, index.step or 1)
            return [self._get_single_item(i) for i in index]
        return self._get_single_item(index)

In [87]:
from torchvision import transforms

ds = QuickDrawDataSet(CLASSES[0], max_drawings=10000, transform=transforms.ToTensor())

loading aircraft carrier drawings
load complete


In [88]:
ds[np.array([1, 3, 5, 7])]

[(<PIL.Image.Image image mode=RGB size=255x255>, 0),
 (<PIL.Image.Image image mode=RGB size=255x255>, 0),
 (<PIL.Image.Image image mode=RGB size=255x255>, 0),
 (<PIL.Image.Image image mode=RGB size=255x255>, 0)]

In [184]:
import bisect

class QuickDrawDataAllSet(torch.utils.data.Dataset):
    def __init__(self, classes, max_drawings, transform, recognized=True):
        params = dict(
            max_drawings=max_drawings, transform=transform, classes=classes, recognized=recognized
        )
        self.groups = [QuickDrawDataSet(cls, **params) for cls in classes]
        self.offset = [0]
        self.count = 0
        for g in self.groups:
            self.count += len(g)
            self.offset.append(self.count)

    def __len__(self):
        return self.count
    
    def get_single_item(self, index: int):
        gi = bisect.bisect_right(self.offset, index) - 1
        return self.groups[gi][index - self.offset[gi]]

    def __getitem__(self, index: Union[int, slice, np.ndarray]):
        if type(index) == slice or type(index) == np.ndarray:
            if type(index) == slice:
                index = range(index.start, index.stop, index.step or 1)
            return [self.get_single_item(i) for i in index]
        return self.get_single_item(index)

In [193]:
def transfom(data):
    print(data)
    return transforms.ToTensor()(data)

dataset_all = QuickDrawDataAllSet(CLASSES[:5], max_drawings=1000, transform=transforms.ToTensor())
dataset_all

loading aircraft carrier drawings
load complete
loading airplane drawings
load complete
loading alarm clock drawings
load complete
loading ambulance drawings
load complete
loading angel drawings
load complete


<__main__.QuickDrawDataAllSet at 0x29f9f8a30>

In [194]:
dataset_all[0]

(tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],
 
         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],
 
         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]),
 0)

In [195]:
dataset_all[0:3]

[(tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           ...,
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.]],
  
          [[1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           ...,
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.]],
  
          [[1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           ...,
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.]]]),
  0),
 (tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           [1., 1., 1.,  ..., 1., 1., 1.],
           ...,
           [1., 

In [174]:
def split_index(indices, train_size: float) -> tuple:
    n = len(indices)
    i = int(n * train_size)
    return (indices[:i], indices[i:])

In [196]:
from torch.utils.data import DataLoader, SequentialSampler, SubsetRandomSampler

BATCH_SIZE = 256

index_all = np.arange(len(dataset_all))
train_valid_idx, test_idx = split_index(index_all, train_size=0.8)
train_idx, valid_idx = split_index(train_valid_idx, train_size=0.7)

train_subsampler = SubsetRandomSampler(train_idx)
valid_subsampler = SequentialSampler(valid_idx)
test_subsampler = SequentialSampler(test_idx)

train_dataloader = DataLoader(dataset_all, batch_size=BATCH_SIZE, sampler=train_subsampler)
valid_dataloader = DataLoader(dataset_all, batch_size=BATCH_SIZE, sampler=valid_subsampler)
test_dataloader = DataLoader(dataset_all, batch_size=BATCH_SIZE, sampler=test_subsampler)

In [198]:
image_batch, label_batch = next(iter(train_dataloader))

In [200]:
image_batch.shape

torch.Size([256, 3, 255, 255])

In [202]:
label_batch.shape, label_batch

(torch.Size([256]),
 tensor([1, 1, 2, 1, 1, 1, 1, 2, 2, 0, 0, 2, 0, 2, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0,
         1, 1, 1, 2, 1, 1, 2, 1, 1, 0, 1, 0, 2, 1, 0, 1, 1, 0, 1, 2, 2, 2, 0, 1,
         2, 2, 2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 2, 1,
         1, 2, 0, 2, 1, 1, 2, 0, 1, 0, 1, 0, 2, 1, 2, 0, 2, 2, 0, 2, 1, 1, 1, 2,
         1, 0, 0, 2, 1, 0, 2, 0, 1, 0, 0, 0, 1, 2, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0,
         1, 0, 1, 2, 1, 1, 0, 0, 0, 0, 2, 0, 2, 0, 1, 2, 1, 1, 1, 2, 2, 1, 1, 1,
         2, 1, 0, 0, 0, 2, 1, 1, 1, 0, 1, 0, 1, 0, 1, 2, 1, 1, 1, 2, 2, 2, 0, 2,
         1, 0, 0, 2, 1, 1, 1, 1, 0, 1, 2, 1, 2, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 1,
         2, 0, 0, 1, 2, 0, 0, 1, 0, 1, 2, 0, 0, 1, 0, 2, 1, 0, 2, 1, 1, 1, 2, 2,
         0, 0, 0, 1, 2, 1, 0, 0, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 0, 1, 2, 1, 2,
         0, 1, 0, 0, 2, 1, 0, 1, 0, 1, 2, 1, 1, 2, 0, 1]))