In [None]:
import os
import sys
import shutil
import pandas as pd
import socket
import json
import numpy as np
import base64
import time
from io import BytesIO
from scipy.spatial.distance import pdist, squareform
from itertools import combinations 
import scipy as sp
from tqdm import tqdm


import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

sns.set(style="whitegrid")

In [None]:
# directory & file hierarchy
proj_dir = os.path.abspath('..')
results_dir = os.path.join(proj_dir,'results')
csv_dir = os.path.join(results_dir,'csv')
feature_dir = os.path.abspath(os.path.join(proj_dir,'features'))
    
def make_dir_if_not_exists(dir_name):   
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    return dir_name

## create directories that don't already exist        
result = [make_dir_if_not_exists(x) for x in [results_dir,csv_dir,feature_dir]]

### load in features and metadata

In [None]:
K = pd.read_csv(os.path.join(csv_dir, 'photodraw2x2_sketch_data.csv'))
T = pd.read_csv(os.path.join(csv_dir, 'photodraw2x2_stroke_data.csv'))
S = pd.read_csv(os.path.join(csv_dir, 'photodraw2x2_survey_data.csv'))

S = S.reset_index(drop = True)
F = np.load(os.path.join(feature_dir, f'FEATURES_FC6_photodraw2x2_sketch.npy'))
Fi = np.load(os.path.join(feature_dir, f'photodraw2x2_instance_features.npy'))

In [None]:
K = K.sort_values(by='feature_ind')
KF = pd.concat([pd.DataFrame(F), K], axis=1)
KF.sort_values(by=['goal', 'condition', 'category'], ascending = True, inplace = True)

K = K.sort_values(by='feature_ind_instance')
KFi = pd.concat([pd.DataFrame(Fi), K], axis=1)
KFi.sort_values(by=['goal', 'condition', 'category'], ascending = True, inplace = True)

category_means = []
for name, group in KF.groupby(['goal', 'condition', 'category']):
    if len(category_means)==0:
        category_means = group[np.arange(4096)].mean(axis=0)
    else:
        category_means = np.vstack((category_means, group[np.arange(4096)].mean(axis=0)))
        
category_means_i = []
for name, group in KFi.groupby(['goal', 'condition', 'category']):
    if len(category_means_i)==0:
        category_means_i = group[np.arange(1000)].mean(axis=0)
    else:
        category_means_i = np.vstack((category_means_i, group[np.arange(1000)].mean(axis=0)))

### within category/experiment variance!

In [None]:
def high_dim_variance(X):
    return sum(np.linalg.norm(x_i - x_j)**2 for x_i, x_j in combinations(X, 2)) / (len(X))**2

K = K.sort_values(by='feature_ind')
KF = pd.concat([pd.DataFrame(F), K], axis=1)
KF.sort_values(by=['goal', 'condition', 'category'], ascending = True, inplace = True)

K = K.sort_values(by='feature_ind_instance')
KF = pd.concat([pd.DataFrame(Fi, columns = np.arange(4096, 4096 + 2048)), KF], axis=1)


df = pd.DataFrame(columns = ['category', 'condition', 'goal', 'fc6_variance', 'inst_variance'])
i = 0
for ind, group in KF.groupby(['category', 'condition', 'goal']):
    indx = list(ind)
    indx.append(high_dim_variance(np.array(group[np.arange(4096)])))
    indx.append(high_dim_variance(np.array(group[np.arange(4096, 4096 + 2048)])))
    df.loc[i] = indx
    i += 1
    
sns.barplot(data = df, x = 'condition', y = 'fc6_variance', hue = 'goal');
plt.title('variance (fc6)!');
plt.show()
sns.barplot(data = df, x = 'condition', y = 'inst_variance', hue = 'goal');
plt.title('variance (instance)!');
plt.show()

In [None]:
df = pd.read_csv(os.path.join(csv_dir, 'photodraw2x2_category_by_experiment_variances.csv'))
df.groupby(['condition', 'goal'])['fc6_variance'].mean()

### Get gallery stims for cogsci 2021

In [None]:
gall_path_2x2 = make_dir_if_not_exists(os.path.abspath('../../photodraw_latex/cogsci2021/photodraw32_gallery_examples'))
cat = 'butterfly'
fn = lambda obj: obj.loc[np.random.choice(obj.index),:]

group = K[(K.category == cat) & (K.condition == 'photo')]
lows = sorted(group['inst_typicality'].unique())[:3]
highs = sorted(group['inst_typicality'].unique())[-3:]

lowURLs  = group[group.inst_typicality.isin(lows)].groupby('imageURL', as_index = False).\
                                                                        apply(fn).sample(3).imageURL.values
highURLs = group[group.inst_typicality.isin(highs)].groupby('imageURL', as_index = False).\
                                                                        apply(fn).sample(3).imageURL.values

for i, g in K[K.category == cat].groupby('experiment'):
    path = make_dir_if_not_exists(os.path.join(gall_path_2x2, g.experiment.values[0]))
    if all(g.condition == 'text'):
        images = [Image.open(BytesIO(base64.b64decode(imgdata))).resize((224,224)) for \
                                                                  imgdata in g.pngData.sample(6).values]
        [im.save(os.path.join(path, f"{g.experiment.values[0]}_{cat}_{i}.png")) for i, im in enumerate(images)]
        
    else:
        atyp = g[g.imageURL.isin(lowURLs)]
        typ = g[g.imageURL.isin(highURLs)]
        atyp = atyp.groupby('imageURL', as_index = False).apply(fn).sample(3)
        typ = typ.groupby('imageURL', as_index = False).apply(fn).sample(3)
        
        images_atyp, at = atyp.pngData.values, atyp.imageURL.values
        images_typ, t  = typ.pngData.values, typ.imageURL.values
        
        images_atyp = [Image.open(BytesIO(base64.b64decode(imgdata))).resize((224,224)) for imgdata in images_atyp]
        images_typ  = [Image.open(BytesIO(base64.b64decode(imgdata))).resize((224,224)) for imgdata in images_typ]

        [im.save(os.path.join(path, f"{g.experiment.values[0]}_{cat}_atypical_{i}.png")) \
                                                                                 for i, im in zip(at, images_atyp)]
        [im.save(os.path.join(path, f"{g.experiment.values[0]}_{cat}_typical_{i}.png"))  \
                                                                                 for i, im in zip(t, images_typ)]
        
        stims_path_atyp = [os.path.abspath(os.path.join(proj_dir, f'stimuli/photodraw32_stims_copy/{cat}_{url}.png'))\
                                                                       for url in atyp.imageURL]
        stims_path_typ  = [os.path.abspath(os.path.join(proj_dir, f'stimuli/photodraw32_stims_copy/{cat}_{url}.png'))\
                                                                       for url in typ.imageURL]
        
        [shutil.copyfile(src, os.path.join(path, src.split('\\')[-1])) for src in stims_path_atyp]
        [shutil.copyfile(src, os.path.join(path, src.split('\\')[-1])) for src in stims_path_typ]
