### Notebook for creating image to image autoencoders for sketches using pytorch

We're using python 3 plus the latest versions of all the packages listed below. Be sure to update before running this nb

In [67]:
### import package

import sys
from importlib import reload
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
import numpy as np
import scipy.stats as stats
import pandas as pd
import torch
from PIL import Image
from IPython.display import clear_output
import svgpathtools
import os.path
from collections import Counter



import svg_rendering_helpers as srh

In [68]:

# directory & file hierarchy
proj_dir = os.path.abspath('..')
code_dir = os.getcwd()
plot_dir = os.path.join(proj_dir,'plots')
data_dir = os.path.join(proj_dir,'data')


if not os.path.exists(code_dir):
    os.makedirs(code_dir)
    
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)
    
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    
# if svg_rendering_helpers.py not in sys.path:
#     sys.path.append(os.path.join(proj_dir,svg_rendering_helpers.py)) 

## TO DO: Reformat sketch_svg_string column so that it's a list of strokes instead of one giant string

### Maybe find the png_string version of the sketches so we can just work with those for now?

In [127]:
D = pd.read_csv(os.path.join(data_dir,'semantic_parts_annotated_data.csv'))

In [128]:
D.sketch_svg_string[0]

"[u'M88,40c0,2.06264 -0.6772,18.1614 1,19c9.41403,4.70702 38.10521,0 48,0c6.97504,0 29.82865,3.17135 34,-1c3.42085,-3.42085 -3.19012,-10.78518 -4,-12c-3.78389,-5.67584 -63.17292,-1 -79,-1', u'M92,61c0,18.89105 3,36.40812 3,55c0,1.58004 -0.75655,5 1,5', u'M169,60c0,20.0258 3,38.53083 3,57', u'M102,122c16.31683,10.87789 43.89528,-5 61,-5c0.74802,0 9,0 9,0c0,0 -1.47891,5.69728 -2,7c-4.79137,11.97843 -12.34671,43.34671 -20,51c-2.8003,2.8003 -21.92143,0 -26,0c-2.08325,0 -39.79187,-0.79187 -40,-1c-4.02358,-4.02358 13,-45.63903 13,-52c0,-2.36036 -0.54955,-3 2,-3', u'M83,174c0,17.6105 4,36.94202 4,53', u'M148,176c0,14.22848 5,31.9937 5,44', u'M172,120c3.28295,0 2.90737,5.90421 4,9c4.66601,13.22036 11.28855,72.28855 16,77', u'M103,179c3.79513,0 8.1022,13 12,13']"

In [129]:
def listify(string):
    split_list = string.split("'")
    l=[", u","[u","]"]
    out = [x for x in split_list if x not in l]
    return(out)
    

In [130]:
D.sketch_svg_string = D.sketch_svg_string.apply(listify)

In [132]:
len(D.sketch_svg_string[0])

8

In [135]:
unique_cats = np.unique(D.category)
unique_cats

array(['bird', 'car', 'chair', 'dog'], dtype=object)

In [136]:
#Creating a spline-level df where the modal label is set as the 'true' label for any given spline
spline_df= D.groupby('spline_id').agg(lambda x: Counter(x).most_common(1)[0][0])
spline_df.reset_index(level=0, inplace=True)

##Creating a stroke-level dataframe that takes the mode value of annotation for its children splines to set as its
##label value

from collections import Counter


from collections import OrderedDict
stroke_svgs=OrderedDict()
for category in unique_cats:
    DS=D[D['category']==category]
    for sketch in np.unique(DS['sketch_id']):
        DSS=DS[DS['sketch_id']==sketch]
        for stroke in np.unique(DSS['stroke_num']):
            DSA=DSS[DSS['stroke_num']==stroke]
            DSA=DSA.reset_index()
            stroke_svgs[DSA['stroke_id'][0]] = DSA['sketch_svg_string'][0][stroke]

            
            
stroke_svg_df= pd.DataFrame.from_dict(stroke_svgs, orient='index')    
stroke_group_data= D.groupby('stroke_id').agg(lambda x: Counter(x).most_common(1)[0][0])
labels= pd.DataFrame(stroke_group_data[['sketch_id','label','stroke_num','condition','target','category','outcome']])
stroke_df=pd.merge(stroke_svg_df,labels,left_index=True, right_index =True)
stroke_df.reset_index(level=0, inplace=True)
stroke_df=stroke_df.rename(index=str, columns={"index": "stroke_id", 0: "svg"})


In [137]:
D.sketch_svg_string[0]

['M88,40c0,2.06264 -0.6772,18.1614 1,19c9.41403,4.70702 38.10521,0 48,0c6.97504,0 29.82865,3.17135 34,-1c3.42085,-3.42085 -3.19012,-10.78518 -4,-12c-3.78389,-5.67584 -63.17292,-1 -79,-1',
 'M92,61c0,18.89105 3,36.40812 3,55c0,1.58004 -0.75655,5 1,5',
 'M169,60c0,20.0258 3,38.53083 3,57',
 'M102,122c16.31683,10.87789 43.89528,-5 61,-5c0.74802,0 9,0 9,0c0,0 -1.47891,5.69728 -2,7c-4.79137,11.97843 -12.34671,43.34671 -20,51c-2.8003,2.8003 -21.92143,0 -26,0c-2.08325,0 -39.79187,-0.79187 -40,-1c-4.02358,-4.02358 13,-45.63903 13,-52c0,-2.36036 -0.54955,-3 2,-3',
 'M83,174c0,17.6105 4,36.94202 4,53',
 'M148,176c0,14.22848 5,31.9937 5,44',
 'M172,120c3.28295,0 2.90737,5.90421 4,9c4.66601,13.22036 11.28855,72.28855 16,77',
 'M103,179c3.79513,0 8.1022,13 12,13']

In [138]:
D_birds = stroke_df[stroke_df['category']=='bird']

In [139]:
D_bj = D_birds[D_birds['target']=='bluejay']

In [144]:
D_bj.svg

54      M40,105c20.79393,0 34.03323,-2.09968 40,-20c2....
55           M92,236c-6.92959,6.92959 -15.29336,21 -26,21
56               M84,242c3.79976,3.79976 14.8012,12 20,12
57                 M130,240c0,1.57556 3.56254,5.56254 5,7
58                   M137,245c-8.6616,0 -18.84486,9 -27,9
59      M92,175c0,-14.62466 -18.36294,-28.54441 -26,-4...
60          M93,88c0.33333,-0.33333 0.66667,-0.66667 1,-1
61      M96,84c-2.13931,2.13931 -2,3.07978 -2,6c0,3.95...
62                M130,246c0,3.51362 10.7946,9.7946 14,13
63      M133,243c-2.95162,0 -1.2633,4.63165 -4,6c-0.54...
219               M37,96c8.63927,0 39.32001,4.67999 45,-1
220     M37,94c13.15865,0 22.69514,7.09931 36,9c0.6635...
221     M78,98c0,1.56667 0.66667,1 2,1c11.69836,0 21.9...
222                 M136,227c0,-1.58471 -1,-2.19668 -1,-4
223                M135,220c0.82177,0.82177 4.27177,3 5,3
224                M139,224c0,3.22863 6.64342,2.64342 9,5
225     M108,178c-1.73941,0 -1,2.51791 -1,4c0,2.76598 ...
226           

In [149]:
reload(srh)
really_run = True

if really_run==True:

    for sketch in D_bj.sketch_id.unique():
        this_sketch = D_bj.query('sketch_id == @sketch')
        svgs = list(this_sketch.svg)
        srh.render_svg(svgs,base_dir=plot_dir,out_fname='{}.svg'.format(sketch))


In [151]:
### Create path to lesioned svgs and convert to png for feature extraction
really_run = True

if really_run==True:
    svg_paths= srh.generate_svg_path_list(os.path.join(plot_dir,'svg_images'))
    srh.svg_to_png(svg_paths,base_dir=plot_dir)

convert /Users/kushin/Documents/Github/UW_sketch_work/plots/svg_images/2721-f28245be-a3ac-425e-9538-5c0803980807_23.svg /Users/kushin/Documents/Github/UW_sketch_work/plots/png_images/2721-f28245be-a3ac-425e-9538-5c0803980807_23.png
