### Notebook for creating image to image autoencoders for sketches using pytorch

In [None]:
### Author: Kushin Mukherjee

We're using python 3 plus the latest versions of all the packages listed below. Be sure to update before running this nb

In [None]:
### import package

import sys
import random
from importlib import reload
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
import numpy as np
import scipy.stats as stats
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from IPython.display import clear_output
import svgpathtools
import os.path
from collections import Counter



import svg_rendering_helpers as srh

In [None]:

# directory & file hierarchy
proj_dir = os.path.abspath('..')
code_dir = os.getcwd()
plot_dir = os.path.join(proj_dir,'plots')
data_dir = os.path.join(proj_dir,'data')


if not os.path.exists(code_dir):
    os.makedirs(code_dir)
    
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)
    
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    
# if svg_rendering_helpers.py not in sys.path:
#     sys.path.append(os.path.join(proj_dir,svg_rendering_helpers.py)) 

In [None]:
D = pd.read_csv(os.path.join(data_dir,'semantic_parts_annotated_data.csv'))

In [None]:
D.sketch_svg_string[0]

In [None]:
def listify(string):
    split_list = string.split("'")
    l=[", u","[u","]"]
    out = [x for x in split_list if x not in l]
    return(out)
    

In [None]:
D.sketch_svg_string = D.sketch_svg_string.apply(listify)

In [None]:
len(D.sketch_svg_string[0])

In [None]:
unique_cats = np.unique(D.category)
unique_cats

In [None]:
#Creating a spline-level df where the modal label is set as the 'true' label for any given spline
spline_df= D.groupby('spline_id').agg(lambda x: Counter(x).most_common(1)[0][0])
spline_df.reset_index(level=0, inplace=True)

##Creating a stroke-level dataframe that takes the mode value of annotation for its children splines to set as its
##label value

from collections import Counter


from collections import OrderedDict
stroke_svgs=OrderedDict()
for category in unique_cats:
    DS=D[D['category']==category]
    for sketch in np.unique(DS['sketch_id']):
        DSS=DS[DS['sketch_id']==sketch]
        for stroke in np.unique(DSS['stroke_num']):
            DSA=DSS[DSS['stroke_num']==stroke]
            DSA=DSA.reset_index()
            stroke_svgs[DSA['stroke_id'][0]] = DSA['sketch_svg_string'][0][stroke]

            
            
stroke_svg_df= pd.DataFrame.from_dict(stroke_svgs, orient='index')    
stroke_group_data= D.groupby('stroke_id').agg(lambda x: Counter(x).most_common(1)[0][0])
labels= pd.DataFrame(stroke_group_data[['sketch_id','label','stroke_num','condition','target','category','outcome']])
stroke_df=pd.merge(stroke_svg_df,labels,left_index=True, right_index =True)
stroke_df.reset_index(level=0, inplace=True)
stroke_df=stroke_df.rename(index=str, columns={"index": "stroke_id", 0: "svg"})


### Generating data for triplets task

In [None]:
bad_sketches = [
'3058-fb4fe740-d862-453b-a08f-44375a040165_21',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_8',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_12',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_23',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_24',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_15',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_24',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_16',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_20',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_22',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_7',
'3113-105e6653-7fd1-4451-af00-46bb3145880a_13',
'6311-cd21a68a-f1df-4290-b744-b0c7c7c60ed8_5',
'6786-9c3169eb-962e-468b-8922-b99247975eb2_32'
]

stroke_df = stroke_df[~stroke_df['sketch_id'].isin(bad_sketches)]

In [None]:
random.seed(1022)
sample_sketches= []

for this_cat in unique_cats:
    cat_df = stroke_df[stroke_df['category']== this_cat]
    unique_items = np.unique(cat_df['target'])
    for this_item in unique_items:
        item_df = cat_df[cat_df['target']==this_item]
        unique_conds = np.unique(item_df['condition'])
        for this_cond in unique_conds:
            cond_df = item_df[item_df['condition']==this_cond]
            us = np.unique(cond_df['sketch_id']) ## unique sketches in cell
            if len(us)<4:
                print("not enough in cell", this_item, this_cond,len(us))
                break
            rand_sl = np.random.choice(us,size = 4,replace=False) ## list of random sketch ids
            sample_sketches.append(rand_sl)
            

sample_sketches = [y for x in sample_sketches for y in x] ##flatten list
            
assert(len(np.unique(sample_sketches))==len(sample_sketches))
    

In [None]:
render_df = stroke_df[stroke_df['sketch_id'].isin(sample_sketches) ]


In [None]:
render_df.sketch_id.nunique()

In [None]:
###Clear directories

svg_dir = os.path.join(plot_dir,'triplet_sketches')
png_dir =  os.path.join(plot_dir,'triplet_sketches_png')
for this_dir in [svg_dir,png_dir]:
    filelist = [ f for f in os.listdir(this_dir) ]
    for this_sketch in filelist:
        file_path = os.path.join(this_dir, this_sketch)
        try:
            if os.path.isfile(file_path):
                os.remove(file_path)
                os.unlink(file_path)
            #elif os.path.isdir(file_path): shutil.rmtree(file_path)
        except Exception as e:
            print(e)

In [None]:
###Render out SVGs and PNGs

reload(srh)
really_run = True

if really_run==True:

    for sketch in render_df.sketch_id.unique():
        this_sketch = render_df.query('sketch_id == @sketch')
        svgs = list(this_sketch.svg)
        srh.render_svg(svgs,out_dir ="triplet_sketches", base_dir=plot_dir,out_fname='{}.svg'.format(sketch))
### Create path to svgs and convert to png for feature extraction
really_run = True

if really_run==True:
    svg_paths= srh.generate_svg_path_list(os.path.join(plot_dir,'triplet_sketches'))
    srh.svg_to_png(svg_paths,out_dir="triplet_sketches_png",base_dir=plot_dir)



In [None]:
render_df.columns

In [None]:
render_df_meta = pd.DataFrame(render_df.groupby(['sketch_id','category','target','label']).agg(num_strokes=pd.NamedAgg(column='stroke_id', aggfunc=lambda x: len(x.unique()))))
render_df_meta=render_df_meta.reset_index()
render_df_meta



In [None]:
render_df_meta.to_csv(index=False,path_or_buf=os.path.join(data_dir,'render_meta_data.csv'))

### Data for CNN stuff

In [None]:
D_birds = stroke_df[stroke_df['category']=='bird']

In [None]:
D_bj = D_birds[D_birds['target']=='bluejay']

In [None]:
D_tt = D_birds[D_birds['target']=='tomtit']

In [None]:
reload(srh)
really_run = True

if really_run==True:

    for sketch in D_bj.sketch_id.unique():
        this_sketch = D_bj.query('sketch_id == @sketch')
        svgs = list(this_sketch.svg)
        srh.render_svg(svgs,out_dir ="train_bj", base_dir=plot_dir,out_fname='{}.svg'.format(sketch))


In [None]:
reload(srh)
really_run = True

if really_run==True:

    for sketch in D_tt.sketch_id.unique():
        this_sketch = D_tt.query('sketch_id == @sketch')
        svgs = list(this_sketch.svg)
        srh.render_svg(svgs,out_dir="test_tt" ,base_dir=plot_dir,out_fname='{}.svg'.format(sketch))


In [None]:
### Create path to lesioned svgs and convert to png for feature extraction
really_run = True

if really_run==True:
    svg_paths= srh.generate_svg_path_list(os.path.join(plot_dir,'train_bj'))
    srh.svg_to_png(svg_paths,out_dir="train_bj_png",base_dir=plot_dir)

In [None]:
### Create path to lesioned svgs and convert to png for feature extraction
really_run = True

if really_run==True:
    svg_paths= srh.generate_svg_path_list(os.path.join(plot_dir,'test_tt'))
    srh.svg_to_png(svg_paths,out_dir="test_tt_png",base_dir=plot_dir)

### Autoencoder Work

In [None]:
from torchvision.utils import save_image
from torchvision import datasets
import torchvision
from torch.autograd import Variable

In [None]:
#train_set = os.path.join(plot_dir, "train_bj_png")


train_set = datasets.ImageFolder(os.path.join(plot_dir), transform=torchvision.transforms.ToTensor())

In [None]:
dataloader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=False)

In [None]:
test_set = datasets.ImageFolder(os.path.join(plot_dir), transform=torchvision.transforms.ToTensor())

In [None]:
test_loader = torch.utils.data.DataLoader(test_set, batch_size=4, shuffle=False)

In [None]:

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder,self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=5),
            nn.ReLU(True),
            nn.Conv2d(6,16,kernel_size=5),
            nn.ReLU(True))
        self.decoder = nn.Sequential(             
            nn.ConvTranspose2d(16,6,kernel_size=5),
            nn.ReLU(True),
            nn.ConvTranspose2d(6,3,kernel_size=5),
            nn.ReLU(True))
    def forward(self,x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
num_epochs = 5 #you can go for more epochs, I am using a mac
batch_size = 128

In [None]:
model = Autoencoder().cpu()
distance = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),weight_decay=1e-5)

In [None]:
for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        # ===================forward=====================
        output = model(img)
        loss = distance(output, img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
  #  print('epoch [{}/{}], loss:{:.4f}'.format(epoch+1, num_epochs, loss.data()))

In [None]:
model()

In [None]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
ims= []
for data in test_loader:
    for d in data:
        img, _ = data
        p_img = model(img).detach()
        for p in p_img:
            ims.append(np.array(p).transpose(1,2,0))

In [None]:
len(ims)

In [None]:
ims[0].shape

In [None]:
ims[25].shape

In [None]:
ims[0][0].shape

In [None]:
ims[0].shape

In [None]:
plt.imshow(ims[120])

In [None]:
np.amax(ims[600])