In [None]:
import sys
import os.path as osp
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
import re
from itertools import chain
from collections import defaultdict, Counter
import spacy

nlp = spacy.load("en_core_web_sm")

sys.path.append('../analyzing_annotations')
from analysis_utils import read_ann_df, transform_list_into_heads

pd.options.mode.chained_assignment = None  # default='warn'

IMG_LOCATION=osp.abspath('../generated_items/')

data_dir = osp.abspath('../predicted_data/processed')
ann_dir = osp.abspath('../scenegram_data')
kilogram_dir = osp.abspath('../kilogram')

def model_size_from_col(c):
    return int(re.search(r'.*\w\-(\d+)b', c).group(1))

# Loading Data

In [2]:
# load annotations

input_file = osp.join(ann_dir, f'scenegram.csv')
ann_df = read_ann_df(input_file)
ann_df.head_noun = ann_df.head_noun.apply(lambda x: x.split('/')[0].strip())

tangrams, scenes = zip(*ann_df.index)
tangrams = sorted(set(tangrams))
scenes = sorted(set(scenes))

scenes.remove('none')
scenes.append('none')

tangram2idx = {t:i for i, t in enumerate(tangrams)}
idx2tangram = {i:t for t, i in tangram2idx.items()}

#display(ann_df.head())

In [3]:
# load predictions

input_file = osp.join(data_dir, f'processed_predictions_twostep.csv')
pred_df = pd.read_csv(input_file, index_col=0)

response_cols = [c for c in pred_df.columns if c.startswith('response_')]
models = ['llava-7b', 'llava-13b', 'llava-34b', 'llava-72b']
utt_types = ['label', 'response', 'synset']

agg_columns = [f'{utt_type}_{model}' for utt_type in utt_types for model in models]
agg_columns += ['set_idx', 'item_identifyer']
pred_df = pred_df.groupby('item_id').agg({
    c: list if c in agg_columns else 'first' for c in pred_df.columns
})
pred_df = pred_df.set_index(['tangram', 'scene'])

# display(pred_df.head())

# Location Determination (Table 4)

In [4]:
location_labels = pred_df.set_index('item_id')[[f'location_label_{model}' for model in models]]

location_preds = pd.merge(
    left=location_labels,
    right=ann_df.reset_index()[['tangram', 'scene', 'tangram_pos', 'item_id']].groupby('item_id').first(),
    left_index=True,
    right_index=True
)

for model in models:
    location_preds[f'correct_{model}'] = location_preds.apply(lambda x: x[f"location_label_{model}"] == x["tangram_pos"], axis=1)

correct = location_preds.groupby('scene')[[c for c in location_preds.columns if c.startswith('correct_')]].mean() * 100
correct.loc['global'] = location_preds[[c for c in location_preds.columns if c.startswith('correct_')]].mean() * 100
col_order = sorted(correct.columns, key=model_size_from_col)
correct = correct[col_order]
correct = correct.rename(columns={c:c.replace('correct_', '') for c in correct.columns})
display(correct.round(1))
print(correct.round(1).to_latex())

pred_df = pd.merge(
    pred_df,
    location_preds[[f'correct_{model}' for model in models]],
    left_on='item_id',
    right_index=True
)

Unnamed: 0_level_0,llava-7b,llava-13b,llava-34b,llava-72b
scene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
bathroom,86.5,100.0,100.0,100.0
beach,83.8,100.0,100.0,100.0
bedroom,78.4,100.0,100.0,100.0
forest,89.2,100.0,100.0,100.0
kitchen,75.7,97.3,100.0,97.3
mountain,91.9,100.0,100.0,100.0
none,56.8,100.0,100.0,100.0
office,83.8,100.0,100.0,97.3
sea_bottom,73.0,100.0,100.0,94.6
sky,83.8,100.0,100.0,100.0


\begin{tabular}{lrrrr}
\toprule
{} &  llava-7b &  llava-13b &  llava-34b &  llava-72b \\
scene      &           &            &            &            \\
\midrule
bathroom   &      86.5 &      100.0 &      100.0 &      100.0 \\
beach      &      83.8 &      100.0 &      100.0 &      100.0 \\
bedroom    &      78.4 &      100.0 &      100.0 &      100.0 \\
forest     &      89.2 &      100.0 &      100.0 &      100.0 \\
kitchen    &      75.7 &       97.3 &      100.0 &       97.3 \\
mountain   &      91.9 &      100.0 &      100.0 &      100.0 \\
none       &      56.8 &      100.0 &      100.0 &      100.0 \\
office     &      83.8 &      100.0 &      100.0 &       97.3 \\
sea\_bottom &      73.0 &      100.0 &      100.0 &       94.6 \\
sky        &      83.8 &      100.0 &      100.0 &      100.0 \\
street     &      86.5 &      100.0 &      100.0 &      100.0 \\
global     &      80.8 &       99.8 &      100.0 &       99.0 \\
\bottomrule
\end{tabular}



  print(correct.round(1).to_latex())


# Lexical Frequencies (Table 5)

In [5]:
n_most_frequent = 5
prec = 2

def counts_to_perc(c, total, prec=prec):
    return round((c / total) * 100, prec)

for model in models:

    print(model)

    # get lexical counts
    _df = pred_df.reset_index()
    raw_counts = _df.groupby('scene').agg(list)[f'label_{model}'].map(lambda x: list(chain(*x))).map(Counter)
    most_common_per_scene = raw_counts.map(lambda x: x.most_common(n_most_frequent))

    results = defaultdict(list)
    for scene in scenes:       
        for i, (word, s_count) in enumerate(most_common_per_scene[scene]):
            results[scene].append([word, s_count])
                
    # compile results
    results_df = pd.DataFrame(results).T

    total_per_scene = sum(raw_counts.iloc[0].values())
    words_and_counts = results_df.applymap(lambda x: f'{x[0]} ({counts_to_perc(x[1], total_per_scene)} %)')
    words_and_counts = words_and_counts.rename(columns={c:f'#{c+1}' for c in words_and_counts.columns})
    display(words_and_counts)
    print(words_and_counts.to_latex())

llava-7b


Unnamed: 0,#1,#2,#3,#4,#5
bathroom,bathtub (28.65 %),house (19.19 %),diamond (7.03 %),person (6.76 %),rectangle (5.95 %)
beach,sun (14.59 %),house (12.43 %),bird (10.27 %),triangle (9.46 %),beach (9.19 %)
bedroom,house (28.11 %),bed (14.59 %),diamond (10.54 %),bird (8.38 %),chair (7.3 %)
forest,diamond (14.32 %),house (11.35 %),triangle (10.27 %),forest (8.65 %),tree (7.57 %)
kitchen,house (28.65 %),bird (12.16 %),diamond (7.57 %),chair (7.3 %),triangle (7.3 %)
mountain,mountain (74.05 %),triangle (4.05 %),figure (3.78 %),person (3.24 %),pyramid (2.97 %)
office,chair (14.86 %),house (14.05 %),diamond (8.92 %),person (7.3 %),triangle (7.3 %)
sea_bottom,triangle (9.19 %),diamond (8.65 %),square (7.57 %),fish (7.03 %),letter (7.03 %)
sky,triangle (14.32 %),bird (13.78 %),house (11.35 %),diamond (10.27 %),letter (7.84 %)
street,house (22.7 %),triangle (11.35 %),diamond (10.81 %),person (8.65 %),dog (7.57 %)


\begin{tabular}{llllll}
\toprule
{} &                  \#1 &                  \#2 &                  \#3 &                 \#4 &                  \#5 \\
\midrule
bathroom   &   bathtub (28.65 \%) &     house (19.19 \%) &    diamond (7.03 \%) &    person (6.76 \%) &  rectangle (5.95 \%) \\
beach      &       sun (14.59 \%) &     house (12.43 \%) &      bird (10.27 \%) &  triangle (9.46 \%) &      beach (9.19 \%) \\
bedroom    &     house (28.11 \%) &       bed (14.59 \%) &   diamond (10.54 \%) &      bird (8.38 \%) &       chair (7.3 \%) \\
forest     &   diamond (14.32 \%) &     house (11.35 \%) &  triangle (10.27 \%) &    forest (8.65 \%) &       tree (7.57 \%) \\
kitchen    &     house (28.65 \%) &      bird (12.16 \%) &    diamond (7.57 \%) &      chair (7.3 \%) &    triangle (7.3 \%) \\
mountain   &  mountain (74.05 \%) &   triangle (4.05 \%) &     figure (3.78 \%) &    person (3.24 \%) &    pyramid (2.97 \%) \\
office     &     chair (14.86 \%) &     house (14.05 \%) &    diamond 

  print(words_and_counts.to_latex())


Unnamed: 0,#1,#2,#3,#4,#5
bathroom,house (18.11 %),bird (16.22 %),person (13.78 %),square (10.81 %),tree (9.19 %)
beach,bird (19.73 %),person (14.05 %),house (11.35 %),tree (10.27 %),a (7.57 %)
bedroom,house (20.81 %),bird (17.84 %),person (12.7 %),square (7.84 %),a (7.57 %)
forest,bird (26.22 %),person (18.92 %),house (15.14 %),tree (8.65 %),animal (4.86 %)
kitchen,house (20.81 %),bird (18.65 %),person (15.95 %),square (8.38 %),tree (5.41 %)
mountain,mountain (35.41 %),house (16.76 %),bird (12.7 %),person (10.81 %),landscape (6.49 %)
office,house (18.92 %),bird (16.49 %),person (14.86 %),square (11.35 %),tree (5.41 %)
sea_bottom,bird (21.35 %),person (14.86 %),fish (9.19 %),house (8.92 %),square (8.38 %)
sky,bird (24.86 %),house (18.11 %),person (14.32 %),square (7.03 %),tree (6.49 %)
street,bird (20.54 %),person (18.11 %),house (15.68 %),square (11.35 %),a (5.95 %)


\begin{tabular}{llllll}
\toprule
{} &                  \#1 &                \#2 &                \#3 &                \#4 &                  \#5 \\
\midrule
bathroom   &     house (18.11 \%) &    bird (16.22 \%) &  person (13.78 \%) &  square (10.81 \%) &       tree (9.19 \%) \\
beach      &      bird (19.73 \%) &  person (14.05 \%) &   house (11.35 \%) &    tree (10.27 \%) &          a (7.57 \%) \\
bedroom    &     house (20.81 \%) &    bird (17.84 \%) &   person (12.7 \%) &   square (7.84 \%) &          a (7.57 \%) \\
forest     &      bird (26.22 \%) &  person (18.92 \%) &   house (15.14 \%) &     tree (8.65 \%) &     animal (4.86 \%) \\
kitchen    &     house (20.81 \%) &    bird (18.65 \%) &  person (15.95 \%) &   square (8.38 \%) &       tree (5.41 \%) \\
mountain   &  mountain (35.41 \%) &   house (16.76 \%) &     bird (12.7 \%) &  person (10.81 \%) &  landscape (6.49 \%) \\
office     &     house (18.92 \%) &    bird (16.49 \%) &  person (14.86 \%) &  square (11.35 \%) &       

  print(words_and_counts.to_latex())


Unnamed: 0,#1,#2,#3,#4,#5
bathroom,bird (25.14 %),house (14.86 %),triangle (14.32 %),dog (8.92 %),man (4.86 %)
beach,bird (24.05 %),triangle (15.68 %),dog (13.51 %),house (10.0 %),man (3.78 %)
bedroom,bird (20.54 %),house (18.11 %),dog (14.05 %),triangle (12.97 %),horse (4.05 %)
forest,bird (21.62 %),triangle (19.73 %),dog (10.81 %),house (9.73 %),man (4.32 %)
kitchen,bird (22.43 %),triangle (16.49 %),house (13.78 %),dog (12.7 %),square (4.05 %)
mountain,triangle (21.08 %),mountain (16.22 %),bird (12.16 %),dog (8.92 %),house (7.84 %)
office,bird (21.62 %),triangle (17.57 %),dog (11.62 %),house (11.62 %),person (5.68 %)
sea_bottom,triangle (20.27 %),bird (19.46 %),dog (7.3 %),house (6.76 %),man (4.59 %)
sky,bird (23.78 %),triangle (17.03 %),dog (10.0 %),house (9.73 %),man (5.95 %)
street,triangle (20.27 %),bird (19.19 %),dog (12.16 %),house (11.62 %),diamond (5.14 %)


\begin{tabular}{llllll}
\toprule
{} &                  \#1 &                  \#2 &                  \#3 &                  \#4 &                \#5 \\
\midrule
bathroom   &      bird (25.14 \%) &     house (14.86 \%) &  triangle (14.32 \%) &        dog (8.92 \%) &      man (4.86 \%) \\
beach      &      bird (24.05 \%) &  triangle (15.68 \%) &       dog (13.51 \%) &      house (10.0 \%) &      man (3.78 \%) \\
bedroom    &      bird (20.54 \%) &     house (18.11 \%) &       dog (14.05 \%) &  triangle (12.97 \%) &    horse (4.05 \%) \\
forest     &      bird (21.62 \%) &  triangle (19.73 \%) &       dog (10.81 \%) &      house (9.73 \%) &      man (4.32 \%) \\
kitchen    &      bird (22.43 \%) &  triangle (16.49 \%) &     house (13.78 \%) &        dog (12.7 \%) &   square (4.05 \%) \\
mountain   &  triangle (21.08 \%) &  mountain (16.22 \%) &      bird (12.16 \%) &        dog (8.92 \%) &    house (7.84 \%) \\
office     &      bird (21.62 \%) &  triangle (17.57 \%) &       dog (11.62 \

  print(words_and_counts.to_latex())


Unnamed: 0,#1,#2,#3,#4,#5
bathroom,house (71.08 %),bird (6.76 %),letter (5.14 %),person (3.51 %),figure (1.89 %)
beach,house (64.86 %),bird (7.3 %),letter (5.68 %),person (5.14 %),dog (3.51 %)
bedroom,house (69.73 %),letter (6.22 %),person (4.59 %),bird (4.32 %),bed (3.24 %)
forest,house (63.78 %),person (8.38 %),dog (5.14 %),bird (4.59 %),letter (3.51 %)
kitchen,house (68.92 %),person (4.59 %),bird (4.05 %),letter (3.78 %),dog (3.78 %)
mountain,mountain (64.86 %),house (19.19 %),person (5.95 %),pyramid (2.43 %),letter (2.16 %)
office,house (57.3 %),person (7.84 %),letter (7.84 %),bird (4.86 %),pyramid (3.24 %)
sea_bottom,house (54.32 %),bird (7.3 %),boat (6.76 %),person (4.86 %),letter (4.05 %)
sky,house (62.43 %),bird (9.19 %),person (5.14 %),horse (4.59 %),letter (3.51 %)
street,house (64.59 %),person (8.11 %),letter (7.3 %),bird (5.41 %),dog (2.7 %)


\begin{tabular}{llllll}
\toprule
{} &                  \#1 &               \#2 &               \#3 &                \#4 &                \#5 \\
\midrule
bathroom   &     house (71.08 \%) &    bird (6.76 \%) &  letter (5.14 \%) &   person (3.51 \%) &   figure (1.89 \%) \\
beach      &     house (64.86 \%) &     bird (7.3 \%) &  letter (5.68 \%) &   person (5.14 \%) &      dog (3.51 \%) \\
bedroom    &     house (69.73 \%) &  letter (6.22 \%) &  person (4.59 \%) &     bird (4.32 \%) &      bed (3.24 \%) \\
forest     &     house (63.78 \%) &  person (8.38 \%) &     dog (5.14 \%) &     bird (4.59 \%) &   letter (3.51 \%) \\
kitchen    &     house (68.92 \%) &  person (4.59 \%) &    bird (4.05 \%) &   letter (3.78 \%) &      dog (3.78 \%) \\
mountain   &  mountain (64.86 \%) &  house (19.19 \%) &  person (5.95 \%) &  pyramid (2.43 \%) &   letter (2.16 \%) \\
office     &      house (57.3 \%) &  person (7.84 \%) &  letter (7.84 \%) &     bird (4.86 \%) &  pyramid (3.24 \%) \\
sea\_bottom & 

  print(words_and_counts.to_latex())


# Comparison with human anns

In [6]:
# make df with kilogram anns

def unpack_anns(list_of_annotations):
    return [x['whole']['wholeAnnotation'] for x in list_of_annotations]

kilogram_path = osp.join(kilogram_dir, 'dataset', 'dense.json')
kilogram_df = pd.read_json(kilogram_path).T
kilogram_df = kilogram_df.loc[tangrams]

kilogram_df['annotation_strings'] = kilogram_df.annotations.map(unpack_anns)
kilogram_df['annotation_heads'] = kilogram_df.annotation_strings.map(lambda x: transform_list_into_heads(x, nlp, normalize=True))

#display(kilogram_df.head())

In [7]:
_pdf = pred_df[[f'label_{model}' for model in models]]

# merge with SceneGram anns
_adf = ann_df.groupby(['tangram', 'scene']).agg({'wn_lemma': set}).rename(columns={'wn_lemma':'label_human'})
_pdf = pd.merge(
    _pdf, _adf, 
    left_index=True, right_index=True
).reset_index()

# merge with KiloGram anns
_pdf = pd.merge(
    left=_pdf,
    right=kilogram_df[['annotation_heads']],
    left_on='tangram',
    right_index=True
)

In [8]:
def overlap(x_anns, y_anns):
    overlap = [x for x in x_anns if x in y_anns]
    return len(overlap) / len(x_anns)

for model in models:
    # calculate overlaps
    _pdf[f'overlap_{model}_kilogram'] = _pdf.apply(lambda x: overlap(x[f'label_{model}'], x.annotation_heads), axis=1)
    _pdf[f'overlap_{model}_scenegram'] = _pdf.apply(lambda x: overlap(x[f'label_{model}'], x.label_human), axis=1)

# KiloGram Overlap (Table 3)

In [9]:
print('KILOGRAM')

kilogram_overlap = _pdf.groupby('scene')[
    [c for c in _pdf.columns if c.endswith('_kilogram')]].mean() * 100  # as %
# sort by model size
order = sorted(kilogram_overlap.columns, key=model_size_from_col)
kilogram_overlap = kilogram_overlap.loc[scenes]

kilogram_overlap = kilogram_overlap.rename(columns={
        c:c.replace('overlap_', '').replace('_kilogram', '') 
        for c in kilogram_overlap.columns
    })
kilogram_overlap = kilogram_overlap[models]

print(kilogram_overlap.mean().round(2).to_frame().to_latex())

KILOGRAM
\begin{tabular}{lr}
\toprule
{} &      0 \\
\midrule
llava-7b  &  34.99 \\
llava-13b &  34.42 \\
llava-34b &  50.22 \\
llava-72b &  54.52 \\
\bottomrule
\end{tabular}



  print(kilogram_overlap.mean().round(2).to_frame().to_latex())


# SceneGram Overlap (Table 3)

In [10]:
print('SCENEGRAM')

kilogram_overlap = _pdf.groupby('scene')[
    [c for c in _pdf.columns if c.endswith('_scenegram')]].mean() * 100  # as %
# sort by model size
order = sorted(kilogram_overlap.columns, key=model_size_from_col)
kilogram_overlap = kilogram_overlap.loc[scenes]

kilogram_overlap = kilogram_overlap.rename(columns={
        c:c.replace('overlap_', '').replace('_scenegram', '') 
        for c in kilogram_overlap.columns
    })
kilogram_overlap = kilogram_overlap[models]

print(kilogram_overlap.mean().round(2).to_frame().to_latex())

SCENEGRAM
\begin{tabular}{lr}
\toprule
{} &      0 \\
\midrule
llava-7b  &  26.61 \\
llava-13b &  21.13 \\
llava-34b &  27.64 \\
llava-72b &  26.00 \\
\bottomrule
\end{tabular}



  print(kilogram_overlap.mean().round(2).to_frame().to_latex())


# % Top (Table 3)

In [11]:
def get_perc_top(labels):
    counts = Counter(labels)
    return (max(counts.values()) / sum(counts.values())) * 100

# calculate % top
_df = pred_df.reset_index()[['scene', 'tangram'] + [f'label_{model}' for model in models]]
_df[[f'label_{model}' for model in models]] = _df[[f'label_{model}' for model in models]].applymap(get_perc_top)
_df[[f'label_{model}' for model in models]].mean().round(2)

label_llava-7b     58.50
label_llava-13b    36.71
label_llava-34b    59.16
label_llava-72b    79.46
dtype: float64