In [17]:
from tokenized_dataset import load_tokenized_dataset
import numpy as np

def build_masks(x, y, x_packs=None, y_packs=None, yx_packs=None): # x: seq_len, y: seq_len
    # Mask padded x tokens
    x_mask = np.ones((x.shape[0], x.shape[0]))
    x_pad_mask = np.where(x != 0, np.ones((x.shape[0])), 0)
    x_mask = np.multiply(np.multiply(x_mask, x_pad_mask), x_pad_mask[:, None])
    if x_packs is not None:
        x_mask = np.multiply(x_mask, x_packs)

    # Mask padded y tokens + add "autoregressive" mask
    y_pad_mask = np.where(y != 0, np.ones((y.shape[0])), 0)
    y_mask = np.tri(y.shape[0], y.shape[0])
    y_mask = np.multiply(np.multiply(y_mask, y_pad_mask), y_pad_mask[:, None])
    if y_packs is not None:
        y_mask = np.multiply(y_mask, y_packs)

    # Mask padded yx tokens
    yx_mask = np.ones((y.shape[0], x.shape[0]))
    yx_mask = np.multiply(np.multiply(yx_mask, x_pad_mask), y_pad_mask[:, None])
    if yx_packs is not None:
        yx_mask = np.multiply(yx_mask, yx_packs)

    return x_mask, y_mask, yx_mask

pad_and_trunc = lambda toks, seq_len: (toks + [0] * (seq_len - len(toks)))[:seq_len]
pack_batch = lambda batch: [np.stack(el) for el in map(list, zip(*batch))]

# Complete batch to batch_size with pad tokens only sequences 
def complete_last_batch(batch, batch_size):
    for _ in range(batch_size - len(batch)):
        x = np.zeros(seq_len)
        y = np.zeros(seq_len)
        x_mask, y_mask, yx_mask = build_masks(x,y)
        batch.append((x, y, x_mask, y_mask, yx_mask))
    return batch

def ones_block_diag(lens_dim0, lens_dim1):
    result = np.zeros((sum(lens_dim0), sum(lens_dim1)))
    start_ind_dim0 = 0
    start_ind_dim1 = 0
    for len_dim0, len_dim1 in zip(lens_dim0, lens_dim1):
        result[start_ind_dim0:start_ind_dim0+len_dim0, start_ind_dim1:start_ind_dim1+len_dim1] = np.ones((len_dim0, len_dim1))
        start_ind_dim0 +=len_dim0
        start_ind_dim1 +=len_dim1
    return result

def create_packs(x_lens, y_lens):
    x_packs = ones_block_diag(x_lens, x_lens)
    y_packs = ones_block_diag(y_lens, y_lens)
    yx_packs = ones_block_diag(y_lens, x_lens)
    return x_packs, y_packs, yx_packs
    
def convert_batch_item(x, y, seq_len, x_lens=None, y_lens=None):
    x = np.array(pad_and_trunc(x, seq_len))
    y = np.array(pad_and_trunc(y, seq_len))

    # Account for packing
    if x_lens is not None:
        assert y_lens is not None
        x_packs, y_packs, yx_packs = create_packs(x_lens, y_lens)
    else:
        x_packs, y_packs, yx_packs = None, None, None
        
    x_mask, y_mask, yx_mask = build_masks(x,y, x_packs, y_packs, yx_packs)
    return (x, y, x_mask, y_mask, yx_mask)
                
def get_batched_examples(ds, batch_size, seq_len, split="train", skip_n_rows=None):
    ds_split = ds[split].skip(skip_n_rows) if skip_n_rows is not None else ds[split]

    batch = []
    for item in ds_split:
        batch.append(convert_batch_item(item['x'], item['y'], seq_len))
        if len(batch) == batch_size:
            yield pack_batch(batch)
            batch = []
    if split!="train" and len(batch) > 0: # Note I don't use last few rows left in train split..
        yield pack_batch(complete_last_batch(batch, batch_size))

ds, _ = load_tokenized_dataset()

x, y, x_mask, y_mask, yx_mask = next(get_batched_examples(ds, 2, 10))
print(x)
print(y)

Loading dataset
Loading tokenizer bpe_tokenizer_ds_train_all_merges_35k.pickle
Tokenizing dataset
[[23954 12114  3613  3586 12466     0     0     0     0     0]
 [ 3667 21876 28867  3586 12466  3613  3586  3960  4795  3751]]
[[31352  3606 29833     0     0     0     0     0     0     0]
 [ 3638 37931  3610  3979 36357  3787 15348  9617 19088 27915]]


In [18]:
def get_batched_examples_packed(ds, batch_size, seq_len, start_tok, end_tok, pack_frac = 0.5, split="train", skip_n_rows=None):
    assert split=="train"
    ds_split = ds[split].skip(skip_n_rows) if skip_n_rows is not None else ds[split]
    
    batch = []
    batch_x_lens = []
    batch_y_lens = []
    for item in ds_split:
        # Either append to previous batch item or create new one
        if len(batch)>0 and (len(batch[-1][0]) < seq_len * pack_frac and len(batch[-1][1]) < seq_len * pack_frac):
            pack_func = lambda x, y: x + [end_tok, start_tok] + y
            batch[-1] = (pack_func(batch[-1][0], item['x']), pack_func(batch[-1][1], item['y']))
            #batch_x_lens[-1].append() # TODO
        else:
            batch.append((item['x'],item['y']))
            batch_x_lens.append([len(item['x'])])
            batch_y_lens.append([len(item['y'])])
            
        if len(batch) == batch_size:
            batch = [convert_batch_item(*it, seq_len, x_lens, y_lens) for it, x_lens, y_lens in zip(batch, batch_x_lens, batch_y_lens)]
            
            yield pack_batch(batch)
            batch = []
            
    if split!="train" and len(batch) > 0: # Note I don't use last few rows left in train split..
        yield pack_batch(complete_last_batch(batch, batch_size))

ds, (_, _, tokenizer_vocab_size) = load_tokenized_dataset()
START_TOK = tokenizer_vocab_size + 1
END_TOK = tokenizer_vocab_size + 2

x, y, x_mask, y_mask, yx_mask = next(get_batched_examples_packed(ds, 2, 10, START_TOK, END_TOK, pack_frac=0.75))
print(x)
print(y)

Loading dataset
Loading tokenizer bpe_tokenizer_ds_train_all_merges_35k.pickle
Tokenizing dataset


ValueError: operands could not be broadcast together with shapes (10,10) (5,5) 

In [9]:
x = np.array([1, 2, 3, 5, 6, 7, 0])
y = np.array([11, 12, 13, 15, 16, 0, 0])
print("INPUT")
print(x)
print(y)

print("\nPACKS")
x_packs, y_packs, yx_packs = create_packs([3, 3, 1], [3, 2, 2])
print(x_packs)
print(y_packs)
print(yx_packs)

print("\nMASKS")
x_mask, y_mask, yx_mask = build_masks(x, y, x_packs=x_packs, y_packs=y_packs, yx_packs = yx_packs)
print(x_mask)
print(y_mask)
print(yx_mask)

INPUT
[1 2 3 5 6 7 0]
[11 12 13 15 16  0  0]

PACKS
[[1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 1. 1. 0.]
 [0. 0. 0. 1. 1. 1. 0.]
 [0. 0. 0. 1. 1. 1. 0.]
 [0. 0. 0. 0. 0. 0. 1.]]
[[1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 1. 0. 0.]
 [0. 0. 0. 1. 1. 0. 0.]
 [0. 0. 0. 0. 0. 1. 1.]
 [0. 0. 0. 0. 0. 1. 1.]]
[[1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 1. 1. 0.]
 [0. 0. 0. 1. 1. 1. 0.]
 [0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1.]]

MASKS
[[1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 1. 1. 0.]
 [0. 0. 0. 1. 1. 1. 0.]
 [0. 0. 0. 1. 1. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
[[1. 0. 0. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
[[1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 1. 1. 0.]


In [3]:
import jax.numpy as jnp
from jax import random

def pos_encodings(x, indices): # input: seq_len x emb_dim
    seq_len, emb_dim = x.shape

    #indices = jnp.arange(seq_len)[:, None]
    print(indices.shape)
    div_term = jnp.fromfunction(lambda i: 1 / pow(10000, 2 * i/emb_dim), (int(emb_dim/2),), dtype=float)[None, :]
    print(div_term.shape)

    print(f'\nPOS_ARRAY')
    pos_array = jnp.dot(indices, div_term)
    print(pos_array.shape)
    print(pos_array)

    print(f'\nSTACKED')
    stacked = jnp.stack((jnp.sin(pos_array), jnp.cos(pos_array)), axis=2)
    print(stacked.shape)
    print(stacked)
    return stacked.reshape(seq_len, emb_dim)

x =random.uniform(random.PRNGKey(0), (7, 4)) # 7 seq_len, 4 emb_dim
#x_indices = jnp.arange(x.shape[0])[:, None]
# TODO: indices need to be inferred from mask OR 
# should we get lens from data collator instead,
# (creating indices from lens is relatively straightforward)
# However, lens have variable size, which doesn't make it perfect
# for JAX
x_indices = jnp.array([0, 1, 2, 0, 1, 2, 3])[:, None] 
print(x_indices)
print(pos_encodings(x, x_indices))

#y =random.uniform(random.PRNGKey(0), (7, 4)) # 7 seq_len, 4 emb_dim
#print(pos_encodings(y))

[[0]
 [1]
 [2]
 [0]
 [1]
 [2]
 [3]]
(7, 1)
(1, 2)

POS_ARRAY
(7, 2)
[[0.   0.  ]
 [1.   0.01]
 [2.   0.02]
 [0.   0.  ]
 [1.   0.01]
 [2.   0.02]
 [3.   0.03]]

STACKED
(7, 2, 2)
[[[ 0.          1.        ]
  [ 0.          1.        ]]

 [[ 0.841471    0.5403023 ]
  [ 0.00999983  0.99995   ]]

 [[ 0.90929747 -0.4161468 ]
  [ 0.01999867  0.9998    ]]

 [[ 0.          1.        ]
  [ 0.          1.        ]]

 [[ 0.841471    0.5403023 ]
  [ 0.00999983  0.99995   ]]

 [[ 0.90929747 -0.4161468 ]
  [ 0.01999867  0.9998    ]]

 [[ 0.14112    -0.9899925 ]
  [ 0.0299955   0.99955004]]]
[[ 0.          1.          0.          1.        ]
 [ 0.841471    0.5403023   0.00999983  0.99995   ]
 [ 0.90929747 -0.4161468   0.01999867  0.9998    ]
 [ 0.          1.          0.          1.        ]
 [ 0.841471    0.5403023   0.00999983  0.99995   ]
 [ 0.90929747 -0.4161468   0.01999867  0.9998    ]
 [ 0.14112    -0.9899925   0.0299955   0.99955004]]
