In [1]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.95

import jax
from jax import jit
from t_model_new import *

import jax.numpy as jnp
from jax import grad, jit, vmap, lax 
from jax import random

LAYERS = 6
model_vocab_size = 38561
START_TOK = model_vocab_size - 2
END_TOK = model_vocab_size - 1
EMB_DIM=512
FFN_DIM=2048
NUM_HEADS = 8
params = init_transformer(model_vocab_size, EMB_DIM, LAYERS, NUM_HEADS, FFN_DIM, random.PRNGKey(0))

print(f'Vocabulary size: {model_vocab_size}')
num_params = sum([jnp.size(p_leaf) for p_leaf in jax.tree_util.tree_leaves(params)])
print(f'Number of params: {num_params}')

def avg_cross_entropy_loss(y_labels, x_logits): # y_labels: batch_len x seq_len, x_logits: batch_len x seq_len x vocab_size
    y_labels_1d = jnp.reshape(y_labels, -1)
    x_logits_2d = jnp.reshape(x_logits, (y_labels.size, -1))
    elements_loss = log_softmax(x_logits_2d)[(jnp.arange(y_labels.size), y_labels_1d)]
    elements_loss = jnp.where(y_labels_1d != 0, elements_loss, jnp.nan) # account for padding tokens
    result = -jnp.nanmean(elements_loss) 
    return result, jnp.count_nonzero(y_labels)
    
def accuracy(y_labels, x_logits):
    return jnp.nanmean(jnp.where(y_labels!=0, y_labels == jnp.argmax(x_logits, axis=-1), jnp.nan))
    
def loss(params, x, y, key, train):  # inputs: batch_size x seq_len
    y_lens = jnp.count_nonzero(y, axis=1)
    # It's possible that there are no padding tokens, and we will go out of boundary, hence the use of "drop" mode
    y = y.at[jnp.arange(y.shape[0]), y_lens].set(END_TOK, mode="drop") 
    y = y.at[:,0].set(jnp.where(x[:,0]!=END_TOK, x[:,0], 0)) # Account for possible empty sequences (which are used for in-complete batches)
    
    start_toks = jnp.full((y.shape[0], 1), START_TOK) 
    shifted_y = jnp.concatenate((start_toks, y[:,:-1]), axis=1) 
    
    # TODO: write it without copying memory? is it possible? 
    logits = batched_forward(params, x, shifted_y, key, train) 
    loss_val, tokens_count = avg_cross_entropy_loss(y, logits)
    acc = accuracy(y, logits) 
    #return loss_val, (loss_val, acc, tokens_count/jnp.size(y)) # TODO: this is wrapping, but we could make use of jax.value_and_grad instead
    return loss_val, (acc, tokens_count/jnp.size(y)) # TODO: this is wrapping, but we could make use of jax.value_and_grad instead

batch_size = 512
seq_len = 50
test_x  = random.randint(random.PRNGKey(0), (batch_size, seq_len), 0, model_vocab_size)
test_y  = random.randint(random.PRNGKey(0), (batch_size, seq_len), 0, model_vocab_size)

import time
from jax import grad, value_and_grad
from functools import partial
loss_train = partial(loss, train=True)
grad_loss_train = grad(loss_train, has_aux=True)
value_and_grad_loss_train = value_and_grad(loss_train, has_aux=True)

with jax.profiler.trace("/lego/storage/output/"):
    #result = jit(loss_train)(params, test_x, test_y, random.PRNGKey(0)) # No spike at the end
    #result = jit(grad_loss_train)(params, test_x, test_y, random.PRNGKey(0))[1] # Spike at the end
    result = jit(value_and_grad_loss_train)(params, test_x, test_y, random.PRNGKey(0))[0] # Spike at the end
    result = jax.block_until_ready(result)
    time.sleep(30)
    print(f'result {result}')
    time.sleep(30)
    print(f'The end')
    #print(jit(loss_train)(params, test_x, test_y, random.PRNGKey(0))[0])
    #print(len(jit(grad_loss_train)(params, test_x, test_y, random.PRNGKey(0))[0])) # Requires massive amount of memory at the end??

#jitted_test_proj_fwd = jit(test_proj_fwd)

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.95
Vocabulary size: 38561
Number of params: 63883425




result (Array(10.57429, dtype=float32), (Array(7.813416e-05, dtype=float32), Array(0.9998828, dtype=float32)))
The end


In [20]:
# TODO: Check whether grad gives spike at the end on a toy example
import jax
import jax.numpy as jnp
from jax import jit
from jax import random

model_vocab_size = 38561
EMB_DIM = 512
batch_size = 125 #250
seq_len = 100
test_x  = random.normal(random.PRNGKey(0), (batch_size*seq_len, EMB_DIM))
test_y = random.randint(random.PRNGKey(0), (batch_size*seq_len, ), 0, model_vocab_size)
params = random.normal(random.PRNGKey(0), (EMB_DIM, model_vocab_size))

def test_matmul(params, x): 
    return jnp.matmul(x, params)

from t_model_new import log_softmax

def test_avg_cross_entropy_loss(y_labels_1d, x_logits_2d):
    elements_loss = log_softmax(x_logits_2d)[(jnp.arange(y_labels_1d.size), y_labels_1d)]
    return -jnp.nanmean(elements_loss) 

def test_loss(params, x, y):
    x_logits = test_matmul(params, x)
    return test_avg_cross_entropy_loss(y, x_logits)

import time
from jax import make_jaxpr
from jax import grad, value_and_grad

#print(make_jaxpr(test_matmul)(params, test_x))
#print(make_jaxpr(test_loss)(params, test_x, test_y))
#print(make_jaxpr(grad(test_loss))(params, test_x, test_y))
#raise Exception("end")

with jax.profiler.trace("/lego/storage/output/"):
    #result = jit(test_matmul)(params, test_x) # No spike
    #result = jit(test_loss)(params, test_x, test_y) # No spike
    result = value_and_grad(test_loss)(params, test_x, test_y)[0] # Run on reduced batch_size
    #result = jit(value_and_grad(test_loss))(params, test_x, test_y)[0] # Spike (not as big as full model)
    result = jax.block_until_ready(result)
    #time.sleep(30)
    print(f'result {result}')
    #time.sleep(30)
    #print(f'The end')

result 94.72313690185547


In [2]:
# TODO: investigate why GEMM takes 2x matrix.size space
import jax
import jax.numpy as jnp
from jax import jit
from jax import random

model_vocab_size = 38561
EMB_DIM = 512
batch_size = 250
seq_len = 100
test_x  = random.normal(random.PRNGKey(0), (batch_size*seq_len, EMB_DIM))
params = random.normal(random.PRNGKey(0), (EMB_DIM, model_vocab_size))

def test_matmul(params, x): 
    return jnp.matmul(x, params)

with jax.profiler.trace("/lego/storage/output/"):
    print(jit(test_matmul)(params, test_x).shape)

(25000, 38561)
