In [1]:
from absl import logging
import flax
import jax.numpy as jnp
import jax
import numpy as np

import tensorflow_datasets as tfds
import tensorflow as tf

tf.config.experimental.set_visible_devices([], "GPU")

logging.set_verbosity(logging.INFO)

2024-06-18 15:52:09.020073: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-18 15:52:09.020133: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-18 15:52:09.020959: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# Load the pre-trained weights from huggingface - Do not run this if you are 
# planning to train your own model. It will just take up memory
from transformers import FlaxGPT2LMHeadModel

model_hf = FlaxGPT2LMHeadModel.from_pretrained("openai-community/gpt2")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Here is the dictionary tree we have to replicate. pytrees are just nested dictionaries.
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params)))

initialized parameter shapes:
 {'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, 

In [2]:
from dataclasses import dataclass
from flax import linen as nn
from typing import Any
# define a full GTP2 class in Flax and see if we can replicate the original paper along with
# Karpathy

# TODO:
# - Add options for different dtypes of the model
# - Add options for difference precisions (and understand that better - I think it has to do with the mantissa of the floats
# - Add istraining variable to allow dropout and other operations that are different during training and inference
# - Maybe: Flash attention?
# - Maybe: Sharding for multiple gpu training?
# - Use the Jax configuration utilities instead of the GPTConfig class
# - Develop my own training routing with all of the bells and whistles required: timing, checkpointinng, etc...
# - Move everything to python files for more professional-like code from the command line

@dataclass
class GPTConfig:
    block_size: int = 256
    vocab_size: int = 50257
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    dtype: Any = jnp.bfloat16

class GPTMLP(nn.Module):
    config: GPTConfig

    def setup(self):
        # Simple MLP that upscales, runs through a gelu activation,
        # and then resamples back to the n_embd size (the model size)
        #self.c_fc = nn.Dense(4 * self.config.n_embd)
        # Had to use Einsum to match the matrix multiplication of GPT2 and pytorch. They accept 
        # both shapes and then multiply by the transpose of the matrix (basically means the 
        # shape of the matrix is transpose, but the operation is the same). I confirmed that this
        # produces the same result as the complicated huggingface conv1d version (conv1d is also
        # just a linear matrix operation as well). They do add a lot of variables for 
        # mixed-precision training, that I do not.
        # TODO: Move this into a new module as I am repeating it everywhere
        self.c_fc = nn.Einsum((4 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik', kernel_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)
        #self.c_proj = nn.Dense(self.config.n_embd)
        self.c_proj = nn.Einsum((self.config.n_embd, 4 * self.config.n_embd), '...ij,...kj->...ik', kernel_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)

    def __call__(self, inputs):
        x = inputs

        x = self.c_fc(x)
        x = nn.gelu(x, approximate=True)
        x = self.c_proj(x)

        return x
    
class GPTAttention(nn.Module):
    config: GPTConfig

    # we will need to roll our own attention module because the built in one has a bunch of different
    # naming and structure compared to the original GPT, which just references the projection layers
    def setup(self):
        # The first thing we do is project up to 3x the model size, because we are going to split
        # the data into q, v, k
        self.c_attn = nn.Einsum((3 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik', kernel_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)

        # At the end we have to project everything back to the regular model size of n_embd
        self.c_proj = nn.Einsum((self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik', kernel_init=jax.nn.initializers.normal(stddev=(self.config.n_layer*2) ** -0.5), dtype=self.config.dtype)

    def __call__(self, inputs):
        x = inputs
        B, T, C = jnp.shape(x)
        
        # Project to qkv
        qkv = self.c_attn(x)

        q, k, v = jnp.split(qkv, 3, axis=2)

        query_length, key_length = q.shape[1], k.shape[1]

        # Now reshape with a new head "batch" dimension
        k = jnp.reshape(k, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        q = jnp.reshape(q, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        v = jnp.reshape(v, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)

        # Make the attention mask
        # TODO: For round 1 I'm just trusting the linen functions. They seem to be doing the correct thing here, but I may have to 
        # return to this for a closer look at the math if I'm not getting the GPT2 results
        # Just copied this from the huggingface code to be consistent. First part just broadcasts the causal masks to the batch
        # dimensions (replicating it as a lower triangular matrix of truths). The attention_bias is a stripped down version of the 
        # huggingface code, but the bias has to be floats. They bias the attention softmax, so basically set to -inf where you
        # want it ignored. This will need to be OR'd with an attention mask in certain situations, like encoder/decoder networks.
        # TODO: Also, this mask does not need to be applied during inference, so we could have an 'istraining' variable passed 
        # down the network for these cases and then ignore the mask calculations for increased speed. In the Flax examples, they 
        # also appear to cache the results of previous calculations, which I guess makes sense because we are just adding one 
        # token at a time to the input sequence and then calculating again. There is no point in recalculating the previous 
        # token outputs every time. I should probbaly implement some version of that too.
        c_mask = nn.make_causal_mask(jnp.ones((1, self.config.block_size), dtype="bool"), dtype="bool")
        c_mask = c_mask[:, :, :query_length, :key_length]
        c_mask = jnp.broadcast_to(c_mask, (B,) + c_mask.shape[1:])

        attention_bias = jax.lax.select(
                c_mask > 0,
                jnp.full(c_mask.shape, 0.0).astype(jnp.float32),
                jnp.full(c_mask.shape, jnp.finfo(jnp.float32).min).astype(jnp.float32),
            )

        # use the built in flax libraries to calculate the attention matrix - attention weights are not returned, but could be 
        # I think bias gives more control compared to mask - i.e. bias can be a float. They might result in the same output with a 
        # boolean mask, but I will have to test that.
        # TODO: I don't think Flax has a flash attention module. Is there any way to add that for Flax that will actually be 
        # optimized for hardware? I don't know.
        y = nn.dot_product_attention(q, k, v, bias=attention_bias, dtype=self.config.dtype)

        # Merge the heads back together
        y = y.reshape((B, T, C))

        # Project output with a FC layer
        y = self.c_proj(y)

        return y

class GPTBlock(nn.Module):
    config: GPTConfig

    def setup(self):
        self.ln_1 = nn.LayerNorm(epsilon=1e-05, dtype=jnp.float32)
        # I might have to write this manually to get the proper number of parameters, as the old GPT2 code 
        # migh have subtle differences from the Flax implementation
        self.attn = GPTAttention(self.config)
        self.ln_2 = nn.LayerNorm(epsilon=1e-05, dtype=jnp.float32)
        self.mlp = GPTMLP(self.config)

    def __call__(self, inputs):
        x = inputs
        x = self.ln_1(x)
        x = inputs + self.attn(x)
        inputs2 = x
        x = self.ln_2(x)
        x = inputs2 + self.mlp(x)
        return x

class GPTLayers(nn.Module):
    config: GPTConfig
    
    def setup(self):
        self.blocks = [ GPTBlock(self.config, name=str(i)) for i in range(self.config.n_layer) ]

    def __call__(self, inputs):
        x = inputs

        for block in self.blocks:
            x = block(x)
        return x

class GPTModel(nn.Module):
    config: GPTConfig

    def setup(self):
        # This is a little confusing. vocab size is the number of embeddings we need.
        # n_embd is the dimension of each embedding (i.e. capacity for learning properties
        # about this embedding token)
        # input size = (1 x self.block_size) - 1 int for each token
        # output size = then (self.block_size x self.n_embd)
        self.wte = nn.Embed(self.config.vocab_size, self.config.n_embd, embedding_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)
        # embed is just a randomzied parameter matrix, so can be used for positional 
        # encoding as well. I think block size is the token length.
        # This has to match the size of the previous output, as we are just adding.
        self.wpe = nn.Embed(self.config.block_size, self.config.n_embd, embedding_init=jax.nn.initializers.normal(stddev=0.02), dtype=self.config.dtype)
        # The attention layers
        self.h = GPTLayers(self.config)
        self.ln_f = nn.LayerNorm(epsilon=1e-05, dtype=jnp.float32)

    def __call__(self, inputs):
        x = inputs
        input_shape = jnp.shape(x)       

        x = self.wte(x)

        # For the positional encodings we need an index that is simple the position
        # of each token. This will be the same shape as the input, but will simply
        # be repeating and increasing numbers from 1.
        # jnp.atleast_2d is needed so we can initialize with a batch size of 1
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(inputs).shape[-1]), input_shape)
        x_wpe = self.wpe(position_ids)
        
        x += x_wpe

        x = self.h(x)

        x = self.ln_f(x)

        return x


class GPT(nn.Module):
    config: GPTConfig

    def setup(self):
        self.transformer = GPTModel(self.config)
        # So the Flax model does not return parameters for lm_head, because lm_head is simply the 
        # inverse operation of the initial word embedding, so we can just reuse those weights. They
        # use the term 'tied' for this. The weights are tied together.
        self.lm_head = nn.Dense(self.config.vocab_size, use_bias=False, dtype=self.config.dtype)

    def __call__(self, inputs):
        x = inputs
        x = self.transformer(x)

        # This is from the huggingface code - might as well reuse it here.
        shared_kernel = self.transformer.variables['params']['wte']['embedding'].T
        lm_logits = self.lm_head.apply({'params': {'kernel': shared_kernel}}, x)
        return lm_logits

In [29]:
# Initialize
GPT2 = GPT(GPTConfig)

key1, key2 = jax.random.split(jax.random.key(0), 2)

# We never want to use int64 as it signficantly slows down training
x = np.random.randint(0, GPTConfig.vocab_size, (1, GPTConfig.block_size), dtype='i4')

y = GPT2.init(key2, x)

In [9]:
# Inference
import tiktoken
import jax
import jax.numpy as jnp
#from tensorflow_probability.substrates import jax as tfp

#tfd = tfp.distributions

# Generate a starting sequence and we will generate the rest of the sequence with GPT2 for
# however many batches we need
num_return_sequences = 1
max_length = 50

enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode("")

tokens = jnp.array(tokens, dtype="i4")

initial_input = jnp.atleast_2d(tokens).repeat(num_return_sequences, 0)

init_length=jnp.shape(initial_input)[-1]

# Has to be a constant length for the fast-autogression, so we need to pad (we will ignore the pads until 
# those indicies are actually needed by the model as it grows)
initial_input=jnp.pad(initial_input, ((0, 0), (0, max_length-init_length)), mode='constant', constant_values=0)

In [10]:
# Jax-style fast autoregression using the jax looping tools. This takes a bit of time to get correct
# with each model as you have to design it using the XLA tools and every single parameter needs to 
# be deterministic. This leads to lots of gotchas. For example, dynamic_slice and dynamic_slice_update
# need to have a constant for the size, but can be placed anywhere in the matrix (i.e., you can't use
# a variable with this). And, take_along_axis needs to be used for the top_k type indexing as it uses 
# the XLA gather method underneath, which is super complicated to use by hand. Every model produces 
# new difficulties and I'm sure this can be optimized quite a bit if I understood jax a bit better.
rng = jax.random.key(25)

temperature = 1.0

def get_sentence(n, carry):
    x, rng = carry

    rng, new_rng = jax.random.split(rng)

    # pull out only the sequence up until the currently filled in
    # values. It would be better if we could pass in a variable 
    # length sequence, but I don't think that's allowed by jax
    logits = GPT2.apply({'params': state.params}, x)

    # GPT2 predicted the new token int he last column of the output, 
    # so we only need this.
    logits = logits[:, n, :]

    # Sample from the top 50 tokens
    topk_probs, topk_indices = jax.lax.top_k(logits, 50)

    ix = jnp.expand_dims(jax.random.categorical(new_rng, topk_probs/temperature, axis=-1), axis=-1)

    #print(ix)

    xcol=jnp.atleast_2d(jnp.take_along_axis(topk_indices, ix.astype(jnp.int32), axis=-1))

    x = jax.lax.dynamic_update_slice(x, xcol, (0, n+1))

    return x, rng

x, rng = jax.lax.fori_loop(init_length-1, max_length-1, get_sentence, (initial_input, rng))


In [11]:
# Print the results
for i in range(0, x.shape[0]):
    print(enc.decode(x[i,:].tolist()))


The time it will stay.




LUCYTER:
He had, I will you have?
'TABETH:
And, it; I pritinks you.



What are thee


In [16]:
# Just comparing the shapes of our model vs the GPT2 trained weights.
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(y['params']['transformer'])))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params['transformer'])))

initialized parameter shapes:
 {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '11': {'attn': {

In [3]:
# Now let's train our own GPT on the tiny Shakespear dataset:
#!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open("input.txt", 'r') as f:
    text = f.read()

In [4]:
import tiktoken

enc = tiktoken.get_encoding('gpt2')
data = text[:1000]
tokens = enc.encode(data)
print(tokens[:24])

[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13, 198, 198, 3237, 25, 198, 5248, 461, 11, 2740, 13]


In [3]:
# B, T = 4, 32
# buf = jnp.array(tokens[:(B*T)+1])
# x = buf[:-1].reshape((B,T)) # input
# y = buf[1:].reshape((B,T)) # targets (one ahead because GPT2 is an autogressive model 
#                            # - i.e. it's always predicting the next token in parallel)

# Initialize
GPT2 = GPT(GPTConfig)

key1, key2 = jax.random.split(jax.random.key(2), 2)

# We never want to use int64 as it signficantly slows down training
x_init = np.random.randint(0, GPTConfig.vocab_size, (1, GPTConfig.block_size), dtype='i4')

y_init = GPT2.init(key2, x_init)

#logits = GPT2.apply({'params': y_init['params']}, x)

In [66]:
logits.shape

(4, 32, 50257)

In [67]:
y.shape

(4, 32)

In [20]:
# Let's calculate one loss
import optax

# This is a little different than Karpathy because jax can accept the batch dimension. We then mean the losses
# of the batches to get the full batch loss. I would assume that the pytorch cross_entropy is doing the same
# thing internally
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y))
# The loss is close to the expected value of 10.8
print(loss)

11.311539


In [4]:
# In jax we create a separate apply function to calculate the loss

@jax.jit
def apply_model(state, x, labels):
    def loss_fn(params):
        logits = state.apply_fn(
            {'params': params}, 
            x)
        #one_hot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels))
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (logits)), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy

In [5]:
# Now let's simulate a small training algorithm
from typing import Any
from flax.training import train_state
import optax

learning_rate = 0.0003

# Start with AdamW
tx = optax.adamw(learning_rate=learning_rate)

# Flax handles the training state for us. Generally, you should subclass the TrainState and 
# add anything you need to it. Here, we don't need any additions, but we may in the future
# so I will leave it in for now. Things you might add are the batch_stats or dropout keys
class TrainState(train_state.TrainState):
  key: jax.Array
  batch_stats: Any

# Replace this with our subclassed TrainState when we need it.
#state = train_state.TrainState.create(apply_fn=GPT2.apply, params=y_init['params'], tx=tx)

# Now try our apply function
#grads, loss, accuracy = apply_model(state, x, y)
#state = state.apply_gradients(grads=grads)


In [14]:
loss

Array(11.300755, dtype=float32)

In [28]:
# Okay, now let's create a small number of epochs to make sure it's doing something
# We get roughly the same results as Karpathy
num_epochs = 50

for epoch in range(1, num_epochs + 1):
    grads, loss, accuracy = apply_model(state, x, y)
    state = state.apply_gradients(grads=grads)

    print(f'step {epoch}, loss: {loss}')

step 1, loss: 11.310973167419434
step 2, loss: 8.371686935424805
step 3, loss: 7.288360118865967
step 4, loss: 7.384831428527832
step 5, loss: 6.379255294799805
step 6, loss: 5.839833736419678
step 7, loss: 5.47161340713501
step 8, loss: 5.100825309753418
step 9, loss: 4.55207633972168
step 10, loss: 4.081092834472656
step 11, loss: 3.5328078269958496
step 12, loss: 2.8801956176757812
step 13, loss: 2.2455806732177734
step 14, loss: 1.6977598667144775
step 15, loss: 1.2593261003494263
step 16, loss: 0.900545597076416
step 17, loss: 0.6233440041542053
step 18, loss: 0.40818876028060913
step 19, loss: 0.2590666115283966
step 20, loss: 0.16967424750328064
step 21, loss: 0.1210031658411026
step 22, loss: 0.0914439707994461
step 23, loss: 0.07090624421834946
step 24, loss: 0.05630887299776077
step 25, loss: 0.04591205716133118
step 26, loss: 0.03834603726863861
step 27, loss: 0.0326601043343544
step 28, loss: 0.02821948006749153
step 29, loss: 0.0246470645070076
step 30, loss: 0.02171604707

In [25]:
accuracy

Array(1., dtype=float32)

In [6]:
# Now let's do this for the entire tiny Shakespeare dataset
# We can follow Karpathy's style, but we will use the tensorflow datasets
# tools as our data loader as it is the natural thing to use with jax/flax
import tiktoken

class DataLoaderLite:
    def __init__(self, B, T):
        self.B = B
        self.T = T

        with open("input.txt", 'r') as f:
            text = f.read()
        enc = tiktoken.get_encoding('gpt2')
        tokens = enc.encode(text)
        # Could load it into gpu memory here if you want for speed. It would 
        # depend on how much you have to spare
        self.tokens = tokens
        print(f'Loaded {len(self.tokens)} tokens')
        print(f'One epoch = {len(self.tokens) // (B*T) } batches')

        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        buf = np.array(self.tokens[self.current_position:self.current_position+B*T+1])
        x = buf[:-1].reshape((B, T))
        y = buf[1:].reshape((B,T))
        self.current_position+=B*T

        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        
        return x, y

In [7]:
import time

train_loader = DataLoaderLite(16, 256)

learning_rate = 0.0003

# Start with AdamW
tx = optax.adamw(learning_rate=learning_rate)

# Flax handles the training state for us. Generally, you should subclass the TrainState and 
# add anything you need to it. Here, we don't need any additions, but we may in the future
# so I will leave it in for now. Things you might add are the batch_stats or dropout keys
class TrainState(train_state.TrainState):
  key: jax.Array
  batch_stats: Any

# Replace this with our subclassed TrainState when we need it.
state = train_state.TrainState.create(apply_fn=GPT2.apply, params=y_init['params'], tx=tx)

Loaded 338025 tokens
One epoch = 82 batches


In [8]:
num_epochs = 1000

for epoch in range(1, num_epochs + 1):
    t0 = time.time()
    x, y = train_loader.next_batch()
    grads, loss, accuracy = apply_model(state, x, y)
    state = state.apply_gradients(grads=grads)

    t1 = time.time()

    dt = (t1 - t0)*1000
    tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
    print(f'step {epoch}, loss: {loss}, dt: {dt:.2f}ms, tok/sec: {tokens_per_sec}')

step 1, loss: 10.875, dt: 14283.27ms, tok/sec: 286.7690876994621
step 2, loss: 10.25, dt: 626.04ms, tok/sec: 6542.705640365526
step 3, loss: 10, dt: 407.88ms, tok/sec: 10042.191075475415
step 4, loss: 9.75, dt: 364.51ms, tok/sec: 11237.126219790181
step 5, loss: 9.5625, dt: 437.70ms, tok/sec: 9358.032545284399
step 6, loss: 9.3125, dt: 412.43ms, tok/sec: 9931.398363098851
step 7, loss: 9.25, dt: 446.00ms, tok/sec: 9183.923192308966
step 8, loss: 9.125, dt: 436.62ms, tok/sec: 9381.052850458436
step 9, loss: 9, dt: 458.89ms, tok/sec: 8925.845394659935
step 10, loss: 8.875, dt: 376.08ms, tok/sec: 10891.381532184341
step 11, loss: 8.8125, dt: 400.12ms, tok/sec: 10237.032110441614
step 12, loss: 8.75, dt: 400.87ms, tok/sec: 10217.883564010652
step 13, loss: 8.6875, dt: 448.63ms, tok/sec: 9130.069503847626
step 14, loss: 8.625, dt: 363.84ms, tok/sec: 11257.59659884317
step 15, loss: 8.375, dt: 375.39ms, tok/sec: 10911.199551354479
step 16, loss: 8.25, dt: 419.38ms, tok/sec: 9766.788677033488