In [1]:
import json
import numpy as np
from typing import Dict, List, Union
import logging
import argparse

flags = {
    "dataset_path": "../../../data-files/ReaSCAN-compositional/data-compositional-splits.txt",
    "output_file": "../../../data-files/ReaSCAN-compositional/parsed_dataset.txt",
    "save_data": True
}

FORMAT = "%(asctime)-15s %(message)s"
logging.basicConfig(format=FORMAT, level=logging.DEBUG,
                    datefmt="%Y-%m-%d %H:%M")
logger = logging.getLogger(__name__)

In [2]:
def parse_sparse_situation(situation_representation: dict, grid_size: int) -> np.ndarray:
    """
    Each grid cell in a situation is fully specified by a vector:
    [_ _ _ _ _ _ _   _       _      _       _   _ _ _ _]
     1 2 3 4 r g b circle square cylinder agent E S W N
     _______ _____ ______________________ _____ _______
       size  color        shape           agent agent dir.
    :param situation_representation: data from dataset.txt at key "situation".
    :param grid_size: int determining row/column number.
    :return: grid to be parsed by computational models.
    """
    num_object_attributes = len([int(bit) for bit in situation_representation["target_object"]["vector"]])
    # Object representation + agent bit + agent direction bits (see docstring).
    num_grid_channels = num_object_attributes + 1 + 4

    # Initialize the grid.
    grid = np.zeros([grid_size, grid_size, num_grid_channels], dtype=int)

    # Place the agent.
    agent_row = int(situation_representation["agent_position"]["row"])
    agent_column = int(situation_representation["agent_position"]["column"])
    agent_direction = int(situation_representation["agent_direction"])
    agent_representation = np.zeros([num_grid_channels], dtype=np.int)
    agent_representation[-5] = 1
    agent_representation[-4 + agent_direction] = 1
    grid[agent_row, agent_column, :] = agent_representation

    # Loop over the objects in the world and place them.
    for placed_object in situation_representation["placed_objects"].values():
        object_vector = np.array([int(bit) for bit in placed_object["vector"]], dtype=np.int)
        object_row = int(placed_object["position"]["row"])
        object_column = int(placed_object["position"]["column"])
        grid[object_row, object_column, :] = np.concatenate([object_vector, np.zeros([5], dtype=np.int)])
    return grid


def data_loader(file_path: str) -> Dict[str, Union[List[str], np.ndarray]]:
    """
    Loads grounded SCAN dataset from text file and ..
    :param file_path: Full path to file containing dataset (dataset.txt)
    :returns: dict with as keys all splits and values list of example dicts with input, target and situation.
    """
    with open(file_path, 'r') as infile:
        all_data = json.load(infile)
        grid_size = int(all_data["grid_size"])
        splits = list(all_data["examples"].keys())
        logger.info("Found data splits: {}".format(splits))
        loaded_data = {}
        for split in splits:
            loaded_data[split] = []
            logger.info("Now loading data for split: {}".format(split))
            for data_example in all_data["examples"][split]:
                input_command = data_example["command"].split(',')
                target_command = data_example["target_commands"].split(',')
                situation = parse_sparse_situation(situation_representation=data_example["situation"],
                                                   grid_size=grid_size)
                loaded_data[split].append({"input": input_command,
                                           "target": target_command,
                                           "situation": situation.tolist()})  # .tolist() necessary to be serializable
            logger.info("Loaded {} examples in split {}.\n".format(len(loaded_data[split]), split))
    return loaded_data


In [3]:
data = data_loader(flags["dataset_path"])

2021-05-18 01:50 Found data splits: ['train', 'novel_inside_of_as_yellow_box', 'gscan_yellow_square_command_target_only', 'gscan_yellow_square_command', 'gscan_red_box_visual', 'novel_green_circle_box_coexist_box_shape', 'gscan_small_cylinder_command_target_only', 'novel_yellow_square_blue_circle_coexist_shape', 'novel_same_shape_is_inside_coexist_relation', 'few_shot_single_clause_logic', 'dev', 'test']
2021-05-18 01:50 Now loading data for split: train
2021-05-18 01:52 Loaded 315421 examples in split train.

2021-05-18 01:52 Now loading data for split: novel_inside_of_as_yellow_box
2021-05-18 01:52 Loaded 15950 examples in split novel_inside_of_as_yellow_box.

2021-05-18 01:52 Now loading data for split: gscan_yellow_square_command_target_only
2021-05-18 01:52 Loaded 21801 examples in split gscan_yellow_square_command_target_only.

2021-05-18 01:52 Now loading data for split: gscan_yellow_square_command
2021-05-18 01:52 Loaded 76979 examples in split gscan_yellow_square_command.

202

In [4]:
if flags["save_data"]:
    with open(flags["output_file"], 'w') as outfile:
        json.dump(data, outfile, indent=4)

In [5]:
for split, dt in data.items():
    with open('parsed_dataset/' + split + '.json', 'w') as f:
        for line in dt:
            f.write(json.dumps(line) + '\n')