In [1]:
import pandas as pd
import os.path as osp
from configuration import Config
from glob import glob
from collections import defaultdict
import matplotlib.pyplot as plt
import json
import numpy as np
import re

config = Config()

In [2]:
grid_data = pd.read_json(osp.join(config.data_dir, "color_grid_data.json"))

In [3]:
def clean_and_tokenize(s):
    s = re.sub('[^0-9a-z ]+', ' ', s.lower())
    tokens = s.split()
    return tokens

def merge_utts(utts):
    return ' '.join([u[1] for u in utts])

def merge_clean_tokenize(utts):
    merged_string = merge_utts(utts)
    return clean_and_tokenize(merged_string)

def contains_dir(utts:list, dir:str, possible_dirs:set):
    tokens = merge_clean_tokenize(utts)
    other_dirs = possible_dirs - {dir}

    if dir in tokens:
        if not any([d in tokens for d in other_dirs]):
            return True
    return False

In [4]:
dir_subsets = dict()
possible_dirs = ['left', 'middle', 'right']
for dir in possible_dirs:
    contains_dir_func = lambda x: contains_dir(x, dir, set(possible_dirs))
    m = grid_data.utterances.map(contains_dir_func)
    data_subset = grid_data[m]
    dir_subsets[dir] = data_subset
    
    print(f'Location: {dir}')
    print(f'# Entries: {len(data_subset)}')
    print(f'% of full data: {round((len(data_subset) / len(grid_data)) * 100, 1)}')
    print('######################')

Location: left
# Entries: 2071
% of full data: 19.0
######################
Location: middle
# Entries: 1937
% of full data: 17.7
######################
Location: right
# Entries: 1898
% of full data: 17.4
######################


In [5]:
def has_listener_utt(utts):
    interlocutors = [il for il, ut, _ in utts]
    return 'listener' in interlocutors

def parse_patch_results(data):

    # load data
    data_df = pd.read_json(
        osp.join(config.data_dir, 'color_patch_data.json'))
    data_df = data_df.rename(columns={
        'success': 'human_success'
    })

    # load results
    results_df = pd.DataFrame(data)

    merge_cols = [c for c in data_df.columns if c == 'identifier' or c not in results_df.columns]

    # merge dfs
    results_df = pd.merge(
        results_df, data_df[merge_cols],
        left_on=['identifier'],
        right_on=['identifier']
    )

    # check if there are listener utterances
    results_df['has_listener_utt'] = results_df.conversation.map(has_listener_utt)

    return results_df


def parse_grid_results(data):

    # load data
    data_df = pd.read_json(
        osp.join(config.data_dir, 'color_grid_data.json'))
    data_df = data_df.rename(columns={
        'success': 'human_success'
    })

    # load results
    results_df = pd.DataFrame(data)

    # merge dfs
    merge_cols = [
        'gameid', 'roundNum', 'human_success', 'utterances', 
        'n_utterances', 'objs', 'target', 
        'speaker_order', 'listener_order', 'listener_clicked'
    ]
    results_df = pd.merge(
        results_df, data_df[merge_cols],
        left_on=['gameid', 'roundNum'],
        right_on=['gameid', 'roundNum']
    )

    # check if there are listener utterances
    results_df['has_listener_utt'] = results_df.utterances.map(has_listener_utt)

    return results_df

def parse_results(path):
    
    with open(path, 'r') as f:
        data = json.load(f)
        results_config = data['config']
        results_data = data['results']

    if results_config['task'] == 'grid':
        return results_config, parse_grid_results(results_data)
    else:
        return results_config, parse_patch_results(results_data)

In [6]:
results_files = glob(osp.join(config.output_dir, 'colorgrids*.json'))
all_results = [parse_results(f) for f in results_files]

In [7]:
results_dict = defaultdict(dict)

for r_config, df in all_results:
    
    model_type = r_config['model_type']
    model_size = int(r_config['model_size'].replace('b', ''))
    task = r_config['task']
    quant = r_config['quant']

    assert task == 'grid'
    
    results_dict[f'{model_type}-{model_size}-{task}']['system'] = model_type
    results_dict[f'{model_type}-{model_size}-{task}']['size'] = model_size
    #results_dict[f'{model_type}-{model_size}-{task}']['task'] = task
    results_dict[f'{model_type}-{model_size}-{task}']['quant'] = quant
    
    for dir in possible_dirs:
        data_subset = dir_subsets[dir]
        pred_subset = df.loc[df.round_id.isin(data_subset.round_id)]
        same_dir_ratio = (pred_subset.predicted_location == dir).mean() * 100
        results_dict[f'{model_type}-{model_size}-{task}'][f'same_dir_{dir}'] = same_dir_ratio
        dir_acc = pred_subset.correct.mean() * 100
        results_dict[f'{model_type}-{model_size}-{task}'][f'acc_{dir}'] = dir_acc
        
results_df = pd.DataFrame(results_dict).T.set_index(['system', 'size', 'quant']).sort_index().astype(float)

arrays = [[c.split('_')[-1] for c in results_df.columns], list(results_df.columns)]
results_df.columns = pd.MultiIndex.from_arrays(arrays)


# Location Biases (Table 3)

In [8]:
results_df.round(1)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,left,left,middle,middle,right,right
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,same_dir_left,acc_left,same_dir_middle,acc_middle,same_dir_right,acc_right
system,size,quant,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
Janus,1,,100.0,33.1,32.2,36.2,13.1,33.7
Janus,7,,94.4,34.7,95.6,34.7,53.8,41.9
LLaVa,7,,90.9,37.6,89.2,37.0,94.4,32.2
LLaVa,13,,88.7,37.6,95.4,34.8,95.8,32.5
LLaVa,34,,92.5,34.3,99.5,33.7,90.5,32.6
LLaVa,72,8bit,88.1,38.4,87.1,37.2,91.1,32.5
Qwen,2,,95.9,35.0,67.0,38.3,73.2,38.7
Qwen,7,,76.1,46.4,92.8,37.2,30.7,47.4
Qwen,72,awq,47.4,71.0,69.0,55.8,35.9,74.9


In [9]:
print(results_df.round(1).to_string())

print(results_df.round(1).to_latex())

                           left                   middle                     right          
                  same_dir_left acc_left same_dir_middle acc_middle same_dir_right acc_right
system size quant                                                                           
Janus  1    NaN           100.0     33.1            32.2       36.2           13.1      33.7
       7    NaN            94.4     34.7            95.6       34.7           53.8      41.9
LLaVa  7    NaN            90.9     37.6            89.2       37.0           94.4      32.2
       13   NaN            88.7     37.6            95.4       34.8           95.8      32.5
       34   NaN            92.5     34.3            99.5       33.7           90.5      32.6
       72   8bit           88.1     38.4            87.1       37.2           91.1      32.5
Qwen   2    NaN            95.9     35.0            67.0       38.3           73.2      38.7
       7    NaN            76.1     46.4            92.8       37.2   

  print(results_df.round(1).to_latex())
