### Scrips for generating splits
This script assums you have the main ReaSCAN generated by the generate_ReaSCAN.py script. After that, you can use this file to generate/extrapolate different splits. In the future, we may consolidate two files.

In [10]:
from collections import namedtuple, OrderedDict
import os
from typing import List
from typing import Tuple
import logging
from collections import defaultdict
from collections import Counter
import json
import torch
import numpy as np

def isnotebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
if isnotebook():
    device = torch.device("cpu")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

FORMAT = "%(asctime)-15s %(message)s"
logging.basicConfig(format=FORMAT, level=logging.INFO,
                    datefmt="%Y-%m-%d %H:%M")
logger = logging.getLogger(__name__)

from world import *
from vocabulary import Vocabulary as ReaSCANVocabulary
from object_vocabulary import *

In [2]:
path_to_data = "../../data-files/ReaSCAN-compositional/data-train.txt"
logger.info(f"Reading dataset from file: {path_to_data}...")
data_json = json.load(open(path_to_data, "r"))

In [3]:
all_fake_train = data_json["examples"]["train"]
# for dev and test, it is simple, let us just shuffle, and random select.
len(all_fake_train)

544579

In [54]:
# For generating the splits, we actually have to go through compositional splits first
# and then consider random splits like dev and test. Because, we don't want things mixed up
# in the dev and test. Dev and test should only contain commands that appear in the train,
# so a total random partition at the end should work.

# We do the splits step-by-step!
id_example_map = OrderedDict({})
id_splits_map = OrderedDict({})
index = 0
for example in data_json["examples"]["train"]:
    id_example_map[index] = example
    id_splits_map[index] = set([]) # set of splits that this example belongs to.
    
    # gscan_yellow_square_command_target_only
    if "yellow,square" in example['command'].split("that")[0]:
        id_splits_map[index].add("gscan_yellow_square_command_target_only")

    # gscan_yellow_square_command
    if "yellow,square" in example['command']:
        id_splits_map[index].add("gscan_yellow_square_command")
    
    # gscan_red_box_visual
    if "red,box" in example['command'] or \
        ((example['situation']['placed_objects']['1']['object']['shape'] == "box" and \
        example['situation']['placed_objects']['1']['object']['color'] == "red") or \
        (example['situation']['placed_objects']['2']['object']['shape'] == "box" and \
        example['situation']['placed_objects']['2']['object']['color'] == "red")):
        id_splits_map[index].add("gscan_red_box_visual")
    
    # gscan_small_cylinder_command_target_only
    if "small,cylinder" in example['command'].split("that")[0] or \
        "small,red,cylinder" in example['command'].split("that")[0] or \
        "small,blue,cylinder" in example['command'].split("that")[0] or \
        "small,yellow,cylinder" in example['command'].split("that")[0] or \
        "small,green,cylinder" in example['command'].split("that")[0]:
        id_splits_map[index].add("gscan_small_cylinder_command_target_only")
    
    # novel_yellow_square_blue_circle_coexist_shape
    if "yellow,square" in example['command'] and "blue,circle" in example['command']:
        id_splits_map[index].add("novel_yellow_square_blue_circle_coexist_shape")
    
    # novel_green_circle_box_coexist (must be down side objects)
    if "green,circle" not in example['command'].split("that")[0] and \
        "green,circle" in example['command'] and "box" in example['command']:
        id_splits_map[index].add("novel_green_circle_box_coexist_box_shape")

    # novel_same_shape_is_inside_coexist_relation
    if "same,shape" in example['command'] and "is,inside" in example['command']:
        id_splits_map[index].add("novel_same_shape_is_inside_coexist_relation")
        
    # novel_inside_of_as_yellow_box
    if "is,inside,of,a,yellow,box" in example['command'] or \
        "is,inside,of,the,yellow,box" in example['command'] or \
        "is,inside,of,a,small,yellow,box" in example['command'] or \
        "is,inside,of,the,small,yellow,box" in example['command']:
        id_splits_map[index].add("novel_inside_of_as_yellow_box")
    
    # few_shot_single_clause_logic
    if example['grammer_pattern'] == "$OBJ_0 ^ $OBJ_1":
        id_splits_map[index].add("few_shot_single_clause_logic")
        
    index += 1

In [64]:
splits_distribution = OrderedDict({})
splits_assignment = OrderedDict({})
for index, splits in id_splits_map.items():
    if len(splits) == 0:
        split = "train" # let us split this up later!
        if split in splits_distribution.keys():
            splits_distribution[split] += 1
        else:
            splits_distribution[split] = 1
        
        if split in splits_assignment:
            splits_assignment[split].append(index)
        else:
            splits_assignment[split] = [index]
    else:   
        for split in splits:
            if split in splits_distribution.keys():
                splits_distribution[split] += 1
            else:
                splits_distribution[split] = 1
                
            if split in splits_assignment:
                splits_assignment[split].append(index)
            else:
                splits_assignment[split] = [index]

# Let us further segment train into dev and test!
gscan_dev_size = 3716
gscan_test_size = 19282
all_example_id = splits_assignment["train"]
random.shuffle(all_example_id)
train_example_id = all_example_id[:(-3716-19282)]
dev_example_id = all_example_id[(-3716-19282):-3716]
test_example_id = all_example_id[-3716:]
splits_assignment["train"] = train_example_id
splits_assignment["dev"] = dev_example_id
splits_assignment["test"] = test_example_id

In [67]:
for split, all_ids in splits_assignment.items():
    print(f"for {split} split, we have {len(all_ids)} examples.")

for train split, we have 315421 examples.
for novel_inside_of_as_yellow_box split, we have 15950 examples.
for gscan_yellow_square_command_target_only split, we have 21801 examples.
for gscan_yellow_square_command split, we have 76979 examples.
for gscan_red_box_visual split, we have 52433 examples.
for novel_green_circle_box_coexist_box_shape split, we have 13700 examples.
for gscan_small_cylinder_command_target_only split, we have 31867 examples.
for novel_yellow_square_blue_circle_coexist_shape split, we have 8450 examples.
for novel_same_shape_is_inside_coexist_relation split, we have 4698 examples.
for few_shot_single_clause_logic split, we have 49511 examples.
for dev split, we have 19282 examples.
for test split, we have 3716 examples.


In [68]:
# remake our data file accordingly.
updated_examples = OrderedDict({})
for split, all_ids in splits_assignment.items():
    updated_examples[split] = []
    for _id in all_ids:
        updated_examples[split].append(id_example_map[_id])

In [74]:
# save it to the disk
data_json["examples"] = updated_examples
with open("../../data-files/ReaSCAN-compositional/data-compositional-splits.txt", "w") as fd:
    json.dump(data_json, fd, indent=4)