### Notebook for creating image to image autoencoders for sketches using pytorch

We're using python 3 plus the latest versions of all the packages listed below. Be sure to update before running this nb

In [67]:
### import package

import sys
from importlib import reload
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
import numpy as np
import scipy.stats as stats
import pandas as pd
import torch
from PIL import Image
from IPython.display import clear_output
import svgpathtools
import os.path
from collections import Counter



import svg_rendering_helpers as srh

In [68]:

# directory & file hierarchy
proj_dir = os.path.abspath('..')
code_dir = os.getcwd()
plot_dir = os.path.join(proj_dir,'plots')
data_dir = os.path.join(proj_dir,'data')


if not os.path.exists(code_dir):
    os.makedirs(code_dir)
    
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)
    
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    
# if svg_rendering_helpers.py not in sys.path:
#     sys.path.append(os.path.join(proj_dir,svg_rendering_helpers.py)) 

## TO DO: Reformat sketch_svg_string column so that it's a list of strokes instead of one giant string

### Maybe find the png_string version of the sketches so we can just work with those for now?

In [127]:
D = pd.read_csv(os.path.join(data_dir,'semantic_parts_annotated_data.csv'))

In [128]:
D.sketch_svg_string[0]

"[u'M88,40c0,2.06264 -0.6772,18.1614 1,19c9.41403,4.70702 38.10521,0 48,0c6.97504,0 29.82865,3.17135 34,-1c3.42085,-3.42085 -3.19012,-10.78518 -4,-12c-3.78389,-5.67584 -63.17292,-1 -79,-1', u'M92,61c0,18.89105 3,36.40812 3,55c0,1.58004 -0.75655,5 1,5', u'M169,60c0,20.0258 3,38.53083 3,57', u'M102,122c16.31683,10.87789 43.89528,-5 61,-5c0.74802,0 9,0 9,0c0,0 -1.47891,5.69728 -2,7c-4.79137,11.97843 -12.34671,43.34671 -20,51c-2.8003,2.8003 -21.92143,0 -26,0c-2.08325,0 -39.79187,-0.79187 -40,-1c-4.02358,-4.02358 13,-45.63903 13,-52c0,-2.36036 -0.54955,-3 2,-3', u'M83,174c0,17.6105 4,36.94202 4,53', u'M148,176c0,14.22848 5,31.9937 5,44', u'M172,120c3.28295,0 2.90737,5.90421 4,9c4.66601,13.22036 11.28855,72.28855 16,77', u'M103,179c3.79513,0 8.1022,13 12,13']"

In [129]:
def listify(string):
    split_list = string.split("'")
    l=[", u","[u","]"]
    out = [x for x in split_list if x not in l]
    return(out)
    

In [130]:
D.sketch_svg_string = D.sketch_svg_string.apply(listify)

In [132]:
len(D.sketch_svg_string[0])

8

In [135]:
unique_cats = np.unique(D.category)
unique_cats

array(['bird', 'car', 'chair', 'dog'], dtype=object)

In [136]:
#Creating a spline-level df where the modal label is set as the 'true' label for any given spline
spline_df= D.groupby('spline_id').agg(lambda x: Counter(x).most_common(1)[0][0])
spline_df.reset_index(level=0, inplace=True)

##Creating a stroke-level dataframe that takes the mode value of annotation for its children splines to set as its
##label value

from collections import Counter


from collections import OrderedDict
stroke_svgs=OrderedDict()
for category in unique_cats:
    DS=D[D['category']==category]
    for sketch in np.unique(DS['sketch_id']):
        DSS=DS[DS['sketch_id']==sketch]
        for stroke in np.unique(DSS['stroke_num']):
            DSA=DSS[DSS['stroke_num']==stroke]
            DSA=DSA.reset_index()
            stroke_svgs[DSA['stroke_id'][0]] = DSA['sketch_svg_string'][0][stroke]

            
            
stroke_svg_df= pd.DataFrame.from_dict(stroke_svgs, orient='index')    
stroke_group_data= D.groupby('stroke_id').agg(lambda x: Counter(x).most_common(1)[0][0])
labels= pd.DataFrame(stroke_group_data[['sketch_id','label','stroke_num','condition','target','category','outcome']])
stroke_df=pd.merge(stroke_svg_df,labels,left_index=True, right_index =True)
stroke_df.reset_index(level=0, inplace=True)
stroke_df=stroke_df.rename(index=str, columns={"index": "stroke_id", 0: "svg"})


In [137]:
D.sketch_svg_string[0]

['M88,40c0,2.06264 -0.6772,18.1614 1,19c9.41403,4.70702 38.10521,0 48,0c6.97504,0 29.82865,3.17135 34,-1c3.42085,-3.42085 -3.19012,-10.78518 -4,-12c-3.78389,-5.67584 -63.17292,-1 -79,-1',
 'M92,61c0,18.89105 3,36.40812 3,55c0,1.58004 -0.75655,5 1,5',
 'M169,60c0,20.0258 3,38.53083 3,57',
 'M102,122c16.31683,10.87789 43.89528,-5 61,-5c0.74802,0 9,0 9,0c0,0 -1.47891,5.69728 -2,7c-4.79137,11.97843 -12.34671,43.34671 -20,51c-2.8003,2.8003 -21.92143,0 -26,0c-2.08325,0 -39.79187,-0.79187 -40,-1c-4.02358,-4.02358 13,-45.63903 13,-52c0,-2.36036 -0.54955,-3 2,-3',
 'M83,174c0,17.6105 4,36.94202 4,53',
 'M148,176c0,14.22848 5,31.9937 5,44',
 'M172,120c3.28295,0 2.90737,5.90421 4,9c4.66601,13.22036 11.28855,72.28855 16,77',
 'M103,179c3.79513,0 8.1022,13 12,13']

In [138]:
D_birds = stroke_df[stroke_df['category']=='bird']

In [139]:
D_bj = D_birds[D_birds['target']=='bluejay']

In [140]:
D_bj.svg[0]

'M40,105c20.79393,0 34.03323,-2.09968 40,-20c2.95354,-8.86062 6.23882,-37.61941 15,-42c0.81311,-0.40656 0.82741,-1.58629 2,-1c27.81184,13.90592 11.62806,45.25612 21,64c7.17843,14.35687 9.92147,29.8822 18,42c12.91028,19.36542 61.95776,24.41047 82,33c4.98839,2.13788 31,6.9385 31,10c0,0.4714 -0.57836,0.78918 -1,1c-21.21578,10.60789 -51.57616,-17.71192 -73,-7c-9.90588,4.95294 -9.69981,20.09942 -13,30c-4.6214,13.86419 -14.25489,18.25489 -24,28c-1.5915,1.5915 -8.13429,9 -11,9c-1.37437,0 -0.56539,-2.69616 -1,-4c-2.6389,-7.9167 1.81078,-19.62155 5,-26c0.90007,-1.80014 6,-29 6,-29c0,0 -13.83383,-4.91692 -16,-6c-14.98873,-7.49436 -29.64401,45.28803 -35,56c-0.52395,1.04791 -3,6 -3,6c0,0 -1,-6.95235 -1,-8c0,-16.15664 2.22943,-32.49566 7,-48c1.67887,-5.45632 9,-14.79325 9,-20'

In [78]:
reload(srh)
really_run = True

if really_run==True:

    for sketch in D_bj.sketch_id.unique():
        this_sketch = D_bj.query('sketch_id == @sketch')
        svgs = list(this_sketch.svg)
        srh.render_svg(svgs,base_dir=plot_dir,out_fname='{}.svg'.format(sketch))


TypeError: '[' is not a valid value for attribute 'd' at svg-element <path>.

In [79]:
D_bj.svg

54      [
55      u
56      '
57      M
58      4
59      0
60      ,
61      1
62      0
63      5
219     [
220     u
221     '
222     M
223     3
224     7
225     ,
226     9
227     6
228     c
229     8
298     [
299     u
300     '
301     M
302     1
311     [
312     u
313     '
314     M
       ..
1706    1
1707    6
1708    3
1717    [
1718    u
1719    '
1720    M
1721    1
1722    0
1723    1
1724    ,
1725    1
1786    [
1787    u
1788    '
1789    M
1790    5
1791    4
1792    ,
1793    1
1794    1
1795    8
1875    [
1876    u
1877    '
1878    M
1879    3
1880    5
1881    ,
1882    7
Name: svg, Length: 295, dtype: object