# Generate dataset splits

In [1]:
import os.path as osp
import os

import numpy as np

In [2]:
# -- Set up paths -- #
data_dir = '../data/'

mug_stereo_path = osp.join(data_dir, 'stereo_training_data', 'mug')
bowl_stereo_path = osp.join(data_dir, 'stereo_training_data', 'bowl')
bottle_stereo_path = osp.join(data_dir, 'stereo_training_data', 'bottle')

mug_shapenet_ids = [id for id in os.listdir(mug_stereo_path) if '.lst' not in id]
bowl_shapenet_ids = [id for id in os.listdir(bowl_stereo_path) if '.lst' not in id]
bottle_shapenet_ids = [id for id in os.listdir(bottle_stereo_path) if '.lst' not in id]

bad_mugs = np.load(osp.join(data_dir, 'stereo_training_data', 'bad_mugs_all.npz'))['bad_ids']

In [3]:
# -- Shuffle and split ids -- #
np.random.shuffle(mug_shapenet_ids)
np.random.shuffle(bowl_shapenet_ids)
np.random.shuffle(bottle_shapenet_ids)

n_mugs = len(mug_shapenet_ids)
n_bowls = len(bowl_shapenet_ids)
n_bottles = len(bottle_shapenet_ids)

mug_ids = dict(
    train_mug_shapenet_ids = mug_shapenet_ids[:int(0.8 * n_mugs)],
    test_mug_shapenet_ids = mug_shapenet_ids[int(0.8 * n_mugs):int(0.9 * n_mugs)],
    val_mug_shapenet_ids = mug_shapenet_ids[int(0.9 * n_mugs):],
)

bowl_ids = dict(
    train_bowl_shapenet_ids = bowl_shapenet_ids[:int(0.8 * n_bowls)],
    test_bowl_shapenet_ids = bowl_shapenet_ids[int(0.8 * n_bowls):int(0.9 * n_bowls)],
    val_bowl_shapenet_ids = bowl_shapenet_ids[int(0.9 * n_bowls):],
)

bottle_ids = dict(
    train_bottle_shapenet_ids = bottle_shapenet_ids[:int(0.8 * n_bottles)],
    test_bottle_shapenet_ids = bottle_shapenet_ids[int(0.8 * n_bottles):int(0.9 * n_bottles)],
    val_bottle_shapenet_ids = bottle_shapenet_ids[int(0.9 * n_bottles):],
)

type_to_ids = {
    'mug': mug_ids,
    'bowl': bowl_ids,
    'bottle': bottle_ids
}

type_to_data = {
    'mug': mug_stereo_path,
    'bowl': bowl_stereo_path,
    'bottle': bottle_stereo_path
}

In [4]:
# -- Make sure nothing overlaps -- #
for obj_type in ['mug', 'bowl', 'bottle']:
    obj_ids = type_to_ids[obj_type]

    train_ids = obj_ids[f'train_{obj_type}_shapenet_ids']
    test_ids = obj_ids[f'test_{obj_type}_shapenet_ids']
    val_ids = obj_ids[f'val_{obj_type}_shapenet_ids']
    assert len(set(train_ids + test_ids + val_ids)) == len(train_ids) + len(test_ids) + len(val_ids)

In [5]:
# -- Save splits -- #
for obj_type in ['mug', 'bowl', 'bottle']:
    obj_ids = type_to_ids[obj_type]

    train_ids = obj_ids[f'train_{obj_type}_shapenet_ids']
    test_ids = obj_ids[f'test_{obj_type}_shapenet_ids']
    val_ids = obj_ids[f'val_{obj_type}_shapenet_ids']

    save_dir = type_to_data[obj_type]
    with open(osp.join(save_dir, 'train.lst'), 'w') as f:
        f.write('\n'.join(train_ids))
    with open(osp.join(save_dir, 'test.lst'), 'w') as f:
        f.write('\n'.join(test_ids))
    with open(osp.join(save_dir, 'val.lst'), 'w') as f:
        f.write('\n'.join(val_ids))


In [6]:
# -- See if any of our data is in the bad object list -- #
print(bad_mugs)
print(len(bad_mugs))
print(len(mug_shapenet_ids))

intersection_cnt = 0
for i in mug_shapenet_ids:
    if i in bad_mugs:
        intersection_cnt += 1
print('Intersection count: ', intersection_cnt)

['32e197b8118b3ff6a4bd4f46ba404890' '7374ea7fee07f94c86032c4825f3450'
 '9196f53a0d4be2806ffeedd41ba624d6' 'b9004dcda66abf95b99d2a3bbaea842a'
 '9ff8400080c77feac2ad6fd1941624c3' '4f9f31db3c3873692a6f53dd95fd4468'
 '1c3fccb84f1eeb97a3d0a41d6c77ec7c' 'cc5b14ef71e87e9165ba97214ebde03'
 '159e56c18906830278d8f8c02c47cde0' 'c6b24bf0a011b100d536be1f5e11c560'
 '9880097f723c98a9bd8c6965c4665b41' 'e71102b6da1d63f3a363b55cbd344baa'
 '27119d9b2167080ec190cb14324769d' '89bd0dff1b386ebed6b30d74fff98ffd'
 '127944b6dabee1c9e20e92c5b8147e4a' '513c3410e8354e5211c7f3807925346a'
 '1bc5d303ff4d6e7e1113901b72a68e7c' 'b98fa11a567f644344b25d683fe71de'
 'a3cd44bbd3ba5b019a4cbf5d3b79df06' 'b815d7e084a5a75b8d543d7713b96a41'
 '645b0e2ef3b95979204df312eabf367f' '599e604a8265cc0a98765d8aa3638e70'
 '2997f21fa426e18a6ab1a25d0e8f3590' 'c34718bd10e378186c6c61abcbd83e5a'
 'b7841572364fd9ce1249ffc39a0c3c0b' '604fcae9d93201d9d7f470ee20dce9e0'
 'e16a895052da87277f58c33b328479f4' '659192a6ba300f1f4293529704725d98'
 '30933679