In [None]:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import pickle
import numpy as np
import scipy.sparse as sp
import tqdm
import random
import numpy as np
import more_itertools
from prepare_dataset.config import SPARSE_DIR, TMP_DIR

%pylab inline

In [None]:
def pack_samples(samples):
    X_coo = sp.vstack([x['mat'] for x in samples])
    idx = np.array([
        [i, x['observed_idx']] for i, x in enumerate(samples)
    ]).astype(np.int32)
    props = np.array([x['propensity'] for x in samples])
    cost = np.array([x['cost'] for x in samples])
    indices = np.mat([X_coo.row, X_coo.col]).transpose()
    
    return tf.SparseTensorValue(indices, np.ones_like(X_coo.data).astype(np.float32), X_coo.shape), idx, cost, props

def read_shard(i):
    path = SPARSE_DIR + '/train_{}.pickled'.format(i)
    return pickle.load(open(path, 'rb'))

# Train batches (from shard 0-7)

In [None]:
train_batches = []
batch_size = 512

for i in tqdm.tqdm_notebook(range(0, 8)):
    curr_ds = [x for x in read_shard(i) if x['n_candidates']==11]
    pos_samples = [x for x in curr_ds if x['cost'] < 0.5]
    neg_samples = [x for x in curr_ds if x['cost'] > 0.5]
    for it in range(3):
        curr_train_pool = pos_samples + list(random.choice(neg_samples, len(pos_samples)))
        np.random.seed(42 + i*1000 + it)
        np.random.shuffle(curr_train_pool)
        batches = list(map(pack_samples, more_itertools.chunked(curr_train_pool, batch_size)))
        train_batches.extend(batches)
    del curr_ds
print('Train: {} bathes (x {} samples)'.format(len(train_batches), batch_size))
# > Train: 3162 bathes (x s512 samples)
pickle.dump(train_batches, open(TMP_DIR +  '/train_batches.pickled', 'wb'))

# Validation set (shards 8-11)

In [None]:
ds_va = []
for i in tqdm.tqdm_notebook(range(8, 12)):
    ds_va.extend([x for x in read_shard(i) if x['n_candidates']==11])

ctr = np.mean([x['cost'] < 0.5 for x in ds_va])
valid_pack = pack_samples([x for x in ds_va if x['cost'] < 0.5])
pickle.dump(valid_pack + (ctr,), open(TMP_DIR + '/valid_pack.pickled', 'wb'))
del ds_va

# Holdout set (shards 12-15)

In [None]:
ds_ho = []
for i in tqdm.tqdm_notebook(range(12, 16)):
    ds_ho.extend([x for x in read_shard(i) if x['n_candidates']==11])

ctr = np.mean([x['cost'] < 0.5 for x in ds_ho])
holdout_pack = pack_samples([x for x in ds_ho if x['cost'] < 0.5])
pickle.dump(holdout_pack + (ctr,), open(TMP_DIR +  '/holdout_pack.pickled', 'wb'))
del ds_ho