In [None]:
import os
import os.path as osp
import json
import csv
import argparse

import sys
sys.path.append('../')
from configuration import Config

config = Config()

In [2]:
# hls to rgb convertion, data parsing and building data dicitonary out of original data 

def hls_values_to_float(hls):
    h, l, s = hls
    h_out = h / 359
    l_out = l / 100
    s_out = s/ 100
    return h_out, l_out, s_out

def parse_entry(e):

    identifier = f"{e['gameid']}:{e['roundNum']}"
    success = e['outcome'] == 'true'
    role = e['role']
    source = e['source']
    condition = e['condition']

    patches = [
        (
            e['clickStatus'],
            tuple(map(int, [e['clickColH'], e['clickColL'], e['clickColS']])),
        ), # HLS values of patch clicked by listener
        (
            e['alt1Status'],
            tuple(map(int, [e['alt1ColH'], e['alt1ColL'], e['alt1ColS']])),
        ), # HLS values for 1st alternative
        (
            e['alt2Status'],
            tuple(map(int, [e['alt2ColH'], e['alt2ColL'], e['alt2ColS']]))
        ), # HLS values for 2nd alternative
    ]

    d1, d2, t = sorted(patches, key=lambda x: x[0])
    assert t[0] == 'target'
    hls_values = (t[1], d1[1], d2[1])

    speaker_positions = [None, None, None]
    listener_positions = [None, None, None]

    status_locations = [
        (e['clickStatus'], e['clickLocS'], e['clickLocL']),
        (e['alt1Status'], e['alt1LocS'], e['alt1LocL']),
        (e['alt2Status'], e['alt2LocS'], e['alt2LocL']),
    ]

    for status, loc_s, loc_l in status_locations:
        if status == 'target':
            speaker_positions[0] = loc_s
            listener_positions[0] = loc_l
        elif status == 'distr1':
            speaker_positions[1] = loc_s
            listener_positions[1] = loc_l
        elif status == 'distr2':
            speaker_positions[2] = loc_s
            listener_positions[2] = loc_l
        
    expression = e['contents']

    return {
        'identifier': identifier,
        'role': role,
        'source': source,
        'hls_values': hls_values,
        'expression': expression,
        'condition': condition,
        'success': success,
        's_order_t_distr1_distr2': speaker_positions,
        'l_order_t_distr1_distr2': listener_positions
    }

def build_data_dict(data_path):
    all_data = []
    with open(data_path, 'r') as file:
        csv_file = csv.DictReader(file)
        for i, row in enumerate(csv_file):
            row_data = dict(row)
            all_data.append({'idx':i, **parse_entry(row_data)})
    return all_data

# filtering and saving data splits

def filter_raw_data(data, args):
    if not args.keep_failed:
        data = [d for d in data if d['success']]
    if not args.keep_listener_turns:
        data = [d for d in data if d['role'] == 'speaker']
    if not args.keep_non_human:
        data = [d for d in data if d['source'] == 'human']
    return data

def save_data_splits(data, args):
    out_dir = osp.abspath(args.output_dir)
    if not osp.isdir(out_dir):
        os.makedirs(out_dir)

    for split_data, name in zip(data, ['train', 'val', 'test']):
        out_path = osp.join(out_dir, f'{name}.json')
        with open(out_path, 'w') as f:
            json.dump(split_data, f)

        print(f'Saved {name}.json to {out_path}.')

# loading split data 

def load_data_splits(json_dir):
    split_files = ['train.json', 'val.json', 'test.json']
    split_names = ['train', 'val', 'test']
    preprocessed_data = {}

    for split_file, split_name in zip(split_files, split_names):
        split_file_path = osp.join(json_dir, split_file)
        if osp.exists(split_file_path):
            with open(split_file_path, 'r') as f:
                preprocessed_data[split_name] = json.load(f)  
                print(f'Loaded {split_file}.')
        else:
            print(f'{split_file} not found in {json_dir}.')

    return preprocessed_data

# further processing of split data // filtering out unsuccessful rounds

def filter_split_data(split_data):
    filtered_data = {}
    for split_name, split in split_data.items():
        if split_name in ['train', 'val', 'test']:
            filtered_split = []
            for entry in split:
                if entry.get('success', False): #only include successful rounds
                    entry['split'] = split_name
                    filtered_split.append(entry)
            filtered_data[split_name] = filtered_split

    return filtered_data

# creating final data needed for model input: round information (gameid, roundnum, condition, split, hls values, speaker and listener order of patches) & conversation

def create_final_data_sctructure(data):
    games = {}

    for entry in data:
        identifier = entry['identifier']
        gameid, roundnum = identifier.split(':')
        roundnum = int(roundnum)
        condition = entry['condition']
        success = entry['success']
        role = entry['role']
        hls_values = entry['hls_values']
        s_order_t_distr1_distr2 = [int(x)-1 for x in entry['s_order_t_distr1_distr2']]  # offset
        l_order_t_distr1_distr2 = [int(x)-1 for x in entry['l_order_t_distr1_distr2']]  # offset

        if identifier not in games:
            games[identifier] = {
                'identifier': identifier,
                'gameid': gameid,
                'roundNum': roundnum,
                'condition': condition,
                'success': success,
                'hls_values': hls_values,
                's_order_t_distr1_distr2': s_order_t_distr1_distr2,
                'l_order_t_distr1_distr2': l_order_t_distr1_distr2,
                'conversation': [],
                'n_utterances': 0
            }

        conversation = games[identifier]['conversation']
        conversation_entry = (role, entry['expression'], None)
        conversation.append(conversation_entry)
        games[identifier]['n_utterances'] += 1

    final_data = list(games.values())

    return final_data

In [3]:
args = argparse.Namespace(
    originaldata_location = osp.join(config.colorpatch_data, 'behavioralAnalysis', 'humanOutput'),
    keep_failed = True,
    keep_listener_turns = True,
    keep_non_human = True,
    
)

In [4]:
# loading original data
data_file_path = osp.join(args.originaldata_location, 'filteredCorpus.csv')
data = build_data_dict(data_file_path) 

# filtering original data based on args
data = filter_raw_data(data, args)

# reformat
data = create_final_data_sctructure(data)

In [5]:
out_path = osp.join(config.data_dir, 'color_patch_data.json')
with open(out_path, 'w') as f:
    json.dump(data, f)