# Checking stimuli for balance
This notebook helps to ensure that the generated stimuli are roughly balanced between positive and negative trials.

In [None]:
import os
import numpy as np
from PIL import Image
import pandas as pd
import json
import pymongo as pm
from glob import glob
from IPython.display import clear_output
import ast
import itertools
import random
import h5py
from tqdm import tqdm

In [None]:
import matplotlib.pyplot as plt

In [None]:
#display all columns
pd.set_option('display.max_columns', None)

In [None]:
def list_files(paths, ext='mp4'):
    """Pass list of folders if there are stimuli in multiple folders. 
    Make sure that the containing folder is informative, as the rest of the path is ignored in naming. 
    Also returns filenames as uploaded to S3"""
    if type(paths) is not list:
        paths = [paths]
    results = []
    names = []
    for path in paths:
        results += [y for x in os.walk(path) for y in glob(os.path.join(x[0], '*.%s' % ext))]
        names += [os.path.basename(os.path.dirname(y))+'_'+os.path.split(y)[1].split('.')[0] for x in os.walk(path) for y in glob(os.path.join(x[0], '*.%s' % ext))]
#     hdf5s = [r.split("_img.")[0]+".hdf5" for r in results]
    hdf5s = [r.split("_img.")[0] for r in results]
    return results,names,hdf5s

In [None]:
local_stem = 'XXX' #CHANGE THIS ⚡️
dirnames = [d.split('/')[-1] for d in glob(local_stem+'/*')]
data_dirs = [local_stem + d for d in dirnames]

stimulus_extension = "hdf5" #what's the file extension for the stims? Provide without dot

## get a list of paths to each one
full_stim_paths,filenames, full_hdf5_paths = list_files(data_dirs,stimulus_extension)
full_map_paths, mapnames, _ = list_files(data_dirs, ext = 'png') #generate filenames and stimpaths for target/zone map
print('We have {} stimuli to evaluate.'.format(len(full_stim_paths)))

In [None]:
stim_IDs = [name.split('.')[0] for name in filenames]
set_names= ['_'.join(s.split('_')[:-2]) for s in stim_IDs]

In [None]:
## convert to pandas dataframe
M = pd.DataFrame([stim_IDs,set_names]).transpose()
M.columns = ['stim_ID','set_name']

In [None]:
# if needed, add code to add additional columns
# Add trial labels to the metadata using the stimulus metadata.json
target_hit_zone_labels = dict()
for _dir in data_dirs:
    with open(_dir + '/metadata.json', 'rb') as f:
        trial_metas = json.load(f)
        
    for i,meta in enumerate(trial_metas):
        stim_name = meta['stimulus_name']
        if stim_name == 'None': #recreate stimname from order in metadata
            stim_name = str(i).zfill(4)
            stim_name = _dir.split('/')[-1] + '_' + stim_name
#         if stim_name[-4:] != "_img": stim_name+='_img' #stimnames need to end in "_img"
        label = meta['does_target_contact_zone']
        target_hit_zone_labels[stim_name] = label
        
print("num positive labels: %d" % sum(list(target_hit_zone_labels.values())))
print("num negative labels: %d" % (len(target_hit_zone_labels) - sum(list(target_hit_zone_labels.values()))))
print("ratio",sum(list(target_hit_zone_labels.values())) / (len(target_hit_zone_labels) - sum(list(target_hit_zone_labels.values()))))

In [None]:
# make new df with all metadata
GT = pd.DataFrame([list(target_hit_zone_labels.keys()), list(target_hit_zone_labels.values())]).transpose()
GT.columns = ['stim_ID', 'target_hit_zone_label']

In [None]:
# merge with M
M = M.merge(GT, on='stim_ID')
print("added labels %s" % list(GT.columns[1:]))

In [None]:
metadata = {} #holds all the metadata for all stimuli

for name,hdf5_path in tqdm(list(zip([f.split('.')[0] for f in filenames],full_hdf5_paths))):
    #load hdf5
#     print("loading",hdf5_path)
    try:
        hdf5 = h5py.File(hdf5_path,'r') #get the static part of the HDF5
        stim_name = str(np.array(hdf5['static']['stimulus_name']))
        metadatum = {} #metadata for the current stimulus
        for key in hdf5['static'].keys():
            datum = np.array(hdf5['static'][key])
            if datum.shape == (): datum = datum.item() #unwrap non-arrays
            metadatum[key] = datum
        #close file
        hdf5.close()
        metadata[name] = metadatum
    except Exception as e:
        print("Error with",hdf5_path,":",e)
        continue

Insert those metadatas into M:

In [None]:
for index in M.index:
    stim_name = M.at[index,'stim_ID']
    for key,value in metadata[stim_name].items():
        M.at[index,key] = str(value) #insert every item as string

In [None]:
M

In [None]:
M['label'] = M['target_hit_zone_label'].astype(int)

## Analysis

How many stimuli?

In [None]:
len(M)

How many of trials are positive (1) rather than negative (0)?

In [None]:
np.mean(M['label'])

How many of trials *per set name* are positive (1) rather than negative (0)?

In [None]:
M.groupby('set_name').agg({'stim_ID':['count'],'label':['mean']})