In [None]:
import numpy as np

import scipy as sc
import scipy.io as scio # for loading .mat file
from scipy import linalg
from scipy.spatial import KDTree

from sklearn.decomposition import PCA as PCAdimReduc
from sklearn.feature_extraction import DictVectorizer

import networkx as nx

import pygmtools as pygm

import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch

import cv2

import shapely

import svgpathtools
import drawsvg as draw

from bplustree import BPlusTree
from bplustree.serializer import Serializer
from bplustree.node import Node

import pickle

from xml.dom import minidom

import functools
from itertools import product, combinations
from typing import Optional, Tuple, List
from collections import Counter, defaultdict

import struct

import os, errno
import sys

import random

pygm.BACKEND = 'numpy' # set numpy as backend for pygmtools

# TOPOLOGY GRAPH / GEOMETRY EXTRACTION

In [None]:
from src.svg import *
from src.extraction import *
from src.database import *

In [None]:
svg = load('assets/svg/ant/286.svg')
display(svg)

In [None]:
line_strings = get_line_strings(map(to_control_points, svg['paths']), step=15)

In [None]:
index = 0
ls = line_strings[index]

print(ls)
print(type(ls))
print(ls.length)

s, t = get_endpoints(ls)
print(s.distance(t))

plot_line_strings([ls])

segments = get_segments(ls)

# Plot segmented line string
plot_line_strings(segments)

# Plot closed line string
plot_line_strings([detect_approximate_polygon(ls)])

ls

In [None]:
DEBUG = shapely.MultiLineString(line_strings)

shapely.box(*DEBUG.bounds).area

# 1200.0 vs 70400.0

In [None]:
segments = []
for ls in line_strings:
    segments += get_segments(ls) 

# Plot segmented line string
plot_line_strings(segments)

shapely.MultiLineString(segments)

In [None]:
polygons = filter_polygons(get_polygons(line_strings), step=15)

print(len(polygons))

shapely.MultiPolygon(polygons)

### EXTRACTION

In [None]:
PATHS = map(to_control_points, svg['paths'])

G = extract_graph(PATHS, 'ant', step=15)
plot_graph(G)

G

# DATABASE

## Offline graph extraction

In [None]:
IMAGE_DIRECTORY = os.sep.join(['assets', 'svg'])

def load_svg_files(ext='svg'):
    """
    Returns an iterator that yields the image file paths and their labels.
    """    
    def get_label(path):
        return path.split(os.sep)[-1]
    
    files = []
    
    for r, _, f in os.walk(IMAGE_DIRECTORY):
        for file in f:
            if file.endswith(f'.{ext}'):
                label = get_label(r)
                file_path = os.path.join(r, file)
                
                files.append((file_path, label))
                
    return files


def load_svg_images(files):
    """
    Returns an iterator that yields the image data and their labels.
    """
    for f, l in files:
        print(f)
        yield load(f), l

def extract_graphs(images, step=20):
    """
    Extract all graphs from given images.
    """
    i = 0
    graphs = []
    for img, l in images:
        cp = map(to_control_points,img['paths'])
        
        try:
            graphs.append(extract_graph(cp, l, step=step))
        except StopIteration:
            print('drawing was too small')
        
        print(i)
        i += 1
        
    return graphs

In [None]:
# Extract graphs for each sketch in the dataset, offline processing step

# paths = load_svg_files()
# print(len(paths))

# images = load_svg_images(paths)

# graphs = extract_graphs(images, step=20)
# print(len(graphs))

In [None]:
import pickle

def dump_graphs(graphs, prefix=""):
    with open(prefix+'graphs.p', 'wb') as f:
        pickle.dump(graphs, f)
        
def load_graphs(prefix=""):
    with open(prefix+'graphs.p', 'rb') as f:
        return pickle.load(f)

In [None]:
# Dump the graphs to avoid this offline step later
# dump_graphs(graphs)

In [None]:
# Load graphs from disk
graphs = load_graphs()

len(graphs)

## DIMENSIONALITY REDUCTION

In [None]:
# Check the statistics for number of nodes

n = [len(g.nodes) for g in graphs]

fig = plt.figure()

# hist = np.histogram(n, bins=range(170))

plt.hist(n, bins=170) #, density=True)

plt.xticks(np.arange(0, 170, step=10))

print('max: ', max(n))
print('min: ', min(n))
print('99% percentile: ', np.percentile(n, 99))

In [None]:
# Check the statistics for number of edges

e = [len(g.edges) for g in graphs]

# hist = np.histogram(e, bins=range(170))

fig = plt.figure()

plt.hist(e, bins=400) #, density=True)

plt.xticks(np.arange(0, 400, step=30))

print('max: ', max(e))
print('min: ', min(e))
print('99% percentile: ', np.percentile(e, 99))

In [None]:
# Check descriptor collisions for a given descriptor length

maximum = max(n)

print('100th percentile (maximum): ', maximum )

descriptors_max = [descriptor(g, N=maximum).tobytes() for g in graphs]

c_max = Counter(descriptors_max)

print('Unique descriptors: ', list(c_max.values()).count(1))

[(np.frombuffer(d, dtype=float), c) for d, c in c_max.most_common()]

In [None]:
# Check descriptor collisions for a given descriptor length

percent = 84

percentile = int(np.percentile(n,percent))

print(f'{percent}th percentile: ', percentile)

descriptors = [descriptor(g, N=percentile).tobytes() for g in graphs]

c_percentile = Counter(descriptors)

print('Unique descriptors: ', list(c_percentile.values()).count(1))

[(np.frombuffer(d, dtype=float), c) for d, c in c_percentile.most_common()]

### Derive the frequencies of each label per descriptor

In [None]:
LABELS = ['airplane',
         'alarm clock',
         'angel',
         'ant',
         'apple',
         'arm',
         'armchair',
         'ashtray',
         'axe',
         'backpack',
         'banana',
         'barn',
         'baseball bat',
         'basket',
         'bathtub',
         'bear (animal)',
         'bed',
         'bee',
         'beer-mug',
         'bell',
         'bench',
         'bicycle',
         'binoculars',
         'blimp',
         'book',
         'bookshelf',
         'boomerang',
         'bottle opener',
         'bowl',
         'brain',
         'bread',
         'bridge',
         'bulldozer',
         'bus',
         'bush',
         'butterfly',
         'cabinet',
         'cactus',
         'cake',
         'calculator',
         'camel',
         'camera',
         'candle',
         'cannon',
         'canoe',
         'car (sedan)',
         'carrot',
         'castle',
         'cat',
         'cell phone',
         'chair',
         'chandelier',
         'church',
         'cigarette',
         'cloud',
         'comb',
         'computer monitor',
         'computer-mouse',
         'couch',
         'cow',
         'crab',
         'crane (machine)',
         'crocodile',
         'crown',
         'cup',
         'diamond',
         'dog',
         'dolphin',
         'donut',
         'door',
         'door handle',
         'dragon',
         'duck',
         'ear',
         'elephant',
         'envelope',
         'eye',
         'eyeglasses',
         'face',
         'fan',
         'feather',
         'fire hydrant',
         'fish',
         'flashlight',
         'floor lamp',
         'flower with stem',
         'flying bird',
         'flying saucer',
         'foot',
         'fork',
         'frog',
         'frying-pan',
         'giraffe',
         'grapes',
         'grenade',
         'guitar',
         'hamburger',
         'hammer',
         'hand',
         'harp',
         'hat',
         'head',
         'head-phones',
         'hedgehog',
         'helicopter',
         'helmet',
         'horse',
         'hot air balloon',
         'hot-dog',
         'hourglass',
         'house',
         'human-skeleton',
         'ice-cream-cone',
         'ipod',
         'kangaroo',
         'key',
         'keyboard',
         'knife',
         'ladder',
         'laptop',
         'leaf',
         'lightbulb',
         'lighter',
         'lion',
         'lobster',
         'loudspeaker',
         'mailbox',
         'megaphone',
         'mermaid',
         'microphone',
         'microscope',
         'monkey',
         'moon',
         'mosquito',
         'motorbike',
         'mouse (animal)',
         'mouth',
         'mug',
         'mushroom',
         'nose',
         'octopus',
         'owl',
         'palm tree',
         'panda',
         'paper clip',
         'parachute',
         'parking meter',
         'parrot',
         'pear',
         'pen',
         'penguin',
         'person sitting',
         'person walking',
         'piano',
         'pickup truck',
         'pig',
         'pigeon',
         'pineapple',
         'pipe (for smoking)',
         'pizza',
         'potted plant',
         'power outlet',
         'present',
         'pretzel',
         'pumpkin',
         'purse',
         'rabbit',
         'race car',
         'radio',
         'rainbow',
         'revolver',
         'rifle',
         'rollerblades',
         'rooster',
         'sailboat',
         'santa claus',
         'satellite',
         'satellite dish',
         'saxophone',
         'scissors',
         'scorpion',
         'screwdriver',
         'sea turtle',
         'seagull',
         'shark',
         'sheep',
         'ship',
         'shoe',
         'shovel',
         'skateboard',
         'skull',
         'skyscraper',
         'snail',
         'snake',
         'snowboard',
         'snowman',
         'socks',
         'space shuttle',
         'speed-boat',
         'spider',
         'sponge bob',
         'spoon',
         'squirrel',
         'standing bird',
         'stapler',
         'strawberry',
         'streetlight',
         'submarine',
         'suitcase',
         'sun',
         'suv',
         'swan',
         'sword',
         'syringe',
         't-shirt',
         'table',
         'tablelamp',
         'teacup',
         'teapot',
         'teddy-bear',
         'telephone',
         'tennis-racket',
         'tent',
         'tiger',
         'tire',
         'toilet',
         'tomato',
         'tooth',
         'toothbrush',
         'tractor',
         'traffic light',
         'train',
         'tree',
         'trombone',
         'trousers',
         'truck',
         'trumpet',
         'tv',
         'umbrella',
         'van',
         'vase',
         'violin',
         'walkie talkie',
         'wheel',
         'wheelbarrow',
         'windmill',
         'wine-bottle',
         'wineglass',
         'wrist-watch',
         'zebra'
         ]
len(LABELS)

In [None]:
# Get the weights of all categories 

descriptors = map(lambda g: descriptor(g, N=7).tobytes(), graphs)

labels = map(lambda g: g.graph['label'], graphs)

res = defaultdict(list)
for key, val in zip(descriptors, labels):
    res[key].append(val)
    
counts = {k: Counter(v) for k, v in res.items()}
counts

In [None]:
# Vectorize these frequencies

to_index = {label: i for i, label in enumerate(LABELS)}

def get_frequencies(descriptor, counter):
    freq = np.zeros(len(LABELS), dtype=float)
    
    for label, count in counter.items():
        i = to_index[label]
        freq[i] = count
        
    freq /= np.sum(freq)
    
    return freq

FREQUENCIES = {d: get_frequencies(d, c).tobytes() for d, c in counts.items()}
FREQUENCIES  

In [None]:
# Dump frequencies

def dump_frequencies():
    with open('frequencies.p', 'wb') as f:
        pickle.dump(FREQUENCIES , f)

def load_frequencies():
    with open('frequencies.p', 'rb') as f:
        return pickle.load(f)
        
dump_frequencies()

In [None]:
# Check the frequencies for the zero descriptor
f = FREQUENCIES[bytes(8*7)]
np.frombuffer(f, dtype=float)

## Database construction

### Disk B+-tree mapping descriptors to graph data

In [None]:
FREQUENCIES = load_frequencies()

def serialize_frequency_features(graph):
    d = descriptor(graph, N=7).tobytes()
    return FREQUENCIES[d]
    
# Create iterator for all key/value pairs to be inserted into database
iterator = sorted(FREQUENCIES.items(), key=lambda p: p[0], reverse=True)

# Construct database for online use
db = construct_database(iterator, N=7, value_size=250, serialize=serialize_frequency_features)

In [None]:
# Test query
svg = load('assets/svg/ant/279.svg')

query = extract_graph(svg, 'ant')
plot_graph(query)

print(descriptor(query))

candidates = query_database(db, query)
print(candidates, "\n")

In [None]:
db.close()

In [None]:
# Reopen database

db = open_database()
candidates = query_database(db, query)
db.close()

candidates

# GRAPH MATCHING

In [None]:
from src.matching import *

In [None]:
G = nx.tutte_graph()
G.graph['positions'] = nx.spring_layout(G)

In [None]:
nx.draw(G, pos=G.graph['positions'])

In [None]:
ax1 = plt.subplot(1, 2, 1)
plot_graph(G)

ax2 = plt.subplot(1, 2, 2)
plot_graph(G)

In [None]:
X = match(G, G)
print(X.shape)
X

In [None]:
plot_mapping(X, G, G)