In [1]:
import pandas as pd
import os.path as osp
from configuration import Config
from glob import glob
from collections import defaultdict
import matplotlib.pyplot as plt
import json
import re
import numpy as np

config = Config()

BASIC_COLOR_TERMS = {'black', 'white', 'red', 'green', 'yellow', 'blue', 'brown', 'orange', 'pink', 'purple', 'gray', 'grey'}

In [2]:
grid_data = pd.read_json(osp.join(config.data_dir, "color_grid_data.json"))
patch_data = pd.read_json(osp.join(config.data_dir, "color_patch_data.json"))

In [3]:
def clean_and_tokenize(s):
    s = re.sub('[^0-9a-z ]+', ' ', s.lower())
    tokens = s.split()
    return tokens

def merge_utts(utts):
    return ' '.join([u[1] for u in utts])

def merge_clean_tokenize(utts):
    merged_string = merge_utts(utts)
    return clean_and_tokenize(merged_string)

def color_term_ratio(tokenized_utts, color_terms):
    bct_in_utts = [u for u in tokenized_utts if u in color_terms]
    return len(bct_in_utts) / len(tokenized_utts)

def get_color_term_ratio(utts, color_terms=BASIC_COLOR_TERMS):
    tokens = merge_clean_tokenize(utts)
    if len(tokens) == 0:
        return 0
    return color_term_ratio(tokens, color_terms)

def get_utt_length(utts):
    tokens = merge_clean_tokenize(utts)
    return len(tokens)


In [4]:
ct_threshold = 1/3
len_threshold = 5

m = np.logical_and(
    grid_data.utterances.map(get_color_term_ratio) >= ct_threshold,
    grid_data.utterances.map(get_utt_length) <= len_threshold,
)
grid_data_subset = grid_data[m]
rel_size = round(len(grid_data_subset) / len(grid_data) * 100, 1)
print(f'color grids: {len(grid_data)} -> {len(grid_data_subset)} entries ({rel_size} %)')

m = np.logical_and(
    patch_data.conversation.map(get_color_term_ratio) >= ct_threshold,
    patch_data.conversation.map(get_utt_length) <= len_threshold,
)
patch_data_subset = patch_data[m]
rel_size = round(len(patch_data_subset) / len(patch_data) * 100, 1)
print(f'color patches: {len(patch_data)} -> {len(patch_data_subset)} entries ({rel_size} %)')

color grids: 10925 -> 2556 entries (23.4 %)
color patches: 47041 -> 29201 entries (62.1 %)


In [5]:
def has_listener_utt(utts):
    interlocutors = [il for il, ut, _ in utts]
    return 'listener' in interlocutors

def parse_patch_results(data):

    # load data
    data_df = pd.read_json(
        osp.join(config.data_dir, 'color_patch_data.json'))
    data_df = data_df.rename(columns={
        'success': 'human_success'
    })

    # load results
    results_df = pd.DataFrame(data)

    merge_cols = [c for c in data_df.columns if c == 'identifier' or c not in results_df.columns]

    # merge dfs
    results_df = pd.merge(
        results_df, data_df[merge_cols],
        left_on=['identifier'],
        right_on=['identifier']
    )

    # check if there are listener utterances
    results_df['has_listener_utt'] = results_df.conversation.map(has_listener_utt)

    return results_df


def parse_grid_results(data):

    # load data
    data_df = pd.read_json(
        osp.join(config.data_dir, 'color_grid_data.json'))
    data_df = data_df.rename(columns={
        'success': 'human_success'
    })

    # load results
    results_df = pd.DataFrame(data)

    # merge dfs
    merge_cols = [
        'gameid', 'roundNum', 'human_success', 'utterances', 
        'n_utterances', 'objs', 'target', 
        'speaker_order', 'listener_order', 'listener_clicked'
    ]
    results_df = pd.merge(
        results_df, data_df[merge_cols],
        left_on=['gameid', 'roundNum'],
        right_on=['gameid', 'roundNum']
    )

    # check if there are listener utterances
    results_df['has_listener_utt'] = results_df.utterances.map(has_listener_utt)

    return results_df

def parse_results(path):
    
    with open(path, 'r') as f:
        data = json.load(f)
        results_config = data['config']
        results_data = data['results']

    if results_config['task'] == 'grid':
        return results_config, parse_grid_results(results_data)
    else:
        return results_config, parse_patch_results(results_data)

In [6]:
# read and parse inputs
results_files = glob(osp.join(config.output_dir, '*.json'))
all_results = [parse_results(f) for f in results_files]

In [7]:
results_dict = defaultdict(dict)

for r_config, df in all_results:

    model_type = r_config['model_type']
    model_size = int(r_config['model_size'].replace('b', ''))
    task = r_config['task']
    quant = r_config['quant']

    if task == 'grid': 
        subset_df = df.loc[df.round_id.isin(grid_data_subset.round_id)]
    elif task == 'patch':
        subset_df = df.loc[df.identifier.isin(patch_data_subset.identifier)]
    else:
        raise ValueError

    # print(f'{len(df)} -> {(len(subset_df))}')

    results_dict[f'{model_type}-{model_size}-{task}']['system'] = model_type
    results_dict[f'{model_type}-{model_size}-{task}']['size'] = model_size
    results_dict[f'{model_type}-{model_size}-{task}']['task'] = task
    results_dict[f'{model_type}-{model_size}-{task}']['quant'] = quant
    
    # total accuracy
    results_dict[f'{model_type}-{model_size}-{task}']['total_acc'] = subset_df.correct.mean()

    # accuracy per condition
    per_condition_results = subset_df.groupby('condition').correct.mean().to_dict()
    results_dict[f'{model_type}-{model_size}-{task}'].update(
        {f'{k.lower()}_acc':v for k,v in per_condition_results.items()}
    )

    # predicted locations
    pred_locations = (subset_df.groupby('predicted_location').size() / len(subset_df)).rename({'None': 'unknown'}).to_dict()
    results_dict[f'{model_type}-{model_size}-{task}'].update(
        {f'{d}_ratio':pred_locations.get(d, 0) for d in ['left', 'middle', 'right', 'unknown']}
    )

results_df = pd.DataFrame(results_dict).T
acc_cols = [c for c in results_df.columns if 'acc' in c]
results_df[acc_cols] = results_df[acc_cols].astype(float) * 100  # convert to %

In [8]:
display_df = True
print_latex = True
print_str = True

acc_cols = ['total', 'far', 'split', 'close']
    
r = results_df.sort_index()
task_dfs = []

for task in ['patch', 'grid']:
    #print('\n', task.upper(), '#'*50, sep='\n')
    _r = r.loc[r.task == task].sort_values(by=['system', 'size']).drop(columns=['task'])
    rename = {c:c.replace('_acc', '') for c in _r.columns}
    _r = _r.set_index(['system', 'size', 'quant']).rename(columns=rename)[acc_cols]

    # add scores for humans
    data = grid_data_subset if task == 'grid' else patch_data_subset
    human_per_condition = data.groupby('condition').success.mean()
    human_per_condition.index = human_per_condition.index.map(str.lower)
    human_per_condition['total'] = data.success.mean()
    human_per_condition = human_per_condition * 100
    _r.loc['human', -1, '-'] = human_per_condition
    
    multicolumns=[(task, c) for c in _r.columns]
    _r.columns=pd.MultiIndex.from_tuples(multicolumns)
        
    task_dfs.append(_r)
    
merged_results = pd.concat(task_dfs, axis=1).sort_index(ascending=[True, True, True])

m = pd.DataFrame(index=merged_results.index,columns=merged_results.columns)
m.iloc[:-1] = True
m.iloc[-1] = False

style_idx = pd.IndexSlice[merged_results.index[:-1], :]

if display_df:
    display(merged_results.style.highlight_max(subset=style_idx, axis=0, props="font-weight:bold;").format(precision=1))    
if print_latex:
    print(merged_results.style.highlight_max(subset=style_idx, axis=0, props="textbf:--rwrap;").format(precision=1).to_latex())
if print_str:
    print(merged_results.round(1).to_string())

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,patch,patch,patch,patch,grid,grid,grid,grid
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,total,far,split,close,total,far,split,close
system,size,quant,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
Janus,1,,36.9,39.1,36.1,34.8,33.1,32.7,33.6,33.3
Janus,7,,75.2,89.4,68.9,61.5,41.5,41.7,39.4,43.6
LLaVa,7,,64.5,74.2,62.2,52.9,37.6,36.7,37.3,39.7
LLaVa,13,,62.0,70.1,59.5,52.7,36.0,35.5,35.1,38.1
LLaVa,34,,84.9,95.2,80.8,74.6,37.5,36.9,36.9,39.7
LLaVa,72,8bit,65.2,77.0,60.9,52.6,40.5,40.1,40.5,41.1
Qwen,2,,69.9,83.7,64.4,55.7,41.0,40.8,39.9,42.8
Qwen,7,,87.3,95.9,84.3,77.9,44.0,45.2,41.7,44.7
Qwen,72,awq,90.1,96.4,88.6,82.5,65.9,66.9,64.9,65.1
human,-1,-,91.6,97.7,90.7,83.6,94.5,96.8,93.6,90.9


\begin{tabular}{lllrrrrrrrr}
 &  &  & \multicolumn{4}{r}{patch} & \multicolumn{4}{r}{grid} \\
 &  &  & total & far & split & close & total & far & split & close \\
system & size & quant &  &  &  &  &  &  &  &  \\
\multirow[c]{2}{*}{Janus} & 1 & nan & 36.9 & 39.1 & 36.1 & 34.8 & 33.1 & 32.7 & 33.6 & 33.3 \\
 & 7 & nan & 75.2 & 89.4 & 68.9 & 61.5 & 41.5 & 41.7 & 39.4 & 43.6 \\
\multirow[c]{4}{*}{LLaVa} & 7 & nan & 64.5 & 74.2 & 62.2 & 52.9 & 37.6 & 36.7 & 37.3 & 39.7 \\
 & 13 & nan & 62.0 & 70.1 & 59.5 & 52.7 & 36.0 & 35.5 & 35.1 & 38.1 \\
 & 34 & nan & 84.9 & 95.2 & 80.8 & 74.6 & 37.5 & 36.9 & 36.9 & 39.7 \\
 & 72 & 8bit & 65.2 & 77.0 & 60.9 & 52.6 & 40.5 & 40.1 & 40.5 & 41.1 \\
\multirow[c]{3}{*}{Qwen} & 2 & nan & 69.9 & 83.7 & 64.4 & 55.7 & 41.0 & 40.8 & 39.9 & 42.8 \\
 & 7 & nan & 87.3 & 95.9 & 84.3 & 77.9 & 44.0 & 45.2 & 41.7 & 44.7 \\
 & 72 & awq & \textbf{90.1} & \textbf{96.4} & \textbf{88.6} & \textbf{82.5} & \textbf{65.9} & \textbf{66.9} & \textbf{64.9} & \textbf{65.1} \\
human 

In [None]:
table_out_dir = osp.abspath('./generated_tables/')
if not osp.isdir(table_out_dir):
    print(f'make new dir: {table_out_dir}')
    os.makedirs(table_out_dir)

fname = 'results_simplified.csv'
fpath = osp.join(table_out_dir, fname)
print(f'save table to {fpath}')
merged_results.to_csv(fpath)

In [10]:
general_results_path = osp.join(table_out_dir, 'results_general.csv')
assert osp.isfile(general_results_path), f'file {general_results_path} does not exist, run other notebook first'

general_results = pd.read_csv(general_results_path, header=[0,1], index_col=[0,1,2])
result_differences = (((merged_results - general_results) / general_results) * 100)  # deviation as % of original values

# Table with highlights for improvements (Table 2)

In [11]:
def apply_bold_font(val):
    return f'font-weight:bold;' if val else None

def apply_bold_font_latex(val):
    return f'textbf:--rwrap;' if val else None

m = result_differences > 0

display(merged_results.style.apply(lambda x: m.applymap(apply_bold_font), axis=None).format(precision=1))

print(merged_results.style.apply(lambda x: m.applymap(apply_bold_font_latex), axis=None).format(precision=1).to_latex())

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,patch,patch,patch,patch,grid,grid,grid,grid
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,total,far,split,close,total,far,split,close
system,size,quant,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
Janus,1,,36.9,39.1,36.1,34.8,33.1,32.7,33.6,33.3
Janus,7,,75.2,89.4,68.9,61.5,41.5,41.7,39.4,43.6
LLaVa,7,,64.5,74.2,62.2,52.9,37.6,36.7,37.3,39.7
LLaVa,13,,62.0,70.1,59.5,52.7,36.0,35.5,35.1,38.1
LLaVa,34,,84.9,95.2,80.8,74.6,37.5,36.9,36.9,39.7
LLaVa,72,8bit,65.2,77.0,60.9,52.6,40.5,40.1,40.5,41.1
Qwen,2,,69.9,83.7,64.4,55.7,41.0,40.8,39.9,42.8
Qwen,7,,87.3,95.9,84.3,77.9,44.0,45.2,41.7,44.7
Qwen,72,awq,90.1,96.4,88.6,82.5,65.9,66.9,64.9,65.1
human,-1,-,91.6,97.7,90.7,83.6,94.5,96.8,93.6,90.9


\begin{tabular}{lllrrrrrrrr}
 &  &  & \multicolumn{4}{r}{patch} & \multicolumn{4}{r}{grid} \\
 &  &  & total & far & split & close & total & far & split & close \\
system & size & quant &  &  &  &  &  &  &  &  \\
\multirow[c]{2}{*}{Janus} & 1 & nan & \textbf{36.9} & \textbf{39.1} & \textbf{36.1} & \textbf{34.8} & 33.1 & 32.7 & \textbf{33.6} & 33.3 \\
 & 7 & nan & \textbf{75.2} & \textbf{89.4} & \textbf{68.9} & \textbf{61.5} & \textbf{41.5} & \textbf{41.7} & \textbf{39.4} & \textbf{43.6} \\
\multirow[c]{4}{*}{LLaVa} & 7 & nan & \textbf{64.5} & \textbf{74.2} & \textbf{62.2} & \textbf{52.9} & 37.6 & 36.7 & 37.3 & \textbf{39.7} \\
 & 13 & nan & \textbf{62.0} & \textbf{70.1} & \textbf{59.5} & \textbf{52.7} & 36.0 & 35.5 & 35.1 & \textbf{38.1} \\
 & 34 & nan & \textbf{84.9} & \textbf{95.2} & \textbf{80.8} & \textbf{74.6} & 37.5 & 36.9 & 36.9 & \textbf{39.7} \\
 & 72 & 8bit & \textbf{65.2} & \textbf{77.0} & \textbf{60.9} & \textbf{52.6} & \textbf{40.5} & 40.1 & \textbf{40.5} & \textbf{41.1} \