In [1]:
import itertools as it
import json
import numpy as np
import os
import pickle

DATA_DIR = "../data"
SEG_DIR = "instance_segmentation_data"
HOUSE_DIR = "houses"

GOOD_PARTS = {"wall", "roof"}

# TODO move to constants file
ACTION_STOP = 0
ACTION_GO = 1
TASK_PAD = "*pad*"
TASK_SOS = "*sos*"
TASK_EOS = "*eos*"

In [2]:
with open(os.path.join(DATA_DIR, SEG_DIR, "training_data.pkl"), "rb") as seg_f:
    segments = pickle.load(seg_f)

In [3]:
annotation_vocab = {TASK_PAD: 0, TASK_SOS: 1, TASK_EOS: 2}

SCHEMATIC_SIZE = (20, 20, 20)

def construct_plan(schematic, annotation):
    if annotation not in annotation_vocab:
        annotation_vocab[annotation] = len(annotation_vocab)
    plan = []
    snapshots = []
    schematic_so_far = np.zeros(SCHEMATIC_SIZE)
    for x, y, z in it.product(*[range(d) for d in schematic.shape]):
        if schematic[x, y, z] == 0:
            continue
        plan.append((
            annotation_vocab[annotation], 
            ACTION_GO, 
            (x, y, z, schematic[x, y, z]), 
        ))
        snapshots.append(schematic_so_far.copy())
        if x < SCHEMATIC_SIZE[0] and y < SCHEMATIC_SIZE[1] and z < SCHEMATIC_SIZE[2]:
            schematic_so_far[x, y, z] = schematic[x, y, z]
    plan.append((0, ACTION_STOP, (0, 0, 0, 0)))
    snapshots.append(schematic_so_far.copy())
    return plan, snapshots

plans_flat = []
plans_hier = []
for schematic, segmentation, annotations, house_id in segments:
    keep_schematic = np.zeros(schematic.shape, dtype=np.int32)
    
    for i, annotation in enumerate(annotations):
        if annotation not in GOOD_PARTS:
            continue
        keep_schematic[segmentation == i] = schematic[segmentation == i]
    if (keep_schematic > 0).sum() == 0:
        continue
    plan_flat, snapshots_flat = construct_plan(keep_schematic, "house")
    plans_flat.append(plan_flat)
    np.savez("snapshots/flat/{}.npz".format(len(plans_flat)-1), data=snapshots_flat)
    
    plan_hier = []
    snapshots_hier = []
    for i in np.random.permutation(len(annotations)):
        annotation = annotations[i]
        if annotation not in GOOD_PARTS:
            continue
        part_schematic = np.zeros(schematic.shape, dtype=np.int32)
        part_schematic[segmentation == i] = schematic[segmentation == i]
        part_plan_hier, part_snapshots_hier = construct_plan(part_schematic, annotation)
        plan_hier += part_plan_hier
        snapshots_hier += part_snapshots_hier
    plans_hier.append(plan_hier)
    np.savez("snapshots/hier/{}.npz".format(len(plans_flat)-1), data=snapshots_hier)
    
    assert len([p for p in plan_flat if p[1] != ACTION_STOP]) == len([p for p in plan_hier if p[1] != ACTION_STOP])
    
with open("flat.pkl", "wb") as f:
    pickle.dump(plans_flat, f)
    
with open("hier.pkl", "wb") as f:
    pickle.dump(plans_hier, f)

with open("annotation_vocab.json", "w") as f:
    json.dump(annotation_vocab, f)