### Run setup.ipynb to get set up

In [None]:
import os
import re
import sys
import socket
import requests
import numpy as np
import pandas as pd

from PIL import Image
from sklearn import preprocessing

%matplotlib inline
import matplotlib
from matplotlib import pylab, mlab, pyplot, gridspec
from IPython.core.pylabtools import figsize, getfigs
plt = pyplot
import seaborn as sns
sns.set_context('talk')
sns.set_style('white')

from IPython.display import clear_output

import utils
from utils import generate_acc_probs, generate_acc_probs_2x2, generate_2x2_plots, perform_cross_validation, perform_cross_validation_twice, adjacent_plots, cat_cond_diffplots

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=matplotlib.cbook.mplDeprecation)
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")

### set up paths

In [None]:
# directory & file hierarchy
proj_dir = os.path.abspath('..')
analysis_dir = os.getcwd()
results_dir = os.path.join(proj_dir,'results')
plot_dir = os.path.join(results_dir,'plots')
csv_dir = os.path.join(results_dir,'csv')
if socket.gethostname() == 'nightingale':
    feature_dir = os.path.abspath('/mnt/pentagon/photodraw/features/')
else:
    feature_dir = os.path.abspath(os.path.join(proj_dir, 'features', 'photodraw12'))

meta_path = os.path.abspath(os.path.join(feature_dir, 'metadata_pixels.csv'))
image_path = os.path.abspath(os.path.join(feature_dir, 'flattened_sketches_pixels.npy'))
meta_path_fc6 = os.path.abspath(os.path.join(feature_dir, 'METADATA_sketch.csv'))
image_path_fc6 = os.path.abspath(os.path.join(feature_dir, 'FEATURES_FC6_photodraw_sketch.npy'))

# add helpers to python path
if os.path.join(proj_dir,'utils') not in sys.path:
    sys.path.append(os.path.join(proj_dir,'utils'))   

def make_dir_if_not_exists(dir_name):   
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    return dir_name

# create directories that don't already exist        
result = [make_dir_if_not_exists(x) for x in [results_dir,plot_dir,csv_dir,feature_dir]]

In [None]:
# read in csv
T = pd.read_csv(os.path.join(csv_dir,'photodraw_stroke_data.csv'))
K = pd.read_csv(os.path.join(csv_dir,'photodraw_sketch_data.csv'))
S = pd.read_csv(os.path.join(csv_dir,'photodraw_survey_data.csv'))
M = pd.read_csv(meta_path)
M = M[~M.gameID.isin(list(set(M.gameID.values) - set(K.gameID.values)))]
F = np.load(image_path)
F = F[M.index]
M = M.reset_index(drop=True)
M_fc6 = pd.read_csv(meta_path_fc6)
F_fc6 = np.load(image_path_fc6)
F_norm = F - F.mean(axis=0)
F_fc6_norm = F_fc6 - F_fc6.mean(axis=0)
F_scaled = preprocessing.scale(F)
F_fc6_scaled = preprocessing.scale(F_fc6)

classes = ['airplane', 'bike', 'bird', 'car', 'cat', 'chair', 'cup', 'hat', 'house', 'rabbit', 'tree', 'watch']

In [None]:
# remove images flagged as invalid or outliers
def remove_invalid(frame):
    return frame[frame.isInvalid == False]
def remove_flagged(frame):
    return frame[(frame.isOutlier == False) & (frame.isInvalid==False)]
def remove_invalid_T(T):
    thinghthing = K[K.isInvalid==True][['gameID','trialNum']].values
    return T[(~T.gameID.isin(thinghthing[:][0])) & (~T.trialNum.isin(thinghthing[:][1]))]
def remove_flagged_T(T):
    thinghthing = K[(K.isOutlier==True) | (K.isInvalid==True)][['gameID','trialNum']].values
    return T[(~T.gameID.isin(thinghthing[:][0])) & (~T.trialNum.isin(thinghthing[:][1]))]

In [None]:
# Being extra explicit now to avoid possible errors and headaches down the line
K_remove_invalid = remove_invalid(K)
K_remove_flagged = remove_flagged(K)

T_remove_invalid = remove_invalid_T(T)
T_remove_flagged = remove_flagged_T(T)

F_norm_remove_invalid = F_norm[remove_invalid(M).index]
F_norm_remove_flagged = F_norm[remove_flagged(M).index]
F_fc6_norm_remove_invalid = F_fc6_norm[remove_invalid(K).fc6_feature_ind.values]
F_fc6_norm_remove_flagged = F_fc6_norm[remove_flagged(K).fc6_feature_ind.values]

M_remove_invalid = remove_invalid(M).reset_index(drop=True)
M_remove_flagged = remove_flagged(M).reset_index(drop=True)

## How many strokes are people using?

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='condition', y='numStrokes', plottype='violinplot')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='condition', y='numStrokes', plottype='barplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'numStrokes_condition_barplot.png')), bbox_inces='tight')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='category', y='numStrokes', plottype='violinplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'numStrokes_category_violinplot.png')), bbox_inces='tight')

## How much time are people spending drawing?

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='condition', y='activeSketchTime', plottype='violinplot')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='condition', y='activeSketchTime', plottype='barplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'activeSketchTime_condition_barplot.png')), bbox_inces='tight')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x = 'category', y = 'activeSketchTime', plottype = 'violinplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'activeSketchTime_category_violinplot.png')), bbox_inces='tight')

## How much ink are people using?

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='condition', y='totalInk', plottype='violinplot')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='condition', y='totalInk', plottype='barplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'totalInk_condition_barplot.png')), bbox_inces='tight')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='category', y='totalInk', plottype='violinplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'totalInk_category_violinplot.png')), bbox_inces='tight')

## How is sketch recognizability?

#### Pixel-level classifier

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='condition', y='prob_true_predict_pixel', plottype='violinplot')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='condition', y='prob_true_predict_pixel', plottype='barplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'pixelprobs_condition_barplot.png')), bbox_inces='tight')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='category', y='prob_true_predict_pixel', plottype='violinplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'pixelprobs_category_violinplot.png')), bbox_inces='tight')

#### fc6 feature-level classifier

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='condition', y='prob_true_predict_fc6', plottype='violinplot')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='condition', y='prob_true_predict_fc6', plottype='barplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'fc6probs_condition_barplot.png')), bbox_inces='tight')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='category', y='prob_true_predict_fc6', plottype='violinplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'fc6probs_category_violinplot.png')), bbox_inces='tight')

## How long are the strokes?

Note: Arc length has not had outliers removed since it is stroke level data; removing a single stroke does not make sense

In [None]:
adjacent_plots(T_remove_invalid, T_remove_flagged, x='condition', y='arcLength', plottype='violinplot')

In [None]:
adjacent_plots(T_remove_invalid, T_remove_flagged, x='condition', y='arcLength', plottype='barplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'arcLength_condition_barplot.png')), bbox_inces='tight')

In [None]:
adjacent_plots(T_remove_invalid, T_remove_flagged, x='category', y='arcLength', plottype='barplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'arcLength_category_barplot.png')), bbox_inces='tight')

### Density of low-level features

#### How are the number of strokes distributed, by condition?

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, x='numStrokes', y='density', plottype='distplot')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, 'activeSketchTime', 'density', 'distplot')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, 'totalInk', 'density', 'distplot')

In [None]:
adjacent_plots(T_remove_invalid, T_remove_flagged, 'arcLength', 'density', 'distplot')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, 'prob_true_predict_pixel', 'density', 'distplot')

In [None]:
adjacent_plots(K_remove_invalid, K_remove_flagged, 'prob_true_predict_fc6', 'density', 'distplot')

#### Is there any clear correlation of arc length stroke time between conditions?

In [None]:
sns.scatterplot(x=T_remove_flagged['endStrokeTime']-T_remove_flagged['startStrokeTime'], y=T_remove_flagged['arcLength'], hue=T_remove_flagged['condition'])
plt.xlabel('stroke time'), plt.title('Arc length vs time needed for each stroke');

<br><br>
## Within-participant analyses
<br>

In [None]:
numStrokes_diff = adjacent_plots(K_remove_invalid, K_remove_flagged, plottype = 'regplot', x='numStrokes')

In [None]:
cat_cond_diffplots(K_remove_invalid, K_remove_flagged, 'numStrokes', 'barplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'numStrokes_catcond_diffplot.png')), bbox_inces='tight')

In [None]:
activeSketchTime_diff = adjacent_plots(K_remove_invalid, K_remove_flagged, plottype = 'regplot', x='activeSketchTime')

In [None]:
cat_cond_diffplots(K_remove_invalid, K_remove_flagged, 'activeSketchTime', 'barplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'activeSketchTime_catcond_diffplot.png')), bbox_inces='tight')

In [None]:
totalInk_diff = adjacent_plots(K_remove_invalid, K_remove_flagged, plottype = 'regplot', x='totalInk')

In [None]:
cat_cond_diffplots(K_remove_invalid, K_remove_flagged, 'totalInk', 'barplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'totalInk_catcond_diffplot.png')), bbox_inces='tight')

#### Data is too noisy to make any conclusions about arc length across participants

In [None]:
arcLength_diff = adjacent_plots(T_remove_invalid, T_remove_flagged, plottype = 'regplot', x='arcLength')

In [None]:
cat_cond_diffplots(T_remove_invalid, T_remove_flagged, 'arcLength', 'barplot')
plt.savefig(os.path.abspath(os.path.join(plot_dir,'arcLength_catcond_diffplot.png')), bbox_inces='tight')

In [None]:
diffframe = pd.DataFrame([activeSketchTime_diff, numStrokes_diff, totalInk_diff, arcLength_diff]).T
fig, ax = plt.subplots(1, 4, figsize=(12,3))
sns.boxplot(diffframe[0], ax=ax[0]).set(xlabel='active sketch time')
sns.boxplot(diffframe[1], ax=ax[1]).set(xlabel='number of strokes')
sns.boxplot(diffframe[2], ax=ax[2]).set(xlabel='total ink')
sns.boxplot(diffframe[3], ax=ax[3]).set(xlabel='arc length')
plt.suptitle('Participant-level across-category difference of low-level features'),plt.tight_layout(rect=[0, 0.06, 1, 0.95]);
fig.text(0.5, 0.03, 'photo-cue by text-cue difference', ha='center');
plt.savefig(os.path.abspath(os.path.join(plot_dir,'low_level_diffplot.png')), bbox_inces='tight')

### Category and condition level distribution plots for low level features (only preprocessed sketches)

In [None]:
g = sns.FacetGrid(K_remove_flagged, col="category", hue='condition', col_wrap=4)
g = (g.map(sns.distplot, "activeSketchTime", rug=False, hist=False).add_legend())
plt.suptitle('Distribution of active sketch time by condition per category'), plt.tight_layout(rect=[0, 0.03, 1, 0.95]);

In [None]:
g = sns.FacetGrid(K_remove_flagged, col="category", hue='condition', col_wrap=4)
g = (g.map(sns.distplot, "numStrokes", rug=False, hist=False).add_legend())
plt.suptitle('Distribution of strokes per sketch by condition per category'), plt.tight_layout(rect=[0, 0.03, 1, 0.95]);

In [None]:
g = sns.FacetGrid(remove_flagged_T(T), col="category", hue='condition', col_wrap=4)
g = (g.map(sns.distplot, "arcLength", rug=False, hist=False).add_legend())
plt.suptitle('Distribution of stroke arc length by condition per category'), plt.tight_layout(rect=[0, 0.03, 1, 0.95]);

In [None]:
g = sns.FacetGrid(K_remove_flagged, col="category", hue='condition', col_wrap=4)
g = (g.map(sns.distplot, "totalInk",rug=False,hist=False)
      .add_legend())
plt.suptitle('Distribution of "ink" used per sketch by condition per category'), plt.tight_layout(rect=[0, 0.03, 1, 0.95]);

<br>  <br><br><br><br><br>

## Are there any effects of trial number (indicating fatigue)?

In [None]:
# trial number and numStrokes
plt.figure(figsize=(14,3))
plt.subplot(131)
sns.barplot(x='trialNum', y='numStrokes', data=K_remove_flagged).set_ylabel('num strokes')
plt.subplot(132)
sns.barplot(x='trialNum', y='activeSketchTime', data=K_remove_flagged).set_ylabel('sketch time')
plt.subplot(133)
sns.barplot(x='trialNum', y='arcLength', data=T_remove_flagged).set_ylabel('arc length')
plt.suptitle('num strokes, active sketch time, and arc length by trial number').set_position([.5, 1.05])
plt.tight_layout(rect=[0, 0.03, 1, 0.95]);

## Are there any effects of device type?

In [None]:
inputdevice_dict = dict(zip(S.gameID, S.inputDevice))
K['inputDevice'] = K['gameID'].map(inputdevice_dict)
T['inputDevice'] = T['gameID'].map(inputdevice_dict)
plt.hist(K['inputDevice']), plt.ylabel('number of devices'), plt.title('Number of devices used in experiment');

In [None]:
inputdevice_dict = dict(zip(S.gameID,S.inputDevice))
K['inputDevice'] = K['gameID'].map(inputdevice_dict)
T['inputDevice'] = T['gameID'].map(inputdevice_dict)

plt.figure(figsize=(15, 7))
plt.subplot(231)
sns.barplot(x='inputDevice', y='numStrokes', data=K) .set_ylabel('num strokes')
plt.subplot(232)
sns.barplot(x='inputDevice', y='activeSketchTime', data=K).set_ylabel('sketch time')
plt.subplot(233)
sns.barplot(x='inputDevice', y='arcLength', data=T).set_ylabel('arc length')
plt.suptitle('num strokes, active sketch time, and arc length by input device')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])

### Category-level evidence quantifying  
Train a 12-way classifier to quantify how much category evidence is in each sketch

In [None]:
### note: may be obsolete now that M is not used

# Take in pngData for one image, checks if the converted pngData array matches a feature vector F and metadata M
def auditor(pngData, i, F, M): # where i is the row of the observation in pngData
    rgbarr = utils.pngToArray(pngData)
    
    # do all pixels of the i'th observation match up with pixels of the i'th observation of K?
    assert np.count_nonzero((rgbarr) == F[i]) == 150528 
    
    # do the rows K match up with the rows of M?
    assert K.iloc[i]['_id'] == M['_id'][i]
    
    # does M's corresponding drawing correspond to F's corresponding drawing? (it better) 
    assert np.count_nonzero(utils.pngToArray(str(K[K['_id'] == M['_id'][i]]['pngData'].values)) == F[i]) == 150528
    pass

# if our feature and metadata match up this should return True
all(element is None for element in [auditor(K['pngData'][i], i, F, M) for i in range(len(F))])

### Softmax classifiers

In [None]:
# Cross-validation on a 12-way softmax classifier to get baseline accuracies and to visualize data 
# What is the average accuracy from 10-fold cross vaidation?
cv_results = perform_cross_validation(features=F_norm_remove_invalid, 
                                      labels=K_remove_invalid['category'],
                                      num_folds=10,
                                      input_type='sketch pixels',
                                      prediction_type='category',
                                      output=True)

In [None]:
# mean coefficients across the cross validated regressions
coefs = np.array([cv_results['estimator'][i].coef_.reshape(12,224,224,3) for i in range(10)]).mean(axis=0)
coefs_scaled = ((coefs.squeeze() - coefs.min()) / (coefs.max() - coefs.min()) * 255.0).astype(int)

for index,weights in enumerate(coefs_scaled):
    plt.subplot(2, 6, index + 1)
    plt.imshow(weights)
    plt.axis('off')
    plt.title(classes[index])

In [None]:
cv_results_photo, cv_results_text = perform_cross_validation_twice(features=F_norm_remove_invalid,
                                                                   metadata=M_remove_invalid,
                                                                   labels=M_remove_invalid['cat_codes'],
                                                                   num_folds=10,
                                                                   input_type='sketch pixel data',
                                                                   prediction_type='category',
                                                                   output=True)

In [None]:
# get stimuli into 36x224x224x3 array
baseurl = 'https://photodraw.s3.amazonaws.com/___.png'
urls = []
for name in classes:
    for i in range(1,4):
        urls.append(baseurl.replace('___', name + "_" + str(i)))
stims = []
for url in urls: 
    response = requests.get(url, stream=True)
    stims.append(np.array(Image.open(response.raw).resize((224,224))))
stims = np.array(stims, dtype='double')

# extra way to get category labels
category2score = dict(zip(classes, list(range(12))))
photocue_stim_classes = [re.findall(r'\w+',url)[5][:-2] for url in urls]
photocue_stim_classes = np.array([category2score[i] for i in photocue_stim_classes])

# convert to 36x224*224*3 array
photocue_stims = np.array([stims[i].flatten() for i in range(len(stims))])
photocue_stims = np.asarray(photocue_stims, dtype='double')
# normalize the data: subtract mean image
mean_image = np.mean(photocue_stims, axis = 0)
photocue_stims -= mean_image

photo_preds=np.mean([cv_results_photo['estimator'][i].score(photocue_stims, photocue_stim_classes) for i in range(10)])
text_preds=np.mean([cv_results_text['estimator'][i].score(photocue_stims, photocue_stim_classes) for i in range(10)])
print(f'When predicting on the photo-cue stimuli, the test accuracy of a logistic regression trained on photo-cue data is {round(photo_preds,3)}')
print(f'When predicting on the photo-cue stimuli, the test accuracy of a logistic regression trained on text-cue  data is {round(text_preds,3)}')

In [None]:
# mean coefficients across the cross validated regressions -- photo-cue classifier
coefs = np.array([cv_results_photo['estimator'][i].coef_.reshape(12,224,224,3) for i in range(10)]).mean(axis=0)
coefs_scaled = ((coefs.squeeze() - coefs.min()) / (coefs.max() - coefs.min()) * 255.0).astype(int)

for index,weights in enumerate(coefs_scaled):
    plt.subplot(2, 6, index + 1)
    plt.imshow(weights)
    plt.axis('off')
    plt.title(classes[index])

In [None]:
# mean coefficients across the cross validated regressions -- text-cue classifier
coefs = np.array([cv_results_text['estimator'][i].coef_.reshape(12,224,224,3) for i in range(10)]).mean(axis=0)
coefs_scaled = ((coefs.squeeze() - coefs.min()) / (coefs.max() - coefs.min()) * 255.0).astype(int)

for index,weights in enumerate(coefs_scaled):
    plt.subplot(2, 6, index + 1)
    plt.imshow(weights)
    plt.axis('off')
    plt.title(classes[index])

In [None]:
# generate prediction probabilities and confusion matrices on photo/text-trained/tested classifiers
class_probs_2x2_pixel, acc_scores_2x2_pixel = generate_acc_probs_2x2(F_norm_remove_invalid, 
                                                                     M_remove_invalid, 
                                                                     num_splits=5, 
                                                                     num_repeats=5)

In [None]:
generate_2x2_plots(class_probs_2x2_pixel, 'pixel-level', 'probabilities')

In [None]:
generate_2x2_plots(acc_scores_2x2_pixel, 'pixel-level', 'confusion matrix')

### fc6 feature-level classification

In [None]:
perform_cross_validation(features=F_fc6_norm_remove_invalid, 
                         labels=K_remove_invalid.sort_values(by='fc6_feature_ind').reset_index(drop=True)['category'],
                         num_folds=10,
                         input_type='fc6 feature vectors',
                         prediction_type='category',
                         output=False)

In [None]:
# generate prediction probabilities and confusion matrices of fc6 layer features
class_probs_fc6, acc_scores_fc6 = generate_acc_probs(features=F_fc6_norm_remove_invalid, 
                                                     metadata=K_remove_invalid,
                                                     num_splits=5,
                                                     num_repeats=2)

In [None]:
plt.figure(figsize=(12,10))
sns.heatmap(class_probs_fc6), plt.xlabel('predicted class probabilities'), plt.ylabel('correct category')
plt.title('Category prediction probabilities per category label (VGG19 layers)');

In [None]:
plt.figure(figsize=(12,10))
sns.heatmap(acc_scores_fc6), plt.title('Confusion matrix for category predictions');
plt.xlabel('Prediction'), plt.ylabel('True label');

In [None]:
perform_cross_validation_twice(features=F_fc6_norm_remove_invalid,
                               metadata=K_remove_invalid.sort_values(by='fc6_feature_ind').reset_index(drop=True),
                               labels=K_remove_invalid.sort_values(by='fc6_feature_ind').reset_index(drop=True)['category'],
                               num_folds=10,
                               input_type='fc6 feature vectors',
                               prediction_type='category',
                               output=False)

In [None]:
# generate prediction probabilities and confusion matrices on photo/text-trained/tested classifiers using fc6 features
class_probs_2x2_fc6, acc_scores_2x2_fc6 = generate_acc_probs_2x2(features=F_fc6_norm_remove_invalid, 
                                                                 metadata=K_remove_invalid,
                                                                 num_splits=5,
                                                                 num_repeats=5)

In [None]:
generate_2x2_plots(class_probs_2x2_fc6, 'VGG-19', 'probabilities')

In [None]:
generate_2x2_plots(acc_scores_2x2_fc6, 'VGG-19', 'confusion matrix')

### Can the pixel-level data classify things other than category?

In [None]:
perform_cross_validation(features=F_norm_remove_invalid, 
                         labels=M_remove_invalid['cond_codes'],
                         num_folds=10,
                         input_type='sketch pixel data',
                         prediction_type='condition',
                         output=False)

In [None]:
perform_cross_validation(features=F_norm_remove_invalid, 
                         labels=M_remove_invalid['cat_id_codes'],
                         num_folds=3,
                         input_type='sketch pixel data',
                         prediction_type='category-photoid pairs (including text)',
                         output=False)

In [None]:
perform_cross_validation(features=F_norm_remove_invalid, 
                         labels=M_remove_invalid['cat_cond_codes'],
                         num_folds=5,
                         input_type='sketch pixel data',
                         prediction_type='category-condition pairs',
                         output=False)

In [None]:
perform_cross_validation(features=F_norm_remove_invalid, 
                         labels=M_remove_invalid['photoid_codes'],
                         num_folds=5,
                         input_type='sketch pixel data',
                         prediction_type='photo-ids',
                         output=False)

### Can other numStrokes, activeSketchTime, and totalInk predict category and condition?

In [None]:
features = K_remove_flagged[['numStrokes','activeSketchTime','totalInk']]
features_norm = features - features.mean(axis=0)

metadata_cat = K_remove_flagged['category'].astype('category').cat.codes.values
metadata_cond = K_remove_flagged['condition'].astype('category').cat.codes.values

In [None]:
perform_cross_validation(features=features_norm, 
                         labels=metadata_cat,
                         num_folds=10,
                         input_type='number of strokes, active sketch time, and total ink',
                         prediction_type='category',
                         output=False)

In [None]:
perform_cross_validation(features=features_norm, 
                         labels=metadata_cond,
                         num_folds=10,
                         input_type='number of strokes, active sketch time, and total ink',
                         prediction_type='condition',
                         output=False)

### Category-photoid prediction probabilities and confusion matrices

In [None]:
M_photo = M_remove_invalid[M_remove_invalid.condition == 'photo']
F_norm_photo = F_norm_remove_invalid[M_photo.index]
M_photo = M_photo.reset_index(drop=True)
cat_ids_photo = ['cat_id_codes', 'category_id_pair', [cat + '_' + str(i) for cat in classes for i in range(1,4)]]

cat_photoid_probs, cat_photoid_accs = generate_acc_probs(features=F_norm_photo,
                                                         metadata=M_photo,
                                                         num_splits=2,
                                                         num_repeats=1,
                                                         alt_labels=cat_ids_photo)

In [None]:
plt.figure(figsize=(17,14))
sns.heatmap(cat_photoid_probs), plt.xlabel('predicted class probabilities'), plt.ylabel('correct category')
plt.title('Category-photoid prediction probabilities per category-photoid label (pixel-level)');

In [None]:
plt.figure(figsize=(17,14))
sns.heatmap(cat_photoid_accs), plt.title('Confusion matrix for category predictions (pixel-level)');
plt.xlabel('Prediction'), plt.ylabel('True label');

In [None]:
M_fc6_photo = K_remove_invalid[K_remove_invalid.condition == 'photo']
F_fc6_norm_photo = F_fc6_norm_remove_invalid[M_fc6_photo.index]
M_fc6_photo = K_remove_invalid.reset_index(drop=True)

cat_photoid_probs_fc6, cat_photoid_accs_fc6 = generate_acc_probs(features=F_fc6_norm_photo,
                                                         metadata=M_fc6_photo,
                                                         num_splits=3,
                                                         num_repeats=3,
                                                         alt_labels=cat_ids_photo)

In [None]:
plt.figure(figsize=(17,14))
sns.heatmap(cat_photoid_probs_fc6), plt.xlabel('predicted class probabilities'), plt.ylabel('correct category')
plt.title('Category-photoid prediction probabilities per category-photoid label (VGG-19)');

In [None]:
plt.figure(figsize=(17,14))
sns.heatmap(cat_photoid_accs_fc6), plt.title('Confusion matrix for category predictions (VGG-19)');
plt.xlabel('Prediction'), plt.ylabel('True label');

<br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br>