In [1]:
# Suggested flags to set from the jax docs
import os
os.environ['XLA_FLAGS'] = (
    #'--xla_gpu_enable_triton_softmax_fusion=true '
    #'--xla_gpu_triton_gemm_any=True '
    #'--xla_gpu_enable_async_collectives=true '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
)

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

In [1]:
from absl import logging
import flax
import jax.numpy as jnp
import jax
import numpy as np

import tensorflow_datasets as tfds
import tensorflow as tf

tf.config.experimental.set_visible_devices([], "GPU")

logging.set_verbosity(logging.INFO)

2024-06-20 19:56:02.943577: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-20 19:56:02.943675: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-20 19:56:02.953796: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# Load the pre-trained weights from huggingface - Do not run this if you are 
# planning to train your own model. It will just take up memory
from transformers import FlaxGPT2LMHeadModel

model_hf = FlaxGPT2LMHeadModel.from_pretrained("openai-community/gpt2")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Here is the dictionary tree we have to replicate. pytrees are just nested dictionaries.
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params)))

initialized parameter shapes:
 {'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, 

In [3]:
from dataclasses import dataclass
from flax import linen as nn
from typing import Any
# define a full GTP2 class in Flax and see if we can replicate the original paper along with
# Karpathy

# TODO:
# - Add istraining variable to allow dropout and other operations that are different during training and inference
# - Use the Jax configuration utilities instead of the GPTConfig class
# - Develop my own training routing with all of the bells and whistles required: timing, checkpointinng, etc...
# - Move everything to python files for more professional-like code from the command line

@dataclass
class GPTConfig:
    block_size: int = 256
    vocab_size: int = 50257
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    dtype: Any = jnp.bfloat16

class GPTMLP(nn.Module):
    config: GPTConfig

    def setup(self):
        # Simple MLP that upscales, runs through a gelu activation,
        # and then resamples back to the n_embd size (the model size)
        #self.c_fc = nn.Dense(4 * self.config.n_embd)
        # Had to use Einsum to match the matrix multiplication of GPT2 and pytorch. They accept 
        # both shapes and then multiply by the transpose of the matrix (basically means the 
        # shape of the matrix is transpose, but the operation is the same). I confirmed that this
        # produces the same result as the complicated huggingface conv1d version (conv1d is also
        # just a linear matrix operation as well). They do add a lot of variables for 
        # mixed-precision training, that I do not.
        # TODO: Move this into a new module as I am repeating it everywhere
        #self.c_fc = nn.Einsum((4 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik', kernel_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)
        self.c_fc = nn.Dense(self.config.n_embd * 4, kernel_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)
        #self.c_proj = nn.Dense(self.config.n_embd)
        #self.c_proj = nn.Einsum((self.config.n_embd, 4 * self.config.n_embd), '...ij,...kj->...ik', kernel_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)
        self.c_proj = nn.Dense(self.config.n_embd, kernel_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)

    def __call__(self, inputs):
        x = inputs

        x = self.c_fc(x)
        x = nn.gelu(x, approximate=True)
        x = self.c_proj(x)

        return x
    
class GPTAttention(nn.Module):
    config: GPTConfig

    # we will need to roll our own attention module because the built in one has a bunch of different
    # naming and structure compared to the original GPT, which just references the projection layers
    def setup(self):
        # The first thing we do is project up to 3x the model size, because we are going to split
        # the data into q, v, k
        #self.c_attn = nn.Einsum((3 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik', kernel_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)
        self.c_attn = nn.Dense(3 * self.config.n_embd, kernel_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)
        
        # At the end we have to project everything back to the regular model size of n_embd
        #self.c_proj = nn.Einsum((self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik', kernel_init=jax.nn.initializers.normal(stddev=(self.config.n_layer*2) ** -0.5), dtype=self.config.dtype)
        self.c_proj = nn.Dense(self.config.n_embd, kernel_init=jax.nn.initializers.normal(stddev=(self.config.n_layer*2) ** -0.5), dtype=self.config.dtype)
        
    def __call__(self, inputs):
        x = inputs
        B, T, C = jnp.shape(x)
        
        # Project to qkv
        qkv = self.c_attn(x)

        q, k, v = jnp.split(qkv, 3, axis=2)

        query_length, key_length = q.shape[1], k.shape[1]

        # Now reshape with a new head "batch" dimension
        k = jnp.reshape(k, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        q = jnp.reshape(q, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        v = jnp.reshape(v, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)

        # Make the attention mask
        # TODO: For round 1 I'm just trusting the linen functions. They seem to be doing the correct thing here, but I may have to 
        # return to this for a closer look at the math if I'm not getting the GPT2 results
        # Just copied this from the huggingface code to be consistent. First part just broadcasts the causal masks to the batch
        # dimensions (replicating it as a lower triangular matrix of truths). The attention_bias is a stripped down version of the 
        # huggingface code, but the bias has to be floats. They bias the attention softmax, so basically set to -inf where you
        # want it ignored. This will need to be OR'd with an attention mask in certain situations, like encoder/decoder networks.
        # TODO: Also, this mask does not need to be applied during inference, so we could have an 'istraining' variable passed 
        # down the network for these cases and then ignore the mask calculations for increased speed. In the Flax examples, they 
        # also appear to cache the results of previous calculations, which I guess makes sense because we are just adding one 
        # token at a time to the input sequence and then calculating again. There is no point in recalculating the previous 
        # token outputs every time. I should probbaly implement some version of that too.
        c_mask = nn.make_causal_mask(jnp.ones((1, self.config.block_size), dtype="bool"), dtype="bool")
        c_mask = c_mask[:, :, :query_length, :key_length]
        c_mask = jnp.broadcast_to(c_mask, (B,) + c_mask.shape[1:])

        attention_bias = jax.lax.select(
                c_mask > 0,
                jnp.full(c_mask.shape, 0.0).astype(self.config.dtype),
                jnp.full(c_mask.shape, jnp.finfo(self.config.dtype).min).astype(self.config.dtype),
            )

        # use the built in flax libraries to calculate the attention matrix - attention weights are not returned, but could be 
        # I think bias gives more control compared to mask - i.e. bias can be a float. They might result in the same output with a 
        # boolean mask, but I will have to test that.
        # TODO: I don't think Flax has a flash attention module. Is there any way to add that for Flax that will actually be 
        # optimized for hardware? I don't know.
        y = nn.dot_product_attention(q, k, v, bias=attention_bias, dtype=self.config.dtype)

        # Merge the heads back together
        y = y.reshape((B, T, C))

        # Project output with a FC layer
        y = self.c_proj(y)

        return y

class GPTBlock(nn.Module):
    config: GPTConfig

    def setup(self):
        self.ln_1 = nn.LayerNorm(epsilon=1e-05, dtype=jnp.float32)
        # I might have to write this manually to get the proper number of parameters, as the old GPT2 code 
        # migh have subtle differences from the Flax implementation
        self.attn = GPTAttention(self.config)
        self.ln_2 = nn.LayerNorm(epsilon=1e-05, dtype=jnp.float32)
        self.mlp = GPTMLP(self.config)

    def __call__(self, inputs):
        x = inputs
        x = self.ln_1(x)
        x = inputs + self.attn(x)
        inputs2 = x
        x = self.ln_2(x)
        x = inputs2 + self.mlp(x)
        return x

class GPTLayers(nn.Module):
    config: GPTConfig
    
    def setup(self):
        self.blocks = [ GPTBlock(self.config, name=str(i)) for i in range(self.config.n_layer) ]

    def __call__(self, inputs):
        x = inputs

        for block in self.blocks:
            x = block(x)
        return x

class GPTModel(nn.Module):
    config: GPTConfig

    def setup(self):
        # This is a little confusing. vocab size is the number of embeddings we need.
        # n_embd is the dimension of each embedding (i.e. capacity for learning properties
        # about this embedding token)
        # input size = (1 x self.block_size) - 1 int for each token
        # output size = then (self.block_size x self.n_embd)
        self.wte = nn.Embed(self.config.vocab_size, self.config.n_embd, embedding_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)
        # embed is just a randomzied parameter matrix, so can be used for positional 
        # encoding as well. I think block size is the token length.
        # This has to match the size of the previous output, as we are just adding.
        self.wpe = nn.Embed(self.config.block_size, self.config.n_embd, embedding_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)
        # The attention layers
        self.h = GPTLayers(self.config)
        self.ln_f = nn.LayerNorm(epsilon=1e-05, dtype=jnp.float32)

    def __call__(self, inputs):
        x = inputs
        input_shape = jnp.shape(x)       

        x = self.wte(x)

        # For the positional encodings we need an index that is simple the position
        # of each token. This will be the same shape as the input, but will simply
        # be repeating and increasing numbers from 1.
        # jnp.atleast_2d is needed so we can initialize with a batch size of 1
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(inputs).shape[-1]), input_shape)
        x_wpe = self.wpe(position_ids)
        
        x += x_wpe

        x = self.h(x)

        x = self.ln_f(x)

        return x


class GPT(nn.Module):
    config: GPTConfig

    def setup(self):
        self.transformer = GPTModel(self.config)
        # So the Flax model does not return parameters for lm_head, because lm_head is simply the 
        # inverse operation of the initial word embedding, so we can just reuse those weights. They
        # use the term 'tied' for this. The weights are tied together.
        self.lm_head = nn.Dense(self.config.vocab_size, use_bias=False, dtype=self.config.dtype)

    def __call__(self, inputs):
        x = inputs
        x = self.transformer(x)

        # This is from the huggingface code - might as well reuse it here.
        shared_kernel = self.transformer.variables['params']['wte']['embedding'].T
        lm_logits = self.lm_head.apply({'params': {'kernel': shared_kernel}}, x)
        return lm_logits

In [29]:
# Initialize
GPT2 = GPT(GPTConfig)

key1, key2 = jax.random.split(jax.random.key(0), 2)

# We never want to use int64 as it signficantly slows down training
x = np.random.randint(0, GPTConfig.vocab_size, (1, GPTConfig.block_size), dtype='i4')

y = GPT2.init(key2, x)

In [14]:
# Inference
import tiktoken
import jax
import jax.numpy as jnp
#from tensorflow_probability.substrates import jax as tfp

#tfd = tfp.distributions

# Generate a starting sequence and we will generate the rest of the sequence with GPT2 for
# however many batches we need
num_return_sequences = 1
max_length = 100

enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode("")

tokens = jnp.array(tokens, dtype="i4")

initial_input = jnp.atleast_2d(tokens).repeat(num_return_sequences, 0)

init_length=jnp.shape(initial_input)[-1]

# Has to be a constant length for the fast-autogression, so we need to pad (we will ignore the pads until 
# those indicies are actually needed by the model as it grows)
initial_input=jnp.pad(initial_input, ((0, 0), (0, max_length-init_length)), mode='constant', constant_values=0)

In [15]:
# Jax-style fast autoregression using the jax looping tools. This takes a bit of time to get correct
# with each model as you have to design it using the XLA tools and every single parameter needs to 
# be deterministic. This leads to lots of gotchas. For example, dynamic_slice and dynamic_slice_update
# need to have a constant for the size, but can be placed anywhere in the matrix (i.e., you can't use
# a variable with this). And, take_along_axis needs to be used for the top_k type indexing as it uses 
# the XLA gather method underneath, which is super complicated to use by hand. Every model produces 
# new difficulties and I'm sure this can be optimized quite a bit if I understood jax a bit better.
rng = jax.random.key(27)

temperature = 1.0

def get_sentence(n, carry):
    x, rng = carry

    rng, new_rng = jax.random.split(rng)

    # pull out only the sequence up until the currently filled in
    # values. It would be better if we could pass in a variable 
    # length sequence, but I don't think that's allowed by jax
    logits = GPT2.apply({'params': jax_utils.unreplicate(state.params)}, x)

    # GPT2 predicted the new token int he last column of the output, 
    # so we only need this.
    logits = logits[:, n, :]

    # Sample from the top 50 tokens
    topk_probs, topk_indices = jax.lax.top_k(logits, 50)

    ix = jnp.expand_dims(jax.random.categorical(new_rng, topk_probs/temperature, axis=-1), axis=-1)

    #print(ix)

    xcol=jnp.atleast_2d(jnp.take_along_axis(topk_indices, ix.astype(jnp.int32), axis=-1))

    x = jax.lax.dynamic_update_slice(x, xcol, (0, n+1))

    return x, rng

x, rng = jax.lax.fori_loop(init_length-1, max_length-1, get_sentence, (initial_input, rng))


In [16]:
# Print the results
for i in range(0, x.shape[0]):
    print(enc.decode(x[i,:].tolist()))


Rie.

RINGBRO doth in all.

RICHARD II:

Nay, she, I fear:
' MARUEENRY BOLINGHAM:
Than it, if you, thou, no more I pray you come at death?
HORTIO:
O, my husband,
NANUS:
The father's love.

My mother.

GLOUCESUS:

To


In [16]:
# Just comparing the shapes of our model vs the GPT2 trained weights.
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(y['params']['transformer'])))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params['transformer'])))

initialized parameter shapes:
 {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '11': {'attn': {

In [4]:
# Initialize
GPT2 = GPT(GPTConfig(vocab_size=50304))

key1, key2 = jax.random.split(jax.random.key(2), 2)

# We never want to use int64 as it signficantly slows down training
x_init = np.random.randint(0, GPTConfig.vocab_size, (1, GPTConfig.block_size), dtype='i4')

y_init = GPT2.init(key2, x_init)

#logits = GPT2.apply({'params': y_init['params']}, x)

In [17]:
# In jax we create a separate apply function to calculate the loss

# From the flax examples. We should track some metrics as we train
def compute_metrics(logits, labels, loss):
    accuracy = jax.lax.pmean(jnp.mean(jnp.argmax(logits, -1) == labels), axis_name='batch')
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
    }
    metrics = jax.lax.pmean(metrics, axis_name='batch')
    return metrics

#@jax.jit
def apply_model(state, batch):
    def loss_fn(params):
        logits = state.apply_fn(
            {'params': params}, 
            batch['x'])
        #one_hot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch['y']))
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (logits)), grads = grad_fn(state.params)
    #accuracy = jax.lax.pmean(jnp.mean(jnp.argmax(logits, -1) == batch['y']), axis_name='batch')

    #loss = jax.lax.pmean(loss, axis_name='batch')
    grads = jax.lax.pmean(grads, axis_name='batch')
    state = state.apply_gradients(grads=grads)

    metrics = compute_metrics(logits, batch['y'], loss)
    return metrics, state

In [18]:
# Now let's do this for the entire tiny Shakespeare dataset
# We can follow Karpathy's style, but we will use the tensorflow datasets
# tools as our data loader as it is the natural thing to use with jax/flax
# We can also then use prefetch to device, which should speed things up as well
import tiktoken

class DataLoaderLite:
    def __init__(self, B, T):
        self.num_devices = jax.device_count()
        self.B = B
        self.T = T

        assert self.B % self.num_devices == 0, "Number of devices must divide evenly into number of batches"
        
        with open("input.txt", 'r') as f:
            text = f.read()
        enc = tiktoken.get_encoding('gpt2')
        tokens = enc.encode(text)
        # Could load it into gpu memory here if you want for speed. It would 
        # depend on how much you have to spare
        self.tokens = tokens
        print(f'Loaded {len(self.tokens)} tokens')
        print(f'One epoch = {len(self.tokens) // (B*T) } batches')

        self.current_position = 0

    def next_batch(self):
        D, B, T = self.num_devices, self.B, self.T
        buf = np.array(self.tokens[self.current_position:self.current_position+B*T+1])
        x = buf[:-1].reshape((D, B // D, T))
        y = buf[1:].reshape((D, B // D ,T))
        self.current_position+=B*T

        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        
        return {'x': x, 'y': y}

    # The Tensorflow Data way
    def getIterator(self):
        D, B, T = self.num_devices, self.B, self.T

        ds = tf.data.Dataset.from_tensor_slices(self.tokens)
        
        #x = np.array(self.tokens)
        
        #ds = tf.data.Dataset.from_tensor_slices(x[:-(x.shape[0] % (T+1))].reshape((-1, T+1)))

        ds = ds.batch(B*T+1, drop_remainder=True).map(lambda x: {'x': tf.reshape(x[:-1], (D, B // D, T)), 'y': tf.reshape(x[1:], (D, B // D, T))}, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE).cache().repeat()
        
        #ds = ds.shuffle(1000).batch(B*T+1, drop_remainder=True)(e(x[1:], (D, B //2, T))}).prefetch(tf.data.AUTOTUNE).cache().repeat()
        
        #ds = ds.shuffle(1000).batch(B, drop_remainder=True).prefetch(tf.data.AUTOTUNE).cache().repeat()        
        return ds
        

In [11]:
def prepare_tf_data(xs):
  """Convert a input batch from tf Tensors to numpy arrays."""
  local_device_count = jax.local_device_count()

  def _prepare(x):
    # Use _numpy() for zero-copy conversion between TF and NumPy.
    x = x._numpy()  # pylint: disable=protected-access

    # reshape (host_batch_size, height, width, 3) to
    # (local_devices, device_batch_size, height, width, 3)
    return {'x': x[:, :-1].reshape((2, 8, -1)), 'y': x[:, 1:].reshape((2, 8, -1))}

  return jax.tree_util.tree_map(_prepare, xs)

#it = map(prepare_tf_data, it)

In [39]:
import time
from flax import jax_utils
from typing import Any
from flax.training import train_state
import optax

train_loader = DataLoaderLite(8, 256)

#learning_rate = 0.0003
max_lr = 6e-4

schedule = optax.warmup_cosine_decay_schedule(
    init_value=max_lr*0.1,
    peak_value=max_lr,
    warmup_steps=10,
    decay_steps=1_000,
    end_value=max_lr*0.1
)

# Start with AdamW
#tx = optax.adamw(learning_rate=learning_rate)
tx = optax.chain(
    optax.clip(1.0),
    optax.adamw(learning_rate=schedule, b2=0.95, weight_decay=0.1),
    )

# Gradient accumulation here
# Note that we can supply a function here instead if we want to adjust the batch sizes
# as we train.
tx = optax.MultiSteps(tx, every_k_schedule=2)

# Flax handles the training state for us. Generally, you should subclass the TrainState and 
# add anything you need to it. Here, we don't need any additions, but we may in the future
# so I will leave it in for now. Things you might add are the batch_stats or dropout keys
class TrainState(train_state.TrainState):
  key: jax.Array
  batch_stats: Any

# Replace this with our subclassed TrainState when we need it.
state = train_state.TrainState.create(apply_fn=GPT2.apply, params=y_init['params'], tx=tx)

# for distributed training we need to replicate the state
state = jax_utils.replicate(state)

Loaded 338025 tokens
One epoch = 165 batches


In [40]:
logdir = './metrics'
from clu import metric_writers
from absl import logging
from flax.training import common_utils
logging.set_verbosity(logging.INFO)

num_epochs = 100

# I can get about 245ms or 66500 tok/sec out of the compiled jax code with bfloat16. Karpathy
# is at about 93ms at this point, but I'm not sure where the slowdown is. Maybe pytorch is a
# bit more specialized for this application, whereas jax is more general and numpy-like in 
# it's optimization.
# Removing the Einsum layers helped a little bit. Down to 235ms.
# It could also just be the hardware and infrastructure. Maybe the cpu/ram/pcie lanes are 
# a bit slower on this box
# After pmap'ing the update function (or jit for non-distributed), we are at 116ms or 141000 tok/sec
# This is getting close to Karpathy's numbers. Prefetching the data may be the next step

# It is very important to jit the apply_gradients function as well or else you will take a 
# very large speed penalty
#@jax.jit
#def update_model(state, grads):
#    return state.apply_gradients(grads=grads)

it = train_loader.getIterator().as_numpy_iterator()
#it = map(prepare_tf_data, it)
it = jax_utils.prefetch_to_device(it, 2)

logging.info("The first step will take a bit of time to compile")
p_train_step = jax.pmap(apply_model, axis_name='batch')

writer = metric_writers.create_default_writer(logdir)

train_metrics = []

for epoch in range(1, num_epochs + 1):
    t0 = time.time()
    #batch = train_loader.next_batch()
    batch = next(it)
    metrics, state = p_train_step(state, batch)

   # state = update_model(state, grads)

    # Making sure we are synchronized - Probably the only way to do this in jax. This may 
    # be slowing things down a bit though. Probably remove this when not timing.
    jax.random.normal(jax.random.key(0), ()).block_until_ready()
    t1 = time.time()

    dt = (t1 - t0)*1000

    tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)

    metrics['lr'] = np.array([schedule(epoch)])
    #metrics['dt'] = np.array(dt)
    #metrics['tok_per_sec'] = np.array(tokens_per_sec)
    train_metrics.append(metrics)

    # This will unreplicate for us
    train_metrics=common_utils.get_metrics(train_metrics)

    # From the flax examples. Just average the train_stats if we only want to 
    # provide them at a certain number of steps
    summary = {
            f'train_{k}': v
            for k, v in jax.tree_util.tree_map(
                lambda x: np.mean(x), train_metrics
            ).items()
        }

    writer.write_scalars(epoch + 1, summary)
    logging.info(summary)
    train_metrics = []

    #print(f'step {epoch}, loss: {jax_utils.unreplicate(loss)}, accuracy: {jax_utils.unreplicate(accuracy):.2f}, lr: {schedule(epoch):.5f}, dt: {dt:.2f}ms, tok/sec: {tokens_per_sec}')

2024-06-20 21:34:17.703404: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
INFO:absl:The first step will take a bit of time to compile
INFO:absl:[2] train_accuracy=0.0, train_loss=11, train_lr=0.0001140000531449914
INFO:absl:{'train_accuracy': 0.0, 'train_loss': 11, 'train_lr': 0.00011400005}
INFO:absl:{'train_accuracy': 0.0, 'train_loss': 10.9375, 'train_lr': 0.00016800003}
INFO:absl:[3] train_accuracy=0.0, train_loss=10.9375, train_lr=0.00016800002777017653
INFO:absl:{'train_accuracy': 0.0, 'train_loss': 10.75, 'train_lr': 0.00022200006}
INFO:absl:[4] train_accuracy=0.0, train_loss=10.75, train_lr=0.00022200006060302258
INFO:absl:[5] train_a

In [38]:
from flax.training import common_utils

common_utils.get_metrics(train_metrics)

{'accuracy': array([0.], dtype=float32),
 'loss': array([10.6875], dtype=bfloat16),
 'lr': array([0.000114], dtype=float32)}

In [37]:
train_metrics

[{'accuracy': Array([0.], dtype=float32),
  'loss': Array([10.6875], dtype=bfloat16),
  'lr': array([0.000114], dtype=float32)}]