In [None]:
##Author: Kushin Mukherjee

I recommend running this notebook inside a conda environment to keep things organized and for reproducibility.

Download and install conda: https://docs.conda.io/projects/conda/en/latest/user-guide/install/  
Creating an environment: https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html

Some tips for installing packages:
First, activate your environment then install pip within the environment so that all the packages you install don't get installed to your global path. To do so:  
Type `conda install pip` in your terminal

Then, when in the project directory:  
Type `pip install -r requirements.txt`


We're using python 3 plus the latest versions of all the packages listed below. Be sure to update before running this nb

In [None]:
### import packages

import sys
import random
from importlib import reload
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
import numpy as np
import scipy.stats as stats
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from IPython.display import clear_output
import svgpathtools
import os.path
from collections import Counter
import svg_rendering_helpers as srh

In [None]:

# directory & file hierarchy
proj_dir = os.path.abspath('..')
code_dir = os.getcwd()
plot_dir = os.path.join(proj_dir,'plots')
data_dir = os.path.join(proj_dir,'data')


if not os.path.exists(code_dir):
    os.makedirs(code_dir)
    
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)
    
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    
# if svg_rendering_helpers.py not in sys.path:
#     sys.path.append(os.path.join(proj_dir,svg_rendering_helpers.py)) 

In [None]:
D = pd.read_csv(os.path.join(data_dir,'semantic_parts_annotated_data.csv'))

In [None]:
def listify(string):
    split_list = string.split("'")
    l=[", u","[u","]"]
    out = [x for x in split_list if x not in l]
    return(out)
    

In [None]:
D.sketch_svg_string = D.sketch_svg_string.apply(listify)
len(D.sketch_svg_string[0])

In [None]:
unique_cats = np.unique(D.category)
unique_cats

In [None]:
#Creating a spline-level df where the modal label is set as the 'true' label for any given spline
spline_df= D.groupby('spline_id').agg(lambda x: Counter(x).most_common(1)[0][0])
spline_df.reset_index(level=0, inplace=True)

##Creating a stroke-level dataframe that takes the mode value of annotation for its children splines to set as its
##label value

from collections import Counter


from collections import OrderedDict
stroke_svgs=OrderedDict()
for category in unique_cats:
    DS=D[D['category']==category]
    for sketch in np.unique(DS['sketch_id']):
        DSS=DS[DS['sketch_id']==sketch]
        for stroke in np.unique(DSS['stroke_num']):
            DSA=DSS[DSS['stroke_num']==stroke]
            DSA=DSA.reset_index()
            stroke_svgs[DSA['stroke_id'][0]] = DSA['sketch_svg_string'][0][stroke]

            
            
stroke_svg_df= pd.DataFrame.from_dict(stroke_svgs, orient='index')    
stroke_group_data= D.groupby('stroke_id').agg(lambda x: Counter(x).most_common(1)[0][0])
labels= pd.DataFrame(stroke_group_data[['sketch_id','label','stroke_num','condition','target','category','outcome']])
stroke_df=pd.merge(stroke_svg_df,labels,left_index=True, right_index =True)
stroke_df.reset_index(level=0, inplace=True)
stroke_df=stroke_df.rename(index=str, columns={"index": "stroke_id", 0: "svg"})


### Generating data for triplets task


In [None]:
## We need to exclude some "bad sketches", which are mostly just handwritten text instead of drawingss

bad_sketches = [
'3058-fb4fe740-d862-453b-a08f-44375a040165_21',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_8',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_12',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_23',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_24',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_15',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_24',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_16',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_20',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_22',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_7',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_13',
'6311-cd21a68a-f1df-4290-b744-b0c7c7c60ed8_5',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_32'
]

stroke_df = stroke_df[~stroke_df['sketch_id'].isin(bad_sketches)]

In [None]:
### Currently constrained by minimum number of sketches in a conditionXcategoryXexemplar cell, which is 4
### We have 2*4*8*4 (256) sketches in total

random.seed(1022)
sample_sketches= []

for this_cat in unique_cats:
    cat_df = stroke_df[stroke_df['category']== this_cat]
    unique_items = np.unique(cat_df['target'])
    for this_item in unique_items:
        item_df = cat_df[cat_df['target']==this_item]
        unique_conds = np.unique(item_df['condition'])
        for this_cond in unique_conds:
            cond_df = item_df[item_df['condition']==this_cond]
            us = np.unique(cond_df['sketch_id']) ## unique sketches in cell
            if len(us)<4:
                print("not enough in cell", this_item, this_cond,len(us))
                break
            rand_sl = np.random.choice(us,size = 4,replace=False) ## list of random sketch ids
            sample_sketches.append(rand_sl)
            

sample_sketches = [y for x in sample_sketches for y in x] ##flatten list
            
assert(len(np.unique(sample_sketches))==len(sample_sketches))
    

In [None]:
render_df = stroke_df[stroke_df['sketch_id'].isin(sample_sketches) ]

In [None]:
render_df.sketch_id.nunique()

In [None]:
###Clear directories

svg_dir = os.path.join(plot_dir,'triplet_sketches')
png_dir =  os.path.join(plot_dir,'triplet_sketches_png')
for this_dir in [svg_dir,png_dir]:
    filelist = [ f for f in os.listdir(this_dir) ]
    for this_sketch in filelist:
        file_path = os.path.join(this_dir, this_sketch)
        try:
            if os.path.isfile(file_path):
                os.remove(file_path)
                os.unlink(file_path)
            #elif os.path.isdir(file_path): shutil.rmtree(file_path)
        except Exception as e:
            print(e)

In [None]:
###Render out SVGs and PNGs

reload(srh)
really_run = True

if really_run==True:

    for sketch in render_df.sketch_id.unique():
        this_sketch = render_df.query('sketch_id == @sketch')
        svgs = list(this_sketch.svg)
        srh.render_svg(svgs,out_dir ="triplet_sketches", base_dir=plot_dir,out_fname='{}.svg'.format(sketch))
### Create path to svgs and convert to png for feature extraction
really_run = True

if really_run==True:
    svg_paths= srh.generate_svg_path_list(os.path.join(plot_dir,'triplet_sketches'))
    srh.svg_to_png(svg_paths,out_dir="triplet_sketches_png",base_dir=plot_dir)



In [None]:
render_df_meta = pd.DataFrame(render_df.groupby(['sketch_id','category','target','label']).agg(num_strokes=pd.NamedAgg(column='stroke_id', aggfunc=lambda x: len(x.unique()))))
render_df_meta=render_df_meta.reset_index()
render_df_meta

In [None]:
render_df_meta.to_csv(index=False,path_or_buf=os.path.join(data_dir,'render_meta_data.csv'))

## Two feature analyses using VGG features extraced from UW (Tim + Pablo) feature extractor and Judy's feature extractor

### Tim and Pablo

### Judy