In [1]:
import torch
from egg.core import Interaction
import os
from glob import glob
from collections import Counter, defaultdict
from pprint import pprint
import json


In [14]:
dataset_mapping = {
    "dale-2": ("CLEVR_UNAMBIGOUS-DALE-TWO", 'clevr-images-unambigous-dale-two'),
    "dale-5": ("CLEVR_UNAMBIGOUS-DALE", 'clevr-images-unambigous-dale'),
    "single": ("CLEVR_RANDOM-SINGLE", 'clevr-images-random-single'),
    "colour": ("CLEVR_UNAMBIGOUS-COLOR", 'clevr-images-unambigous-colour'),
}
dataset_root_path = '/home/dominik/Development/'

root_dir = '/home/dominik/Nextcloud/020_Masterstudium/Language Technology/LT2402_Master Thesis/experiments/language-games/'
# root_dir = "/home/dominik/Development/MLT_Master-Thesis/out/"
model_dir = os.path.join(root_dir, 'bb_one-hot_generator/')
dataset = 'colour'
# dataset_dir = os.path.join(model_dir, f"{dataset}/")
run_dir = os.path.join(model_dir, '2024-03-16_23-33-27_bounding_box_one_hot_generator_colour_1_100_500_100_100_500_500_2')
train_interaction_path = glob(os.path.join(run_dir, "interactions/", "train/", "epoch*/", "interaction*"))[0]
test_interaction_path = glob(os.path.join(run_dir, 'interactions/', 'validation/', 'epoch*/', 'interaction*'))[0]

In [16]:
train_interaction: Interaction = torch.load(train_interaction_path)
test_interaction: Interaction = torch.load(test_interaction_path)

def remove_eos(tensor):
    for index, symbol in enumerate(tensor):
        if int(symbol) == 0:
            return tuple(tensor[:index].tolist())
        
def get_messages(interaction):
    messages = [remove_eos(message.max(dim=1).indices) for message in interaction.message]
    return messages, Counter(messages)

train_messages, train_counter = get_messages(train_interaction)
test_messages, test_counter = get_messages(test_interaction)

print(train_messages)
print(train_counter)

print(test_messages)
print(test_counter)

def get_image_name(image_id, dataset):
    return f"{dataset_mapping[dataset][0]}_{str(int(image_id)).zfill(6)}"

def get_mapping(interaction, messages, counter):
    image_mapping = {}
    for key in counter:
        image_mapping[key] = [get_image_name(image_id, dataset) 
                              for image_id, message in zip(interaction.aux_input['image_id'], messages)
                              if message == key]
    return image_mapping

train_image_mapping = get_mapping(train_interaction, train_messages, train_counter)
test_image_mapping = get_mapping(test_interaction, test_messages, test_counter)
# pprint(test_image_mapping)
print(len(train_counter))

[(), (), (), (1,), (), (1,), (), (1,), (), (1,), (), (1,), (), (1,), (), (1,), (1,), (), (), (1,), (1,), (), (), (), (), (1,), (1,), (1,), (), (), (1,), (), (), (), (), (1,), (), (1,), (1,), (), (1,), (), (), (), (), (1,), (), (1,), (), (), (), (), (), (), (), (), (), (1,), (), (), (), (), (), (1,), (), (1,), (), (), (), (), (), (), (1,), (), (1,), (), (), (), (), (), (1,), (1,), (), (1,), (1,), (), (1,), (), (), (), (), (1,), (), (1,), (), (), (1,), (), (), (), (), (), (1,), (1,), (), (), (1,), (), (), (1,), (), (1,), (), (), (), (), (1,), (), (1,), (), (), (), (), (1,), (), (), (1,), (), (), (), (), (1,), (), (), (), (), (1,), (1,), (1,), (), (), (1,), (), (1,), (1,), (1,), (1,), (1,), (), (), (), (), (), (), (), (1,), (), (1,), (1,), (), (), (), (), (1,), (), (1,), (), (1,), (1,), (), (1,), (), (), (), (), (), (), (), (), (), (), (), (1,), (), (), (), (), (), (), (), (), (), (), (1,), (), (1,), (1,), (1,), (), (1,), (), (), (1,), (), (), (1,), (1,), (1,), (), (), (), (1,), (), (), (

In [18]:
def get_dale(order, target_attributes, scene):
    caption = [target_attributes[0]]
    remaining_objects = [
        obj for obj in scene["objects"] if obj[order[0]] == target_attributes[0]
    ]

    if len(remaining_objects) > 1:
        caption.insert(0, target_attributes[1])
        remaining_objects = [
            obj for obj in remaining_objects if obj[order[1]] == target_attributes[1]
        ]

        if len(remaining_objects) > 1:
            caption.insert(0, target_attributes[2])
    
    return tuple(caption)


scenes = {}
for scene_file in glob(os.path.join(dataset_root_path, dataset_mapping[dataset][1], 'scenes/*')):
    with open(scene_file, "r", encoding="utf-8") as f:
        scene = json.load(f)

    image_id = scene_file.split('/')[-1].removesuffix(".json")
    target_object_index = scene["groups"]["target"][0]
    target_object = scene["objects"][target_object_index]

    scenes[image_id] = {
        'shape': target_object["shape"],
        'color': target_object["color"],
        'size': target_object["size"],
    }

    combinations = [
        ('shape', 'color', 'size'),
        ('shape', 'size', 'color'),
        ('color', 'shape', 'size'),
        ('color', 'size', 'shape'),
        ('size', 'shape', 'color'),
        ('size', 'color', 'shape'),
    ]
    for combination in combinations:
        scenes[image_id][combination] = get_dale(combination,
                                                 (
                                                     target_object[combination[0]],
                                                     target_object[combination[1]],
                                                     target_object[combination[2]],
                                                 ),
                                                 scene)

# print(scenes['CLEVR_UNAMBIGOUS-DALE-TWO_000000'])


In [20]:
colors = {}

for message, image_ids in test_image_mapping.items():
    counter = Counter()
    for image_id in image_ids:
        counter.update([(scenes[image_id][('color')])])
    colors[message] = counter

pprint(colors)


{(): Counter({'brown': 218,
              'green': 217,
              'yellow': 196,
              'cyan': 193,
              'red': 183,
              'blue': 181,
              'purple': 175,
              'gray': 168}),
 (1,): Counter({'yellow': 106,
                'cyan': 105,
                'green': 103,
                'purple': 101,
                'red': 100,
                'brown': 92,
                'blue': 87,
                'gray': 79})}
