In [1]:
### JAX

# UPDATE/TODO XXX: We can now move to jax24.04-py3 (https://docs.nvidia.com/deeplearning/frameworks/jax-release-notes/rel-24-04.html)
# TODO: this is slightly faster even with the warning -> invewstigate (current jax version is 0.4.26, where the image has 0.4.17)
#! pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#2024-05-02 08:16:04.763248: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] 
#The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). 
#Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. 
#You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

# TODO: It looks like I am suffering from fragmentation on GPU, thus enabling prelocation
# Disable JAX memory preallocation
#import os
#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".90"
#%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.95

#!LD_LIBRARY_PATH=/usr/local/cuda/compat:$LD_LIBRARY_PATH
import jax
jax.devices()

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.95


[cuda(id=0)]

In [2]:
### DATASETs # TODO XXX: set to 5%, and 0.02 for nightly
import datasets
from tokenized_dataset import load_tokenized_dataset_gpt2, load_tokenized_dataset_hellaswag, unpack_hellaswag_x, get_batched_examples, get_batched_examples_packed
ds, (tokenize, detokenize, tokenizer_vocab_size) = load_tokenized_dataset_gpt2("train[:1%]") #:1% or :1000
ds = ds.train_test_split(test_size=0.1, seed=42) # TODO: put seed in better place? does it mess up with resume_from_checkpoint logic?
ds = datasets.DatasetDict({
    'train': ds['train'],
    'validation': ds['test'] #rename
})
print(ds)

# Some stats on HellaSwag. Given the tokenicer: 
# Max len of concatenated y+longest choice is 149
# Max sum of choices tokens lens is 263 (Important for flattening choices in x + seq_len param for data collactor)
hellaswag_ds = load_tokenized_dataset_hellaswag(tokenize)
print(hellaswag_ds)

# Tests:
# item = next(x for x in hellaswag_ds)
# print(item)
# print(detokenize((item['y'],)))
# item_x = item['x']
# choices, label = unpack_hellaswag_x(item['x'])
# print(detokenize(choices)) # TODO XXX: one of chocies has ", while others have '. Is it anything serious?
# print(label)

Loading FineWeb-Edu dataset


Resolving data files:   0%|          | 0/1630 [00:00<?, ?it/s]

Loading tokenizer bpe_tokenizer_fineweb-edu_sample-10BT_100k_ds_merges_30k.pickle
HotFix: Filter out items containing out-of-vocabulary words
Tokenizing dataset
DatasetDict({
    train: Dataset({
        features: ['x', 'y'],
        num_rows: 87048
    })
    validation: Dataset({
        features: ['x', 'y'],
        num_rows: 9673
    })
})
Loading HellaSwag dataset
Tokenizing dataset
Dataset({
    features: ['x', 'y'],
    num_rows: 10042
})


In [3]:
### MODEL

from model import *
import jax.numpy as jnp
from jax import grad, jit, vmap, lax 
from jax import random

LAYERS = 12
model_vocab_size = tokenizer_vocab_size + 3 # add padding token (0) + start of sequence token + end of sequence token 
START_TOK = tokenizer_vocab_size + 1
END_TOK = tokenizer_vocab_size + 2 # TODO: in standard LLM convention, it should be 1. Also, it could be part of tokenizer_vocab_size
EMB_DIM=768 #512
FFN_DIM=3072 #2048
NUM_HEADS = 12 #8
params = init_transformer_gpt2like(model_vocab_size, EMB_DIM, LAYERS, NUM_HEADS, FFN_DIM, random.PRNGKey(0))

print(f'Vocabulary size: {model_vocab_size:_}')
num_params = sum([jnp.size(p_leaf) for p_leaf in jax.tree_util.tree_leaves(params)])
print(f'Number of params: {num_params:_}')

Vocabulary size: 35_374
Number of params: 112_220_206


In [4]:
### Loss + grads
#def one_hot(x, k, dtype=jnp.float32): 
#    """Create a one-hot encoding of x of size k.""" 
#    return jnp.array(x[:, None] == jnp.arange(k), dtype)
#
#batched_one_hot = vmap(one_hot, in_axes=(0, None))

def avg_cross_entropy_loss(y_labels, x_logits): # y_labels: batch_len x seq_len, x_logits: batch_len x seq_len x vocab_size
    # Note that in jax, un-jitted reshape calls are producing copies of array instead of views.
    # However, for jitted, this SHOULD be optmized away (I checked this function that indeed it is).
    y_labels_1d = jnp.reshape(y_labels, -1) # there is probably a way of doing it while staying in 2d..
    x_logits_2d = jnp.reshape(x_logits, (y_labels.size, -1))
    elements_loss = log_softmax(x_logits_2d)[(jnp.arange(y_labels.size), y_labels_1d)]
    elements_loss = jnp.where(y_labels_1d != 0, elements_loss, jnp.nan) # account for padding tokens
    result = -jnp.nanmean(elements_loss) 
    return result, jnp.count_nonzero(y_labels)
    
def accuracy(y_labels, x_logits):
    return jnp.nanmean(jnp.where(y_labels!=0, y_labels == jnp.argmax(x_logits, axis=-1), jnp.nan))

from functools import partial

@partial(jax.jit, static_argnames=['sample_len']) # TODO XXX: don't pass y_mask nor y_indices (pass batch_size though!)
def predict(params, y_mask, y_indices, sample_len): # TODO: code up not-scanned version, which could be faster on GPU
    def predict_step(step_i, y):
        # TODO: Cache key-value pairs
        new_y = batched_forward_gpt2like(params, y, y_mask, y_indices, random.PRNGKey(0), False) 
        new_toks = jnp.argmax(new_y[:, step_i], axis=-1)
        y = y.at[:,step_i+1].set(new_toks)
        return y
    
    start_toks = jnp.full((y_mask.shape[0], sample_len), START_TOK)
    y_sample = jax.lax.fori_loop(0, sample_len, predict_step, start_toks) 
    y_sample = jnp.where(jax.lax.cummax(y_sample, axis=1) != END_TOK, y_sample, 0) # replace END token, and what follows with padding

    y_sample = y_sample[:, 1:]
    return jnp.where(y_sample!=START_TOK, y_sample, 0) # It should not be happening, but for random model it might.2

def loss(params, y, y_mask, y_indices, key, train):  # inputs: batch_size x seq_len
    y_in = y[:, :-1]
    y_out = y[:, 1:]
    
    # TODO: write it without copying memory? is it possible? 
    logits = batched_forward_gpt2like(params, y_in, y_mask, y_indices, key, train) 
    loss_val, tokens_count = avg_cross_entropy_loss(y_out, logits)
    acc = accuracy(y_out, logits) # TODO: Do I need to stop_gradient on this? I think not, but double-check
    return loss_val, (loss_val, acc, tokens_count/jnp.size(y_out)) # TODO: this is wrapping, but we could make use of jax.value_and_grad instead

loss_train = partial(loss, train=True)
loss_eval = jit(partial(loss, key=random.PRNGKey(0), train=False))

grad_loss = jit(grad(loss_train, has_aux=True))
#grad_loss = grad(loss_train, has_aux=True)

#print(f'iter #{i} loss {loss_train(params, jnp.array(x[:1]), jnp.array(y[:1]), random.PRNGKey(0))[0] }')

#with jax.disable_jit():
#print(f'iter #{i} loss {predict(params, jnp.array(x[:2], 50)) }')

# TODO XXX XXX: write some test for it?
@jit #TODO XXX: take y_prefix_len which not to ignore probs for?
def log_probs(params, y, y_mask, y_indices):  # inputs: batch_size x seq_len
    y_in = y[:, :-1]
    y_out = y[:, 1:]

    # copied a few lines from avg_cross_entropy_loss # TODO XXX XXX: reuse instead!
    def compute_log_probs(y_labels, x_logits): # y_labels: batch_len x seq_len, x_logits: batch_len x seq_len x vocab_size
        y_labels_1d = jnp.reshape(y_labels, -1) # there is probably a way of doing it while staying in 2d..
        x_logits_2d = jnp.reshape(x_logits, (y_labels.size, -1))
        elements_loss = log_softmax(x_logits_2d)[(jnp.arange(y_labels.size), y_labels_1d)]
        elements_loss = jnp.where(y_labels_1d != 0, elements_loss, 1) # account for padding tokens
        elements_loss_2d = jnp.reshape(elements_loss, (x_logits.shape[0], x_logits.shape[1]))
        y_log_probs = jnp.sum(elements_loss_2d, axis=1)
        return y_log_probs
    
    # TODO: write it without copying memory? is it possible? 
    logits = batched_forward_gpt2like(params, y_in, y_mask, y_indices, random.PRNGKey(0), False) 
    return compute_log_probs(y_out, logits)

In [5]:
### Optimizers

# TODO: any call to this function can be replaced by jax's tree_map
def elwise(params_and_grads, func): # generically applying func element-wise
    return [ [ func(*p_and_g) for p_and_g in zip(*p_and_g_items)] for p_and_g_items in zip(*params_and_grads)]

def sgd(params, grads, lr):
    return elwise((params, grads), lambda p,g: p - lr * g)

@jit
def adam(params, grads, lr, betas, epsilon, moments, i):
    t = i + 1 # TODO: should we decuple iteration from t, and threading t instead?
    moments = [elwise((moment, grads), lambda m, g: b*m + (1-b) * pow(g, pow_g)) for b, moment, pow_g in zip(betas, moments, [1,2])]
    bias_corrected_moments = [elwise((moment,), lambda m: m / (1 - pow(b,t))) for b, moment in zip(betas, moments)]
    params = elwise((params, *bias_corrected_moments), lambda p,m,v: p - lr * m / (jnp.sqrt(v) + epsilon))
    return params, moments

from functools import partial
@partial(jax.jit, donate_argnames=("params","moments"))
def adam_in_place(params, grads, lr, betas, epsilon, moments, i):
    t = i + 1 # TODO: should we decuple iteration from t, and threading t instead?

    # TODO: once write it more effiently, combine both loops + vmap (if possible)?
    
    # update moments
    for b, moment, pow_g in zip(betas, moments, [1,2]): 
        for grp_i in range(len(grads)):
            for p_i in range(len(grads[grp_i])):
                moment[grp_i][p_i] = moment[grp_i][p_i].at[:].multiply(b)
                moment[grp_i][p_i] = moment[grp_i][p_i].at[:].add((1-b) * pow(grads[grp_i][p_i], pow_g))

    # update grads
    for grp_i in range(len(grads)):
        for p_i in range(len(grads[grp_i])):
            bias_correct_func = lambda b, m: m / (1 - pow(b,t))
            m = bias_correct_func(betas[0], moments[0][grp_i][p_i])
            v = bias_correct_func(betas[1], moments[1][grp_i][p_i])
            params[grp_i][p_i] =  params[grp_i][p_i].at[:].add(-lr * m / (jnp.sqrt(v) + epsilon))
            
    return params, moments

# Testing Adam in place:
#original_pointer = params[0][0].unsafe_buffer_pointer()
#params, moments = adam_in_place(params, grads, lr, betas, epsilon, moments, i)
#assert params[0][0].unsafe_buffer_pointer() == original_pointer # will not fail
#params, moments = adam(params, grads, lr, betas, epsilon, moments, i)
#assert params[0][0].unsafe_buffer_pointer() == original_pointer # will fail

In [6]:
### Infra utils
def print_mem_stats():
    mem_stats = jax.devices()[0].memory_stats()
    conv = lambda k: mem_stats[k] / pow(1000,3)
    print(f'GB in use: {conv("bytes_in_use")}. GB limit: {conv("bytes_limit")}')

import wandb

# start a new wandb run to track this script
if True:
    wandb.init(
        # set the wandb project where this run will be logged
        project="t",
    
        # track hyperparameters and run metadata
        #config={
        #"learning_rate": 0.02,
        #"architecture": "CNN",
        #"dataset": "CIFAR-100",
        #"epochs": 10,
        #}
        sync_tensorboard=True
    )

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
2024-11-12 10:59:31.133327: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1731409171.152126     949 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1731409171.158056     949 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[34m[1mwandb[0m: Currently logged in as: [33mmkukla[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
## Training loop
import datetime
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import itertools
import pickle
import evaluate
import numpy as np # should we get rid of it?
import math

# Infra training params
run_name = datetime.datetime.now().strftime("%h%d_%H-%M-%S")
# Since implementation of gradient accumulation is very primitve, we need logging & checkpoint steps params
# to be multiplication of gradient_accumulations_steps. 
# TODO: Introduce effective step (conditioned on accumulation steps), and do logging/checkpoint in respect to effective  step
log_every_steps = 16
eval_every_steps = 4000 #500 * 8 machines
eval_n_examples = 4
writer = SummaryWriter(f'/lego/storage/output/runs/{run_name}')
#checkpoint_every_steps = None #500 * 8 machines
checkpoint_every_steps = 4000 #20000
resume_from_checkpoint = None
#resume_from_checkpoint = 'runs/Jun07_10-12-10/checkpoint_4000.pkl' # TODO: Confirm runs from checkpoints are still fully reproducible


# ML training params
key_training = random.PRNGKey(0) 
batch_size= 16 #64 #128 #512 #416 # TODO: Investigate OOMs when 496? #512
gradient_accumulations_steps = 4 # to imitate paper's 8 devices
num_steps = 100000 #800000 # AIAYN paper's 100k steps *  8 devices # TODO XXX: think what it should be for GPT1/2
max_lr = 0.00025 # Effectively ignored if lr scheduler is used (i.e. warmup_steps is set to something else than None)
warmup_steps= 2000 #8000
betas = (0.9, 0.98) 
epsilon = 10e-9
moments = [elwise((params,), lambda p: jnp.zeros_like(p)) for _ in range(2)] # moment esimtates
seq_len = 512 #200 #50 # TODO: 124 is maximum length in validation dataset. 
hellaswag_seq_len = 300
x_tokens_per_batch = 15000 #For variable batch len, we don't use it as we can fit less data (paper does 25k)

# TODO XXX: remove below one
_, _, _, y_eval_mask, _, _, y_eval_indices  = next(get_batched_examples(ds, eval_n_examples, seq_len, START_TOK, END_TOK, "validation")) 
    
i = 0 
ds_train_rows_read = 0
if resume_from_checkpoint is not None:
    with open(resume_from_checkpoint,'rb') as f:
        i, ds_train_rows_read, params, moments, key_training = pickle.load(f)   
        print(f'Resuming training from the checkpoint: i {i} ds_train_rows_read {ds_train_rows_read}')

num_params = sum([jnp.size(p_leaf) for p_leaf in jax.tree_util.tree_leaves(params)])
print(f'Number of params: {num_params:_}')
grads = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), params)

from functools import partial
@partial(jax.jit, donate_argnames=("acc_grads"))
def acc_grad_loss(acc_grads, params, y, y_mask, y_indices, key_iter):
    i_step_grads, grad_loss_rest = grad_loss(params, y, y_mask, y_indices, key_iter)
    
    for grp_i in range(len(acc_grads)):
        for p_i in range(len(acc_grads[grp_i])):
            acc_grads[grp_i][p_i] =  acc_grads[grp_i][p_i].at[:].add(i_step_grads[grp_i][p_i])
            
    return acc_grads, grad_loss_rest

while True:
    #for _, batch in tqdm(enumerate(itertools.islice(get_batched_examples(ds, batch_size, seq_len, START_TOK, END_TOK, skip_n_rows = ds_train_rows_read), num_steps)), initial=i, total=num_steps, smoothing=0):
    for _, batch in tqdm(enumerate(itertools.islice(get_batched_examples_packed(ds, batch_size, seq_len, START_TOK, END_TOK, pack_frac=0.75, skip_n_rows = ds_train_rows_read), num_steps)), initial=i, total=num_steps, smoothing=0):
        _, y, _, y_mask, _, _, y_indices = batch
        # Training step
        # TODO: introduce update func, which does grad_loss and adam, and then call/jit that function instead of calling/jitting two separate ones
        key_training, key_iter = random.split(key_training, 2)
        grads, (loss_val, acc, _) = acc_grad_loss(grads, params, jnp.array(y), jnp.array(y_mask), jnp.array(y_indices), key_iter)
        #grads, (loss_val, acc) = grad_loss(params, jnp.array(x), jnp.array(y), key_iter)

        #lr = max_lr # for SGD
        
        if warmup_steps is not None:
            i_multidevice = i // gradient_accumulations_steps

            # AIAYN:
            #lr = pow(EMB_DIM, -0.5) * min(pow((i_multidevice+1), -0.5), (i_multidevice+1) * pow(warmup_steps, -1.5))

            # GPT1:
            if i_multidevice < warmup_steps:
                lr = (i_multidevice+1)/warmup_steps * max_lr
            else:
                t_step = i_multidevice - warmup_steps
                t_max = num_steps - warmup_steps
                lr = max_lr * (1 + math.cos(math.pi * t_step/t_max))/2

        #params = sgd(params, grads, lr)
        if i > 0 and i % gradient_accumulations_steps == 0:
            for grp_i in range(len(grads)):
                for p_i in range(len(grads[grp_i])):
                    grads[grp_i][p_i] =  grads[grp_i][p_i].at[:].divide(gradient_accumulations_steps)
            
            #params, moments = adam(params, grads, lr, betas, epsilon, moments, i)
            params, moments = adam_in_place(params, grads, lr, betas, epsilon, moments, i)
    
        # Logging:
        if i%log_every_steps==0:
            loss_val = loss_val.item()
            acc = acc.item()
            
            #@jit # TODO: I can't jit that one
            def g_l2norm_squared(g_list):
                return pow(jnp.linalg.norm(g_list),2)
            def l2norm(grads): # computing l2norm without any memory copies
                return math.sqrt(sum([ sum([g_l2norm_squared(g) for g in g_items]) for g_items in grads]))
            def grps_l2norms(grads):
                return [ math.sqrt(sum([g_l2norm_squared(g) for g in g_items])) for g_items in grads]
            grad_norm = l2norm(grads)
            grps_grad_norms = grps_l2norms(grads)

            
            #print(f'iter #{i} loss {loss_val} acc {acc} lr {lr} grad_norm {grad_norm}')
            #print_mem_stats() # TODO: monitor it in tensorboard?
            writer.add_scalar('train/loss', loss_val, i)
            writer.add_scalar('train/acc', acc, i)
            writer.add_scalar('train/lr', lr, i)
            writer.add_scalar('train/grad_norm', grad_norm, i)
            for grp_i, grp_grad_norm in enumerate(grps_grad_norms):
                writer.add_scalar(f'train_details/grad_norm_grp_{grp_i}', grp_grad_norm, i)

            # TODO: some metrics computed on x, other on y. Make it consistent
            #pad_tokens_prop = sum([y_row.count(0) for y_row in y]) / sum([len(y_row) for y_row in y])
            pad_tokens_prop = np.count_nonzero(y==0) / y.size
            writer.add_scalar('train_data/pad_tokens_prop', pad_tokens_prop, i)
            writer.add_scalar('train_data/batch_size', len(y), i)
            writer.add_scalar('train_data/batch_seq_len', len(y[0]), i)
            writer.add_scalar('train_data/batch_total_tokens', len(y) * len(y[0]), i)

        # Zeroed accumulated grads: we have to do it after computing grad norms
        if i > 0 and i % gradient_accumulations_steps == 0: 
            for grp_i in range(len(grads)):
                for p_i in range(len(grads[grp_i])):
                    grads[grp_i][p_i] =  grads[grp_i][p_i].at[:].set(0)
            
        # Evaluation
        if i>0 and i%eval_every_steps==0:
            val_losses = []
            val_accs = []
            val_toks_props = []
            for eval_step, batch in enumerate(get_batched_examples(ds, batch_size, seq_len, START_TOK, END_TOK, split="validation")): 
                _, y, _, y_mask, _, _, y_indices = batch
                _, (loss_val, acc, toks_prop) = loss_eval(params, np.array(y), jnp.array(y_mask), jnp.array(y_indices))
                val_losses.append(loss_val)
                val_accs.append(acc)
                val_toks_props.append(toks_prop)
            writer.add_scalar('eval/loss', jnp.average(jnp.hstack(val_losses), weights = jnp.hstack(val_toks_props)).item(), i)
            writer.add_scalar('eval/acc', jnp.average(jnp.hstack(val_accs), weights = jnp.hstack(val_toks_props)).item(), i)
            
            # Few predictions TODO XXX: vary temperature -> diff samples
            y_sample = predict(params, jnp.array(y_eval_mask), jnp.array(y_eval_indices), seq_len)
            y_sample = tuple([item.tolist() for item in y_sample])
            def detokenize_y_in(y):
                y_out = y[:, 1:]
                y_out[y_out == END_TOK] = 0
                return detokenize(y_out)
            for detokenized_y_sample in detokenize(y_sample):
                print(f'PREDS: {detokenized_y_sample}\n')

            # Compute HellaSwag score
            # TODO XXX XXX: This is not efficient. As minimum, we could do it in parallel on CPU. We can also move the whole op
            # (i.e. unpacking, and copying individual choices to y) onto GPU: if so, data loader& data collator, need nested list
            # for x, in which each choice is of equal shape. Then, I believe we can concatenate y with choice in batches efficiently
            # we need different structure of batched_x coming from data collator: individual choices should be of equal shapes:
            # I assume existance of scatter like function, but operating on contigious blocks instead of individual values.
            # Finally we need batched computes of value function itself..
            def unpack_hellaswag_batched_x(batched_x: np.ndarray):
                choices = []
                labels = []
                for item_x in batched_x:
                    item_choices, item_label = unpack_hellaswag_x(list(np.trim_zeros(item_x)))
                    choices.append(item_choices)
                    labels.append(item_label)
                return list(map(list, zip(*choices))), labels

            def concatenate_y_and_choice(batched_y: np.ndarray, batched_choice: np.ndarray):
                y_and_choice = np.copy(batched_y)
                y_and_choice_mask = [] 
                full_mask = np.tri(batched_y.shape[1]-1)
                for y, choice in zip(y_and_choice, batched_choice):
                    choice = choice +[END_TOK]
                    zeros = np.where(y == END_TOK)[0]
                    assert len(zeros)>0
                    assert zeros[0]+len(choice) < len(y)
                    y[zeros[0]:zeros[0]+len(choice)]= choice
                    y_pad_mask = np.where(y[1:] != 0, np.ones((y[1:].shape[0])), 0) #  TODO XXX: is it behaving correctly if START_TOK&END_TOK present
                    y_mask = np.multiply(np.multiply(full_mask, y_pad_mask), y_pad_mask[:, None])
                    y_and_choice_mask.append(y_mask)
                return y_and_choice, np.array(y_and_choice_mask)

            print(f'Compute HellaSwag score')
            hellaswag_accs = [] # TODO XXX: enable seq_len be different for x vs y; 
            num_hellaswag_batches = 50 #TODO XXX:; run for the whole dataset
            for _, batch in tqdm(enumerate(itertools.islice(get_batched_examples(hellaswag_ds, 32, hellaswag_seq_len, START_TOK, END_TOK, split=None), num_hellaswag_batches))): 
            #for _, batch in tqdm(enumerate(get_batched_examples(hellaswag_ds, 2, 400, START_TOK, END_TOK, split=None))):
                choices_vals = []
                x, y, _, y_mask, _, _, y_indices = batch
                choices, labels = unpack_hellaswag_batched_x(x) 
                
                for choice in choices:
                    y, y_mask = concatenate_y_and_choice(y, choice) # no need to return new y_indices for now.
                    choice_log_probs = log_probs(params, jnp.array(y), jnp.array(y_mask), jnp.array(y_indices))
                    choices_vals.append(choice_log_probs)
                choices_vals = np.array(choices_vals).transpose() # we want choice per column
                hellaswag_accs.extend(np.argmax(choices_vals, axis=1)==labels)
                   
            hellaswag_acc = sum(hellaswag_accs)/len(hellaswag_accs)
            print(f'HellaSwag score:', hellaswag_acc)
            writer.add_scalar('eval/hellaswag', hellaswag_acc, i)
                
        i = i + 1
        ds_train_rows_read = ds_train_rows_read + len(y)

        # Checkpointing (i, ds_train_rows_read, params, moments).
        if checkpoint_every_steps is not None and (i>0 and i%checkpoint_every_steps==0):
            import os
            training_state = (i, ds_train_rows_read, params, moments, key_training)
            filename = f'runs/{run_name}/checkpoint_{i}.pkl'
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            with open(filename, 'wb') as f:
                pickle.dump(training_state, f)

        if i> num_steps:
            break
                
    ds_train_rows_read=0 # After each epoch, reset dataset pointer

writer.close()



Number of params: 112_220_206


  0%|          | 155/100000 [01:27<15:41:19,  1.77it/s]

In [None]:
# FOR TESTING

# Compute HellaSwag score
import numpy as np
def unpack_hellaswag_batched_x(batched_x: np.ndarray):
    choices = []
    labels = []
    for item_x in batched_x:
        item_choices, item_label = unpack_hellaswag_x(list(np.trim_zeros(item_x)))
        choices.append(item_choices)
        labels.append(item_label)
    return list(map(list, zip(*choices))), labels

def concatenate_y_and_choice(batched_y: np.ndarray, batched_choice: np.ndarray):
    y_and_choice = np.copy(batched_y)
    y_and_choice_mask = [] 
    full_mask = np.tri(batched_y.shape[1]-1)
    for y, choice in zip(y_and_choice, batched_choice):
        choice = choice +[END_TOK]
        zeros = np.where(y == END_TOK)[0]
        assert len(zeros)>0
        assert zeros[0]+len(choice) < len(y)
        y[zeros[0]:zeros[0]+len(choice)]= choice
        y_pad_mask = np.where(y[1:] != 0, np.ones((y[1:].shape[0])), 0) #  TODO XXX: is it behaving correctly if START_TOK&END_TOK present
        y_mask = np.multiply(np.multiply(full_mask, y_pad_mask), y_pad_mask[:, None])
        y_and_choice_mask.append(y_mask)
    return y_and_choice, np.array(y_and_choice_mask)

hellaswag_accs = []
from tqdm import tqdm
import itertools
for _, batch in tqdm(enumerate(itertools.islice(get_batched_examples(hellaswag_ds, 20, 400, START_TOK, END_TOK, split=None), 5))): 
#for _, batch in tqdm(enumerate(get_batched_examples(hellaswag_ds, 1, 400, START_TOK, END_TOK, split=None))):
    choices_vals = []
    x, y, _, y_mask, _, _, y_indices = batch
    choices, labels = unpack_hellaswag_batched_x(x)
    
    for choice in choices:
        y, y_mask = concatenate_y_and_choice(y, choice) # no need to return new y_indices for now.
        choice_log_probs = log_probs(params, jnp.array(y), jnp.array(y_mask), jnp.array(y_indices))
        choices_vals.append(choice_log_probs)
    choices_vals = np.array(choices_vals).transpose()
    hellaswag_accs.extend(np.argmax(choices_vals, axis=1)==labels)

#print("hellaswag_accs", hellaswag_accs)
hellaswag_acc = sum(hellaswag_accs)/len(hellaswag_accs)
print(hellaswag_acc)


In [None]:
### Final test predictions + BLEU computation
print(f'Few predictions for validation dataset')
y_sample = predict(params, jnp.array(x_eval))
y_sample = tuple([item.tolist() for item in y_sample])
for detekonized_x_eval, detokenized_y_eval, detokenized_y_sample in zip(detokenize(x_eval), detokenize(y_eval), detokenize(y_sample)):
    print(f'X:{detekonized_x_eval}\tY: {detokenized_y_eval} \tPREDS: {detokenized_y_sample}\n')
    references.append(detokenized_y_eval)
    predictions.append(detokenized_y_sample)

print(f'Computing BLEU for validation dataset')
import evaluate
references = [] 
predictions = []
for _, (x, y) in tqdm(enumerate(get_batched_examples_per_length(ds, x_tokens_per_batch, split="validation"))):
    y_sample = predict(params, jnp.array(x), seq_len)
    y_sample = tuple([item.tolist() for item in y_sample])
    for detekonized_x_eval, detokenized_y_eval, detokenized_y_sample in zip(detokenize(x), detokenize(y), detokenize(y_sample)):
        references.append(detokenized_y_eval)
        predictions.append(detokenized_y_sample)

bleu = evaluate.load("bleu")
results = bleu.compute(predictions=predictions, references=references)
print(results)