In [None]:
import os
from absl import flags
from ml_collections import config_flags
import numpy as np
import tqdm
from data.synthetic import utils

In [2]:
NUM_SAMPLES=10000000
BATCH_SIZE=200
DATA_NAME="checkerboard" #OPTIONS: swissroll, circles, moons, 8gaussians, pinwheel, 2spirals, checkerboard, line, cos 
DATA_ROOT=f'data/synthetic/{DATA_NAME}'

In [3]:
_CONFIG = config_flags.DEFINE_config_file('data_config', lock_config=False)
flags.DEFINE_integer('num_samples', NUM_SAMPLES, 'num samples to be generated')
flags.DEFINE_integer('batch_size', BATCH_SIZE, 'batch size for datagen')
flags.DEFINE_string('data_root', DATA_ROOT, 'root folder of data')

FLAGS = flags.FLAGS
FLAGS(['generate_data','--data_config=data/synthetic/data_config.py'])

['generate_data']

In [4]:
if not os.path.exists(FLAGS.data_root):
    os.makedirs(FLAGS.data_root)
data_config = _CONFIG.value
data_config.data_name=DATA_NAME
print(data_config)

binmode: gray
data_name: checkerboard
discrete_dim: 32
int_scale: -1.0
plot_size: -1.0
vocab_size: 2



In [5]:
db, bm, inv_bm = utils.setup_data(data_config)

with open(os.path.join(FLAGS.data_root, 'config.yaml'), 'w') as f:
    f.write(data_config.to_yaml())

remapping binary repr with gray code
f_scale, 4.999530215205527 int_scale, 5461.760975376213


In [None]:
data_list = []
for _ in tqdm.tqdm(range(FLAGS.num_samples // FLAGS.batch_size)):
    data = utils.float2bin(db.gen_batch(FLAGS.batch_size), bm,
                           data_config.discrete_dim, data_config.int_scale)
    data_list.append(data.astype(bool))
data = np.concatenate(data_list, axis=0)
print(data.shape[0], 'samples generated')
save_path = os.path.join(FLAGS.data_root, 'data.npy')
with open(save_path, 'wb') as f:
    np.save(f, data)

with open(os.path.join(FLAGS.data_root, 'samples.pdf'), 'wb') as f:
    float_data = utils.bin2float(data[:1000].astype(np.int32), inv_bm,
                                 data_config.discrete_dim,
                                 data_config.int_scale)
    utils.plot_samples(float_data, f, im_size=4.1, im_fmt='pdf')

100%|██████████| 50000/50000 [02:25<00:00, 344.14it/s]


10000000 samples generated
