In [64]:
from tokenized_dataset import load_tokenized_dataset
import numpy as np

def build_masks(x, y):
    # Mask padded x tokens
    x_mask = np.ones((x.shape[0], x.shape[0]))
    x_pad_mask = np.where(x != 0, np.ones((x.shape[0])), 0)
    x_mask = np.multiply(np.multiply(x_mask, x_pad_mask), x_pad_mask[:, None])

    # Mask padded y tokens + add "autoregressive" mask
    y_pad_mask = np.where(y != 0, np.ones((y.shape[0])), 0)
    y_mask = np.tri(y.shape[0], y.shape[0])
    y_mask = np.multiply(np.multiply(y_mask, y_pad_mask), y_pad_mask[:, None])

    # Mask padded yx tokens
    yx_mask = np.ones((y.shape[0], x.shape[0]))
    yx_mask = np.multiply(np.multiply(yx_mask, x_pad_mask), y_pad_mask[:, None])

    return x_mask, y_mask, yx_mask

pad_and_trunc = lambda toks, seq_len: (toks + [0] * (seq_len - len(toks)))[:seq_len]
pack_batch = lambda batch: [np.stack(el) for el in map(list, zip(*batch))]

# Complete batch to batch_size with pad tokens only sequences 
def complete_last_batch(batch, batch_size):
    for _ in range(batch_size - len(batch)):
        x = np.zeros(seq_len)
        y = np.zeros(seq_len)
        x_mask, y_mask, yx_mask = build_masks(x,y)
        batch.append((x, y, x_mask, y_mask, yx_mask))
    return batch
    

def get_batched_examples(ds, batch_size, seq_len, split="train", skip_n_rows=None):
    ds_split = ds[split].skip(skip_n_rows) if skip_n_rows is not None else ds[split]

    batch = []
    for item in ds_split:
        x = np.array(pad_and_trunc(item['x'], seq_len))
        y = np.array(pad_and_trunc(item['y'], seq_len))
        x_mask, y_mask, yx_mask = build_masks(x,y)
        batch.append((x, y, x_mask, y_mask, yx_mask))
        if len(batch) == batch_size:
            yield pack_batch(batch)
            batch = []
    if split!="train" and len(batch) > 0: # Note I don't use last few rows left in train split..
        yield pack_batch(complete_last_batch(batch, batch_size))

ds, _ = load_tokenized_dataset()

x, y, x_mask, y_mask, yx_mask = next(get_batched_examples(ds, 10, 10))
print(x)
print(y)

Loading dataset
Loading tokenizer bpe_tokenizer_ds_train_all_merges_35k.pickle
Tokenizing dataset
[[23954 12114  3613  3586 12466     0     0     0     0     0]
 [ 3667 21876 28867  3586 12466  3613  3586  3960  4795  3751]
 [ 3585 15064  3626  3829  3850  3827 31442  3586  4489 11945]
 [ 3829  3827 15275  3612  6044  3599  3745  7010  3594  3586]
 [ 3594  3586 32712  3667  4156  4239  3617 20984  3612  6958]
 [ 5895  4535  3788 11576  3687  3745  6958  3655  4435  3571]
 [ 9339  5544 19234  3614 23218  3612  6958  3655  4435  3571]
 [ 6698  4468  3599  3612  5318  3613 16215     0     0     0]
 [ 3829  3850  3743  8528  3857  3586  6314  3614 13569  3686]
 [ 3943  3613  3586  4641 31869  4017  4043  9155  3594 22023]]
[[31352  3606 29833     0     0     0     0     0     0     0]
 [ 3638 37931  3610  3979 36357  3787 15348  9617 19088 27915]
 [ 3869  3725 26275 24456  3732  3606  3619 23742  7960    35]
 [ 3755  4794  6502  3606 10816  4052  3884  7320  3755 16943]
 [ 4910  4547  3638

In [69]:
from tokenized_dataset import load_tokenized_dataset
import numpy as np

def get_batched_examples_packed(ds, batch_size, seq_len, start_tok, end_tok, pack_frac = 0.5, split="train", skip_n_rows=None):
    assert split=="train"
    ds_split = ds[split].skip(skip_n_rows) if skip_n_rows is not None else ds[split]
    
    batch = []
    for item in ds_split:
        # Either append to previous batch item or create new one
        if len(batch)>0 and (len(batch[-1][0]) < seq_len * pack_frac and len(batch[-1][1]) < seq_len * pack_frac):
            pack_func = lambda x, y: x + [end_tok, start_tok] + y
            batch[-1] = (pack_func(batch[-1][0], item['x']), pack_func(batch[-1][1], item['y']))
        else:
            batch.append((item['x'],item['y']))
            
        if len(batch) == batch_size:
            def convert_batch_item(x, y):
                x = np.array(pad_and_trunc(x, seq_len))
                y = np.array(pad_and_trunc(y, seq_len))
                x_mask, y_mask, yx_mask = build_masks(x,y)
                return (x, y, x_mask, y_mask, yx_mask)
            batch = [convert_batch_item(*it) for it in batch]
            print(len(batch[0][0]))
            print(len(batch[1][0]))
            
            yield pack_batch(batch)
            batch = []
            
    if split!="train" and len(batch) > 0: # Note I don't use last few rows left in train split..
        yield pack_batch(complete_last_batch(batch, batch_size))

ds, (_, _, tokenizer_vocab_size) = load_tokenized_dataset()
START_TOK = tokenizer_vocab_size + 1
END_TOK = tokenizer_vocab_size + 2

x, y, x_mask, y_mask, yx_mask = next(get_batched_examples_packed(ds, 2, 50, START_TOK, END_TOK, pack_frac=0.75))
print(x)
print(y)

Loading dataset
Loading tokenizer bpe_tokenizer_ds_train_all_merges_35k.pickle
Tokenizing dataset
50
50
[[23954 12114  3613  3586 12466 38560 38559  3667 21876 28867  3586 12466
   3613  3586  3960  4795  3751  7291  5434  3599 20578  9201 10178 20119
   3614  3667  4081  4239  6204  6664  3617  6454  3829  3612  9825  4161
   5567  3594  3586  6096  3686  3829 16708  3612 11638  6807  4110 17835
      0     0]
 [ 3585 15064  3626  3829  3850  3827 31442  3586  4489 11945    34 21147
  16112  4435 13814  3617  5324  4351  3788  5221  3586  4641  3594  3612
   5239  3613  4924 18161  3612  8545  3613  7082 23814  3686 12332  4424
   4489  3751 16858     0     0     0     0     0     0     0     0     0
      0     0]]
[[31352  3606 29833 38560 38559  3638 37931  3610  3979 36357  3787 15348
   9617 19088 27915 29833  3730  4157  6131  3704 22472 12444 10706  4521
  14141  6282  6715  4083 10585 16869  3618  8898  4213  3725 10838 21975
  18117     0     0     0     0     0     0     0  