# Split Creation
Generates splits for a GEECO dataset.

In [None]:
'''imports'''
import os
import re
import csv
import pprint

import numpy as np

In [None]:
'''input'''
# dataset_name = 'gym-push-pad1-cube1-v4'
# dataset_name = 'gym-push-pad1-cube2-v4'
# dataset_name = 'gym-push-pad2-cube1-v4'
# dataset_name = 'gym-push-pad2-cube2-v4'
# dataset_name = 'gym-pick-pad1-cube1-v4'
# dataset_name = 'gym-pick-pad1-cube2-v4'
# dataset_name = 'gym-pick-pad2-cube1-v4'
dataset_name = 'gym-pick-pad2-cube2-v4'
# dataset_name = 'gym-pick-pad2-cube2-clutter4-v4'
# dataset_name = 'gym-pick-pad2-cube2-clutter12-v4'

split_name = 'balanced'
# split_name = 'fasttest'
# split_name = 'default'
# split_name = 'debug'

In [None]:
'''presets'''
DATASET2CSV = {
        'gym-push-pad1-cube1-v4' : 'init-gym-push-pad1-cube1.csv',
        'gym-push-pad1-cube2-v4' : 'init-gym-push-pad1-cube2.csv',
        'gym-push-pad2-cube1-v4' : 'init-gym-push-pad2-cube1.csv',
        'gym-push-pad2-cube2-v4' : 'init-gym-push-pad2-cube2.csv',
        'gym-pick-pad1-cube1-v4' : 'init-gym-pick-pad1-cube1.csv',
        'gym-pick-pad1-cube2-v4' : 'init-gym-pick-pad1-cube2.csv',
        'gym-pick-pad2-cube1-v4' : 'init-gym-pick-pad2-cube1.csv',
        'gym-pick-pad2-cube2-v4' : 'init-gym-pick-pad2-cube2.csv',
        'gym-pick-pad2-cube2-clutter4-v4' : 'init-gym-pick-pad2-cube2-clutter4.csv',
        'gym-pick-pad2-cube2-clutter12-v4' : 'init-gym-pick-pad2-cube2-clutter12.csv',
}
SPLITRATIO = {
        'fasttest' : (0.0, 0.0, 1.0),
        'default' : (0.5, 0.3, 0.2),
        'debug' : (0.01, 0.01, 0.01),
        'balanced' : (0.5, 0.25, 0.25),
}
p_train, p_eval, p_test = SPLITRATIO[split_name]
init_fn = DATASET2CSV[dataset_name]

In [None]:
'''path setup'''
root_path = os.environ['GEECO_ROOT']
dataset_dir = os.path.join(root_path, 'data', dataset_name)

data_dir = os.path.join(dataset_dir, 'data')
meta_dir = os.path.join(dataset_dir, 'meta')
splits_dir = os.path.join(dataset_dir, 'splits')
print("Dataset directory: %s" % dataset_dir)

In [None]:
'''set up data structures'''

# list all tfrecords in <dataset_dir>/data
tfrecord_list = [fn for fn in os.listdir(data_dir) if fn.endswith('.tfrecord.zlib')]
print("Found %d tfrecords in %s" % (len(tfrecord_list), data_dir))
# 1-based record filenames!
idx2tfrecord = dict([(int(re.search(r'\d+', fn).group(0)) - 1, fn) for fn in tfrecord_list])
tfrecord2idx = dict([(fn, int(re.search(r'\d+', fn).group(0)) - 1) for fn in tfrecord_list])

# associate rows from init csv
init_path = os.path.join(meta_dir, init_fn)
init_rows = []
with open(init_path) as fp:
    reader = csv.reader(fp, delimiter=';')
    iterator = iter(reader)
    header_row = next(iterator)
    print(header_row)
    while True:
        try:
            row = next(iterator)
        except StopIteration:
            break
        state_row = [float(e) for e in row[:-2]]
        task_row = [str(e) for e in row[-2:]]
        init_rows.append(state_row + task_row)

# group by task
task_map = {}
task_groups = {}
for record_id, record_name in idx2tfrecord.items():
    task_name = "".join(init_rows[record_id][-2:])
    task_map[record_name] = task_name
    if task_name not in task_groups:
        task_groups[task_name] = [record_name]
    else:
        task_groups[task_name].append(record_name)
print("Found the following task groups:")
for task_name, record_names in task_groups.items():
    print(task_name, len(record_names))

In [None]:
'''helper functions'''
def create_split(items, p_train, p_eval, p_test):
    num_items = len(items)
    np.random.shuffle(items)
    idx_train = int(np.rint(num_items * p_train))
    items_train = items[:idx_train]
    del items[:idx_train]
    idx_eval = int(np.rint(num_items * p_eval))
    items_eval = items[:idx_eval]
    del items[:idx_eval]
    items_test = items  # remainder
    return items_train, items_eval, items_test

def load_split(split_dir):
    items_train, items_eval, items_test = [], [], []
    for split, items in [('train', items_train), ('eval', items_eval), ('test', items_test)]:
        split_file = os.path.join(split_dir, '%s.txt' % split)
        with open(split_file, 'r') as fp:
            items.extend(fp.read().split('\n')[:-1])
    return items_train, items_eval, items_test

In [None]:
'''create split'''
split_dir = os.path.join(splits_dir, split_name)
os.makedirs(split_dir, exist_ok=True)

# split into train | eval | test or load from files
split_paths = [os.path.join(split_dir, '%s.txt' % split) for split in ['train', 'eval', 'test']]
if all([os.path.exists(p) for p in split_paths]):  # complete split, load files
    print("Complete split found at %s" % split_dir)
    print("Loading items...")
    train_items, eval_items, test_items = load_split(split_dir)
elif any([os.path.exists(p) for p in split_paths]):  # incomplete split, raise warning
    err_msg = "Incomplete split found at %s! Please fix before continuing!" % split_dir
    print(os.listdir(split_dir))
    raise RuntimeWarning(err_msg)
else:  # empty split, create from scratch
    print("Empty spilt. Create a new one from scratch.")
    train_items, eval_items, test_items = [], [], []
    # sample in equal proportions from each task group
    for task_name, record_names in task_groups.items():
        _train, _eval, _test = create_split(record_names, p_train, p_eval, p_test)
        train_items.extend(_train)
        eval_items.extend(_eval)
        test_items.extend(_test)

print("Split:\tTrain: %d\tEval: %d\tTest: %d" % (len(train_items), len(eval_items), len(test_items)))

In [None]:
'''associate init'''
# associate each tfrecord with its row ID in the initial config CSV table
train_init = [init_rows[tfrecord2idx[tfr]] for tfr in train_items]
eval_init = [init_rows[tfrecord2idx[tfr]] for tfr in eval_items]
test_init = [init_rows[tfrecord2idx[tfr]] for tfr in test_items]
print("Associated inits:\tTrain: %d\tEval: %d\tTest: %d" % (len(train_init), len(eval_init), len(test_init)))

In [None]:
'''save split files'''
for part, items in [('train', train_items), ('eval', eval_items), ('test', test_items)]:
    file_path = os.path.join(split_dir, '%s.txt' % (part, ))
    if os.path.exists(file_path):
        print("Split file %s already exists!" % file_path)
    else:
        print("Writing split file: %s" % file_path)
        with open(file_path, 'w') as fp:
            for item in items:
                fp.write(item+'\n')

In [None]:
'''export init configurations associated with splits'''
for split, rows in [('train', train_init), ('eval', eval_init), ('test', test_init)]:
    csv_path = os.path.join(split_dir, 'init-%s.csv' % (split, ))
    if os.path.exists(csv_path):
        print("Init file %s already exists!" % csv_path)
    else:
        print("Writing init file: %s" % csv_path)
        with open(csv_path, 'w', newline='') as fp:
            writer = csv.writer(fp, delimiter=';')
            writer.writerow(header_row)
            for row in rows:
                writer.writerow(row)