### This notebook generates metadata that associates confederate stimuli to each sketch for the instance-level recognition experiment

In [None]:
import os
import ast
import random
import requests
import numpy as np
import pandas as pd

import seaborn as sns
from PIL import Image
from io import BytesIO
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

from sklearn.neighbors import NearestNeighbors
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold

In [None]:
# Set up paths
proj_dir = os.path.abspath('..')
results_dir = os.path.join(proj_dir,'results')
csv_dir = os.path.join(results_dir,'csv')
feature_dir = os.path.abspath(os.path.join(proj_dir,'features'))
stims_dir = os.path.abspath(os.path.join(proj_dir, 'stimuli', 'photodraw32_stims_agglomerate'))

In [None]:
# import data
K = pd.read_csv(os.path.join(csv_dir, 'photodraw2x2_sketch_data.csv'))
F = np.load(os.path.join(feature_dir, f'FEATURES_FC6_photodraw2x2_sketch.npy'))
M = pd.read_csv(os.path.join(feature_dir, f'METADATA_photodraw2x2_sketch.csv'))
IMF = np.load(os.path.join(feature_dir, f'FEATURES_FC6_photodraw2x2_image.npy'))
IMM = pd.read_csv(os.path.join(feature_dir, f'METADATA_photodraw2x2_image.csv'))
IMFI = np.load(os.path.join(feature_dir, f'photodraw2x2_instance_features.npy'))
IMMI = pd.read_csv(os.path.join(feature_dir, f'photodraw2x2_metadata_instance.csv'))

In [None]:
# Get data into neater formats for us to work with
KF = pd.concat([K, pd.DataFrame(F)], axis=1)
IMM[['category', 'id', 'instance']] =  IMM.image_id.str.rsplit('_', 2, expand=True)
IMM['instance_id'] = IMM['id'] + '_' + IMM['instance'] 
IMMF = pd.concat([IMM, pd.DataFrame(IMF[IMM.feature_ind])], axis=1)
IMMFI = pd.concat([IMMI, pd.DataFrame(IMFI[IMMI.feature_ind])], axis=1)

### First: run image classification on stimuli for sanity

In [None]:
def compute_class_predictions(data, labels):
    # setup cross validation framework
    kFold = StratifiedKFold(n_splits = 5, shuffle = True, random_state = 0)
    logit = LogisticRegression(max_iter=1000)

    prob_dict = {}
    pred_dict = {} # get it?
    for train_ind, test_ind in kFold.split(data, labels):
        # fit logistic regression and make indices corresponding to each of the model classes
        model = logit.fit(data[train_ind], labels[train_ind])
        class_inds = {label : ind for ind, label in enumerate(model.classes_)}

        # predict model probabilities and get find true class prediction probability
        probarrs = model.predict_proba(data[test_ind])
        label_inds = np.asarray([class_inds[label] for label in labels[test_ind]])[:, None]
        prob_of_classification = np.take_along_axis(probarrs, label_inds, axis=1).reshape((-1,))

        # update dictionaries with probabilities and predictions
        prob_dict.update(list(zip(test_ind, prob_of_classification)))
        pred_dict.update(list(zip(test_ind, model.predict(data[test_ind]) == labels[test_ind])))

    # ensure we are actually getting probabilities
    assert all(np.fromiter(prob_dict.values(), dtype = float) <= 1)
    assert all(np.fromiter(prob_dict.values(), dtype = float) >= 0)
    
    return prob_dict, pred_dict 

In [None]:
prob_dict, pred_dict = compute_class_predictions(IMMF[np.arange(4096)].values, IMMF.category.values)
IMMF['prob_true_predict_fc6'] = IMMF.index.map(prob_dict)
IMMF['true_predict_fc6'] = IMMF.index.map(pred_dict)

# check that we are getting a reasonable classification accuracy
IMMF['true_predict_fc6'].mean()

In [None]:
# sanity check that our features are extracted correctly: otherwise we would perform at chance
plt.figure(figsize=(6,8))
sns.barplot(data = IMMF, x = 'prob_true_predict_fc6', y = 'category');

In [None]:
print('Least recognizable: ', IMMF.groupby('category')['prob_true_predict_fc6'].mean().nsmallest().index.values)
print('Most recognizable: ', IMMF.groupby('category')['prob_true_predict_fc6'].mean().nlargest().index.values)

### Goal: get the 8-nearest images corresponding to each sketch

#### Store the paths to these images in a dataframe. This will be the metadata for `recogdraw_instance`

In [None]:
reallyRun = False
if reallyRun:
    # import stimuli data
    stims_metadata = pd.read_csv('photodraw32_metadata.csv')
    stims_metadata['instance_id'] = stims_metadata.sketchy_filename.str.split('.', expand=True)[0]
    sketches_s3_metadata = pd.read_csv('photodraw32_s3_sketches_metadata.csv')

    # create dataframe which is supposed to match a sketch to its 8 most similar stimuli
    sketch2simstims_metadata = sketches_s3_metadata[sketches_s3_metadata.condition == 'photo']
    sketch2simstims_metadata = sketch2simstims_metadata.drop(columns='filepath')
    sketch2simstims_metadata = sketch2simstims_metadata.rename(columns = {'filename' : 'sketch_filename',
                                                                          's3_url'   : 'sketch_s3_url'})
    sketch2simstims_metadata['sketch_file'] = sketch2simstims_metadata.sketch_filename.str.split('.', expand=True)[0]

    sketch2simstims_metadata['nearest_photo_filenames'] = ''
    sketch2simstims_metadata['nearest_photo_s3_urls']   = ''
    sketch2simstims_metadata['true_photo_filename'] = ''
    sketch2simstims_metadata['true_photo_s3_url']   = ''

    # extract the 8-nearest neighbors for each sketch and store in a dataframe
    for index, sketch in KF[KF.condition == 'photo'].iterrows():

        sketch_original_image = IMMF[IMMF.image_id == f"{sketch['category']}_{sketch['imageURL']}"]
        sketch_original_image = sketch_original_image[np.arange(4096)].values[0]

        sketch_string = f"{sketch['gameID']}_"\
                        f"{sketch['trialNum']}_"\
                        f"{sketch['condition']}_"\
                        f"{sketch['category']}_"\
                        f"{sketch['imageURL']}_"\
                        f"{sketch['goal']}"
        simstims_index = sketch2simstims_metadata[sketch2simstims_metadata.sketch_file == sketch_string].index[0]


        all_neighbors = IMMF[IMMF.category == sketch['category']]
        knn = NearestNeighbors(n_neighbors=8, metric="cosine") 
        knn.fit(all_neighbors[np.arange(4096)].values)
        _, indices = knn.kneighbors([sketch_original_image]) # find k nearest train neighbours
        neighbors = all_neighbors.iloc[indices[0]]

        sketch2simstims_metadata.at[simstims_index, 'nearest_photo_filenames'] = list(neighbors.image_id + '.png')
        sketch2simstims_metadata.at[simstims_index, 'nearest_photo_s3_urls'] = \
                        list(stims_metadata[stims_metadata['instance_id'].isin(neighbors.instance_id)].s3_url)
        sketch2simstims_metadata.at[simstims_index, 'true_photo_filename'] = f'{sketch.category}_{sketch.imageURL}.png'
        # a bit messy 
        for url in sketch2simstims_metadata.at[simstims_index, 'nearest_photo_s3_urls']:
            if url.split('/')[-1].split('.')[0].rsplit('_',1)[0] == f'{sketch.imageURL}_{sketch.category}':
                sketch2simstims_metadata.at[simstims_index, 'true_photo_s3_url'] = url

    # Save out data to csv
    sketch2simstims_metadata = sketch2simstims_metadata.drop(columns='sketch_file')
    sketch2simstims_metadata.to_csv('photodraw32_instance_validation_metadata.csv', index = False)
    sketch2simstims_metadata.head()

In [None]:
assert len(sketch2simstims_metadata[sketch2simstims_metadata['true_photo_s3_url'] == '']) == 0

### Demonstration that we can now pull the 8 most similar images of a given sketch in from s3 using our metadata

In [None]:
sketch2simstims_metadata = pd.read_csv('photodraw32_instance_validation_metadata.csv')

In [None]:
sketch = sketch2simstims_metadata.sample()
print(sketch[['category', 'goal']])

# sketch
response = requests.get(sketch.sketch_s3_url.values[0])
img = Image.open(BytesIO(response.content))
display(img)

# ground truth image
img = Image.open(BytesIO(requests.get(sketch.true_photo_s3_url.values[0]).content))
display(img)

fig = plt.figure(figsize=(8., 8.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(2, 4),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 share_all=True)
grid[0].get_yaxis().set_ticks([])
grid[0].get_xaxis().set_ticks([])

images = [Image.open(BytesIO(requests.get(url).content)) for url in \
          ast.literal_eval(sketch.nearest_photo_s3_urls.values[0])]

random.shuffle(images)
# images.insert(0, Image.open(os.path.join(stims_dir, f'{a.category.values[0]}_{a.imageURL.values[0]}.png')))
for ax, im in zip(grid, images):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

In [None]:
fig = plt.figure(figsize=(12., 8.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(4, 8),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 share_all=True)
grid[0].get_yaxis().set_ticks([])
grid[0].get_xaxis().set_ticks([])

images = [Image.open(os.path.join(stims_dir, f'{img_id}.png')) for img_id in all_neighbors.image_id]

for ax, im in zip(grid, images):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)