In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import torch
print(torch.__version__)
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset
from torchvision import transforms as T
import random
import cv2
import numpy as np
!pip install pyefd
import pyefd
from google.colab.patches import cv2_imshow
!pip install cairocffi
import cairocffi as cairo
import struct
from struct import unpack
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

1.12.1+cu113
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyefd
  Downloading pyefd-1.6.0-py2.py3-none-any.whl (7.7 kB)
Installing collected packages: pyefd
Successfully installed pyefd-1.6.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting cairocffi
  Downloading cairocffi-1.3.0.tar.gz (88 kB)
[K     |████████████████████████████████| 88 kB 4.0 MB/s 
Building wheels for collected packages: cairocffi
  Building wheel for cairocffi (setup.py) ... [?25l[?25hdone
  Created wheel for cairocffi: filename=cairocffi-1.3.0-py3-none-any.whl size=89668 sha256=45d0bf39cf5733527f79731161ae69303df9774ce55370c7c5200c65c2e20f72
  Stored in directory: /root/.cache/pip/wheels/4e/ca/e1/5c8a9692a27f639a07c949044bec943f26c81cd53d3805319f
Successfully built cairocffi
Installing collected packages: cairocffi
Successfully installed cairocffi-1.3.0
Looking in indexes: https://pypi.o

In [3]:
# Env vars
torch.use_deterministic_algorithms(False)

# Const vars
SAVE_PATH = '/content/drive/My Drive/Fourier/Saved Models/GNN conn.pt'
RAND_SEED = 0
DEVICE = "cuda"

IMG_SIDE = 28
NUM_CLASSES = 343
EPOCHS = 10
LEARNING_RATE = 0.01
BATCH_SIZE = 500
LOSS_FN = nn.CrossEntropyLoss()
EDGE_ATTR_DIM = 1

In [4]:
# convert raw vector image to list of raster images, one for each stroke
def vector_to_raster(vector_image, side=IMG_SIDE, line_diameter=16, padding=96, bg_color=(0,0,0), fg_color=(1,1,1)):
  """
  padding and line_diameter are relative to the original 256x256 image.
  """
  
  original_side = 256.
  
  surface = cairo.ImageSurface(cairo.FORMAT_ARGB32, side, side)
  ctx = cairo.Context(surface)
  ctx.set_antialias(cairo.ANTIALIAS_BEST)
  ctx.set_line_cap(cairo.LINE_CAP_ROUND)
  ctx.set_line_join(cairo.LINE_JOIN_ROUND)
  ctx.set_line_width(line_diameter)

  # scale to match the new size
  # add padding at the edges for the line_diameter
  # and add additional padding to account for antialiasing
  total_padding = padding * 2. + line_diameter
  new_scale = float(side) / float(original_side + total_padding)
  ctx.scale(new_scale, new_scale)
  ctx.translate(total_padding / 2., total_padding / 2.)
      
  bbox = np.hstack(vector_image).max(axis=1)
  offset = ((original_side, original_side) - bbox) / 2.
  offset = offset.reshape(-1,1)
  centered = [stroke + offset for stroke in vector_image]

  stroke_rasters = []
  for xv, yv in centered:
    # clear background
    ctx.set_source_rgb(*bg_color)
    ctx.paint()

    # draw strokes, this is the most cpu-intensive part
    ctx.set_source_rgb(*fg_color)        
    ctx.move_to(xv[0], yv[0])
    for x, y in zip(xv, yv):
        ctx.line_to(x, y)
    ctx.stroke()

    data = surface.get_data()
    stroke_raster = np.copy(np.asarray(data)[::4]).reshape(28, 28)
    stroke_rasters.append(stroke_raster)

  return stroke_rasters

def get_edges(stroke_rasters): 
  adj_1=[]
  adj_2=[]
  strokes_connected = []
  for i in range(len(stroke_rasters)):
      values=np.nonzero(stroke_rasters[i])
      for j in range(i+1,len(stroke_rasters)):
          adj_1.append(i)
          adj_2.append(j)
          adj_1.append(j)
          adj_2.append(i)
          connected = False
          for k in range(len(values[0])):
              x=values[0][k]
              y=values[1][k]
              
              sum = stroke_rasters[j][x][y]
              if x != 0 and y != 0:
                sum += stroke_rasters[j][x-1][y-1]
              if y != 0:
                sum += stroke_rasters[j][x][y-1]
              if x != IMG_SIDE - 1 and y != 0:
                sum += stroke_rasters[j][x+1][y-1] 
              if x != IMG_SIDE - 1:
                sum += stroke_rasters[j][x+1][y]
              if x != IMG_SIDE - 1 and y != IMG_SIDE - 1:
                sum += stroke_rasters[j][x+1][y+1] 
              if y != IMG_SIDE - 1:
                sum += stroke_rasters[j][x][y+1] 
              if x != 0 and y != IMG_SIDE - 1:
                sum += stroke_rasters[j][x-1][y+1]
              if x != 0:
                sum += stroke_rasters[j][x-1][y]
              
              if sum != 0:
                connected = True
                break
          if connected:
            strokes_connected.append([1])
            strokes_connected.append([1])
          else:
            strokes_connected.append([0])
            strokes_connected.append([0])

  edge_index = torch.LongTensor([adj_1,adj_2])
  edge_attr = torch.FloatTensor(strokes_connected)
  return edge_index, edge_attr

# transform functions - take sketch image, return torch tensor of descriptors
def fourier_transform(vector_img, is_test):
  stroke_rasters = vector_to_raster(vector_img)

  # add rotations and translations at test time
  if is_test: 
    stroke_rasters = np.stack(stroke_rasters)
    stroke_rasters = torch.from_numpy(stroke_rasters).float()

    angle = random.random()*60 - 30
    deltaX = random.randint(-3, 3)
    deltaY = random.randint(-3, 3)

    stroke_rasters = T.functional.affine(stroke_rasters,angle,[deltaX, deltaY],1,0,
                                          interpolation=T.InterpolationMode.BILINEAR)
    stroke_rasters = np.asarray(stroke_rasters)
    stroke_rasters = np.split(stroke_rasters, stroke_rasters.shape[0])
    stroke_rasters = [np.squeeze(a) for a in stroke_rasters]

  stroke_rasters_binary = []
  for raster in stroke_rasters:
    raster_binary = cv2.threshold(raster, 100, 255, cv2.THRESH_BINARY)[1]
    stroke_rasters_binary.append(raster_binary.astype(np.float32))

  stroke_fourier_descriptors = []
  strokes_to_remove = []
  for i, raster in enumerate(stroke_rasters_binary):
    contours, hierarchy = cv2.findContours(raster.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)

    largest_size = 0
    largest_index = 0
    for k, contour in enumerate(contours):
        if len(contour) > largest_size:
          largest_size = len(contour)
          largest_index = k

    if largest_size > 1:
      contour = contours[largest_index]
      coeffs = pyefd.elliptic_fourier_descriptors(np.squeeze(contour), order=FOURIER_ORDER, normalize=True)
      stroke_fourier_descriptors.append(coeffs.flatten())
    else:
      strokes_to_remove.append(i)

  for i in reversed(strokes_to_remove):
    del stroke_rasters_binary[i]

  edge_indices, edge_attr = get_edges(stroke_rasters_binary)
  stroke_fourier_descriptors = np.stack(stroke_fourier_descriptors)
  stroke_fourier_descriptors = torch.from_numpy(stroke_fourier_descriptors).float()
  return stroke_fourier_descriptors, edge_indices, edge_attr

# helper method to find class based on imgset index
def find_class(idx, num_list):
  class_id = 0
  sum = num_list[class_id]
  while idx >= sum:
    class_id += 1
    sum += num_list[class_id]
  return class_id

# deterministic worker re-seeding
def seed_worker(worker_id):
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)

# custom dataset for quickdraw
class QuickdrawDataset(Dataset):
  def __init__(self, imgs, nums, is_test):
    self.imgs = imgs
    self.nums = nums
    self.len = sum(nums)
    self.is_test = is_test

  def __len__(self):
    return self.len

  def __getitem__(self, idx):
    img = self.imgs[idx]
    x, edge_index, edge_attr = fourier_transform(img, self.is_test)
    y = find_class(idx, self.nums)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)


# pytorch-geometric GCN
class GCN(torch.nn.Module):
  def __init__(self):
    super(GCN, self).__init__()
    self.conv1 = GCNConv(FOURIER_ORDER * 4, 128)
    self.conv2 = GCNConv(128, 128)
    self.conv3 = GCNConv(128, 128)
    self.conv4 = GCNConv(128, 128)
    self.conv5 = GCNConv(128, 128)
    self.conv6 = GCNConv(128, 128)
    self.conv7 = GCNConv(128, 128)
    self.conv8 = GCNConv(128, 128)
    self.fc1 = nn.Linear(128, 512)
    self.fc2 = nn.Linear(512, 512)
    self.head = nn.Linear(512, NUM_CLASSES)
    self.edge_proj = nn.Linear(EDGE_ATTR_DIM, 1)
    self.relu = nn.ReLU()
    self.sigmoid = nn.Sigmoid()


  def forward(self, data):
    x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

    if EDGE_ATTR_DIM > 1:
      edge_attr = edge_attr.squeeze(dim=2)
    edge_weight = self.edge_proj(edge_attr)
    edge_weight = self.sigmoid(edge_weight)
    x = self.conv1(x, edge_index, edge_weight)
    x = self.relu(x)
    x = self.conv2(x, edge_index, edge_weight)
    x = self.relu(x)
    x = self.conv3(x, edge_index, edge_weight)
    x = self.relu(x)
    x = self.conv4(x, edge_index, edge_weight)
    x = self.relu(x)
    # x = self.conv5(x, edge_index, edge_weight)
    # x = self.relu(x)
    # x = self.conv6(x, edge_index, edge_weight)
    # x = self.relu(x)
    # x = self.conv7(x, edge_index, edge_weight)
    # x = self.relu(x)
    # x = self.conv8(x, edge_index, edge_weight)
    # x = self.relu(x)
    x = global_mean_pool(x, batch)
    x = self.fc1(x)
    x = self.relu(x)
    # x = self.fc2(x)
    # x = self.relu(x)
    return self.head(x)


def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train() # put the model in train mode
    total_loss = 0
    total_correct = 0
    # for each batch in the training set compute loss and update model parameters
    for batch, data in enumerate(dataloader):
      data = data.to(DEVICE)
      # Compute prediction and loss
      out = model(data)
      loss = loss_fn(out, data.y)

      # Backpropagation to update model parameters
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      # print current training metrics for user
      data, out, loss = data.to("cpu"), out.to("cpu"), loss.to("cpu")
      loss_val = loss.item()
      if batch % 100 == 0:
          current = (batch + 1) * BATCH_SIZE
          print(f"loss: {loss_val:>7f}  [{current:>5d}/{size:>5d}]")

      pred = out.argmax(dim=1, keepdim=True)
      correct = pred.eq(data.y.view_as(pred)).sum().item()
      total_correct += correct
      total_loss += loss_val
      # print(f"train loss: {loss_val:>7f}   train accuracy: {correct / BATCH_SIZE:>7f}   [batch: {batch + 1:>3d}/{(size // BATCH_SIZE) + 1:>3d}]")      
    print(f"\nepoch avg train loss: {total_loss / ((size // BATCH_SIZE) + 1):>7f}   epoch avg train accuracy: {total_correct / size:>7f}")
      
def eval_loop(dataloader, model):
  model.eval()
  size = len(dataloader.dataset)
  with torch.no_grad():
    total_correct = 0
    for data in dataloader:
      data = data.to(DEVICE)
      out = model(data)
      data, out = data.to("cpu"), out.to("cpu")
      pred = out.argmax(dim=1, keepdim=True)
      total_correct += pred.eq(data.y.view_as(pred)).sum().item()

    accuracy = total_correct / size
    print(f"test accuracy: {accuracy:>7f}")


In [5]:
# define methods for unpacking Quickdraw .bin files
def unpack_drawing(file_handle):
  file_handle.read(15)
  n_strokes, = unpack('H', file_handle.read(2))
  image = []
  for i in range(n_strokes):
      n_points, = unpack('H', file_handle.read(2))
      fmt = str(n_points) + 'B'
      x = unpack(fmt, file_handle.read(n_points))
      y = unpack(fmt, file_handle.read(n_points))
      image.append((x, y))

  return image


def unpack_drawings(filename):
  imageset = []
  with open(filename, 'rb') as f:
      while True:
          try:
              imageset.append(unpack_drawing(f))
          except struct.error:
              break
  return imageset

train_dir = '/content/drive/My Drive/Fourier/Quickdraw Dataset Small/Train/'
test_dir = '/content/drive/My Drive/Fourier/Quickdraw Dataset Small/Test/'
train_imgs = []
test_imgs = []
train_nums = []
test_nums = []
list_of_classes = ["The Eiffel Tower", "The Great Wall of China", "The Mona Lisa",
                   "aircraft carrier", "airplane", "alarm clock", "ambulance", 
                   "angel", "ant", "anvil", "apple", "arm", "asparagus", "axe", 
                   "backpack", "banana", "bandage", "barn", "baseball bat", 
                   "baseball", "basket", "basketball", "bathtub", "beach", "bear", 
                   "beard", "bed", "bee", "belt", "bench", "bicycle", "binoculars", 
                   "bird", "birthday cake", "blackberry", "blueberry", "book", 
                   "boomerang", "bottlecap", "bowtie", "bracelet", "brain", 
                   "bread", "bridge", "broccoli", "broom", "bucket", "bulldozer", 
                   "bus", "bush", "butterfly", "cactus", "cake", "calculator", 
                   "calendar", "camel", "camera", "camouflage", "campfire", 
                   "candle", "cannon", "canoe", 'car', 'carrot', "castle", "cat", "ceiling fan", 
                   "cell phone", "cello", "chair", "chandelier", "church", 
                   "circle", "clarinet", "clock", "cloud", "coffee cup", 
                   "compass", "computer", "cookie", "cooler", "couch", "cow",
                   "crab", "crayon", "crocodile", "crown", "cruise ship", 
                   "cup", "diamond", "dishwasher", "diving board", "dog", 
                   "dolphin", "donut", "door", "dragon", "dresser", "drill", 
                   "drums", "duck", "dumbbell", "ear", "elbow", "elephant", 
                   "envelope", "eraser", "eye", "eyeglasses", "face", "fan",
                   "feather", "fence", "finger", "fire hydrant", "fireplace",
                   "firetruck", "fish", "flamingo", "flashlight", "flip flops", 
                   "floor lamp", "flower", "flying saucer", "foot", "fork", 
                   "frog", "frying pan", "garden hose", "garden", "giraffe", 
                   "goatee", "golf club", "grapes", "grass", "guitar", 
                   "hamburger", "hammer", "hand", "harp", "hat", "headphones", 
                   "hedgehog", "helicopter", "helmet", "hexagon", "hockey puck", 
                   "hockey stick", "horse", "hospital", "hot air balloon", 
                   "hot dog", "hot tub", "hourglass", "house plant", "house", 
                   "hurricane", "ice cream", "jacket", "jail", "kangaroo", 
                   "key", "keyboard", "knee", "knife", "ladder", "lantern", 
                   "laptop", "leaf", "leg", "light bulb", "lighter", "lighthouse",
                   "lightning", "line", "lion", "lipstick", "lobster", "lollipop",
                   "mailbox", "map", "marker", "matches", "megaphone", "mermaid", 
                   "microphone", "microwave", "monkey", "moon", "mosquito", 
                   "motorbike", "mountain", "mouse", "moustache", "mouth", "mug",
                   "mushroom", "nail", "necklace", "nose", "ocean", "octagon", 
                   "octopus", "onion", "oven", "owl", "paint can", "paintbrush", 
                   "palm tree", "panda", "pants", "paper clip", "parachute", 
                   "parrot", "passport", "peanut", "pear", "peas", "pencil", 
                   "penguin", "piano", "pickup truck", "picture frame", "pig", 
                   "pillow", "pineapple", "pizza", "pliers", "police car", 
                   "pond", "pool", "popsicle", "postcard", "potato", 
                   "power outlet", "purse", "rabbit", "raccoon", "radio", 
                   "rain", 'rainbow', 'rake', 'remote control', 'rhinoceros', 
                   'rifle', 'river', 'roller coaster', 'rollerskates', 
                   'sailboat', 'sandwich', 'saw', 'saxophone', 'school bus', 
                   'scissors', 'scorpion', 'screwdriver', 'sea turtle', 
                   'see saw', 'shark', 'sheep', 'shoe', 'shorts', 'shovel', 
                   'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping bag', 
                   'smiley face', 'snail', 'snake', 'snorkel', 'snowflake', 
                   'snowman', 'soccer ball', 'sock', 'speedboat', 'spider', 
                   'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 
                   'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 
                   'stop sign', 'stove', 'strawberry', 'streetlight', 
                   'string bean', 'submarine', 'suitcase', 'sun', 'swan', 
                   'sweater', 'swing set', 'sword', 'syringe', 't-shirt', 
                   'table', 'teapot', 'teddy-bear', 'telephone', 'television', 
                   'tennis racquet', 'tent', 'tiger', 'toaster', 'toe', 'toilet', 
                   'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 
                   'traffic light', 'train', 'tree', 'triangle', 'trombone', 
                   'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 
                   'washing machine', 'watermelon', 'waterslide', 'whale', 
                   'wheel', 'windmill', 'wine bottle', 'wine glass', 'wristwatch', 
                   'yoga', 'zebra', 'zigzag']

In [6]:
# load dataset
for item in list_of_classes:
  train_folder = train_dir + item + '.bin'
  test_folder = test_dir + item + '.bin'
  train_drawings = unpack_drawings(train_folder)
  train_imgs += train_drawings
  train_nums.append(len(train_drawings))
  test_drawings = unpack_drawings(test_folder)
  test_imgs += test_drawings
  test_nums.append(len(test_drawings))

In [7]:
for FOURIER_ORDER in reversed(range(6, 7)):
  # seed RNGs
  torch.manual_seed(RAND_SEED)
  random.seed(RAND_SEED)

  # create datasets
  train_fourier_data = QuickdrawDataset(train_imgs, train_nums, is_test=False)
  # eval_fourier_data = QuickdrawDataset(test_imgs, test_nums, is_test=False)
  test_fourier_data = QuickdrawDataset(test_imgs, test_nums, is_test=True)

  # create dataloaders
  g = torch.Generator()
  g.manual_seed(RAND_SEED)
  train_fourier_loader = DataLoader(train_fourier_data, batch_size=BATCH_SIZE, shuffle=True, worker_init_fn=seed_worker, generator=g)
  # eval_fourier_loader = DataLoader(eval_fourier_data, batch_size=BATCH_SIZE, shuffle=False, worker_init_fn=seed_worker, generator=g)
  test_fourier_loader = DataLoader(test_fourier_data, batch_size=BATCH_SIZE, shuffle=False, worker_init_fn=seed_worker, generator=g)

  # init model and optimizer
  model = GCN()
  checkpoint = torch.load(SAVE_PATH)
  model.load_state_dict(checkpoint['model_state_dict'])
  model.to(DEVICE)
  optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
  optim.load_state_dict(checkpoint['optimizer_state_dict'])
  epoch = checkpoint['epoch']
  print("\n\n\nFourier order is: "+str(FOURIER_ORDER)+"\n\n\n")

  # train for EPOCHS number of epochs then evaluate on test data with affine transformations
  # eval_loop(dataloader=test_fourier_loader,model=model)
  for i in range(epoch, EPOCHS):
      print("Epoch " + str(i + 1) + "\n")
      train_loop(dataloader=train_fourier_loader,model=model,loss_fn=LOSS_FN,optimizer=optim)
      # eval_loop(dataloader=eval_fourier_loader,model=model)
      torch.save({
                  'epoch': i + 1,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optim.state_dict()
                  }, SAVE_PATH)
      print("\n-------------------------------\n")
  random.seed(RAND_SEED)
  eval_loop(dataloader=test_fourier_loader,model=model)




Fourier order is: 6



Epoch 8

loss: 3.867212  [  500/342985]
loss: 3.638100  [50500/342985]
loss: 3.803391  [100500/342985]
loss: 3.771935  [150500/342985]
loss: 3.830309  [200500/342985]
loss: 3.955822  [250500/342985]
loss: 3.868090  [300500/342985]

epoch avg train loss: 3.815602   epoch avg train accuracy: 0.197251

-------------------------------

Epoch 9

loss: 3.747351  [  500/342985]
loss: 3.865071  [50500/342985]
loss: 3.697450  [100500/342985]
loss: 3.761675  [150500/342985]
loss: 3.893540  [200500/342985]
loss: 3.842951  [250500/342985]
loss: 3.854700  [300500/342985]

epoch avg train loss: 3.793753   epoch avg train accuracy: 0.201507

-------------------------------

Epoch 10

loss: 3.716927  [  500/342985]
loss: 3.720410  [50500/342985]
loss: 3.795022  [100500/342985]
loss: 3.736304  [150500/342985]
loss: 3.741310  [200500/342985]
loss: 3.924152  [250500/342985]
loss: 3.838437  [300500/342985]

epoch avg train loss: 3.773583   epoch avg train accuracy: 0.204169

----