In [1]:
from absl import logging
import flax
import jax.numpy as jnp
import jax
import numpy as np

logging.set_verbosity(logging.INFO)

In [2]:
from transformers import FlaxGPT2LMHeadModel

model_hf = FlaxGPT2LMHeadModel.from_pretrained("openai-community/gpt2")

In [5]:
# Here is the dictionary tree we have to replicate. pytrees are just nested dictionaries.
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params)))

initialized parameter shapes:
 {'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, 

In [3]:
from flax import linen as nn
# define a full GTP2 class in Flax and see if we can replicate the original paper along with
# Karpathy

# TODO:
# - Add options for different dtypes of the model
# - Add options for difference precisions (and understand that better - I think it has to do with the mantissa of the floats
# - Add istraining variable to allow dropout and other operations that are different during training and inference
# - Maybe: Flash attention?
# - Maybe: Sharding for multiple gpu training?
# - Use the Jax configuration utilities instead of the GPTConfig class
# - Develop my own training routing with all of the bells and whistles required: timing, checkpointinng, etc...
# - Move everything to python files for more professional-like code from the command line

class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768

class GPTMLP(nn.Module):
    config: GPTConfig

    def setup(self):
        # Simple MLP that upscales, runs through a gelu activation,
        # and then resamples back to the n_embd size (the model size)
        #self.c_fc = nn.Dense(4 * self.config.n_embd)
        # Had to use Einsum to match the matrix multiplication of GPT2 and pytorch. They accept 
        # both shapes and then multiply by the transpose of the matrix (basically means the 
        # shape of the matrix is transpose, but the operation is the same). I confirmed that this
        # produces the same result as the complicated huggingface conv1d version (conv1d is also
        # just a linear matrix operation as well). They do add a lot of variables for 
        # mixed-precision training, that I do not.
        # TODO: Move this into a new module as I am repeating it everywhere
        self.c_fc = nn.Einsum((4 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')
        #self.c_proj = nn.Dense(self.config.n_embd)
        self.c_proj = nn.Einsum((self.config.n_embd, 4 * self.config.n_embd), '...ij,...kj->...ik')

    def __call__(self, inputs):
        x = inputs

        x = self.c_fc(x)
        x = nn.gelu(x, approximate=True)
        x = self.c_proj(x)

        return x
    
class GPTAttention(nn.Module):
    config: GPTConfig

    # we will need to roll our own attention module because the built in one has a bunch of different
    # naming and structure compared to the original GPT, which just references the projection layers
    def setup(self):
        # The first thing we do is project up to 3x the model size, because we are going to split
        # the data into q, v, k
        self.c_attn = nn.Einsum((3 * self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')

        # At the end we have to project everything back to the regular model size of n_embd
        self.c_proj = nn.Einsum((self.config.n_embd, self.config.n_embd), '...ij,...kj->...ik')

    def __call__(self, inputs):
        x = inputs
        B, T, C = jnp.shape(x)
        
        # Project to qkv
        qkv = self.c_attn(x)

        q, k, v = jnp.split(qkv, 3, axis=2)

        query_length, key_length = q.shape[1], k.shape[1]

        # Now reshape with a new head "batch" dimension
        k = jnp.reshape(k, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        q = jnp.reshape(q, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)
        v = jnp.reshape(v, (B, T, self.config.n_head, C // self.config.n_head)) # Shape is (batch, tokens, num_heads, size of head)

        # Make the attention mask
        # TODO: For round 1 I'm just trusting the linen functions. They seem to be doing the correct thing here, but I may have to 
        # return to this for a closer look at the math if I'm not getting the GPT2 results
        # Just copied this from the huggingface code to be consistent. First part just broadcasts the causal masks to the batch
        # dimensions (replicating it as a lower triangular matrix of truths). The attention_bias is a stripped down version of the 
        # huggingface code, but the bias has to be floats. They bias the attention softmax, so basically set to -inf where you
        # want it ignored. This will need to be OR'd with an attention mask in certain situations, like encoder/decoder networks.
        # TODO: Also, this mask does not need to be applied during inference, so we could have an 'istraining' variable passed 
        # down the network for these cases and then ignore the mask calculations for increased speed. In the Flax examples, they 
        # also appear to cache the results of previous calculations, which I guess makes sense because we are just adding one 
        # token at a time to the input sequence and then calculating again. There is no point in recalculating the previous 
        # token outputs every time. I should probbaly implement some version of that too.
        c_mask = nn.make_causal_mask(jnp.ones((1, self.config.block_size), dtype="bool"), dtype="bool")
        c_mask = c_mask[:, :, :query_length, :key_length]
        c_mask = jnp.broadcast_to(c_mask, (B,) + c_mask.shape[1:])

        attention_bias = jax.lax.select(
                c_mask > 0,
                jnp.full(c_mask.shape, 0.0).astype(jnp.float32),
                jnp.full(c_mask.shape, jnp.finfo(jnp.float32).min).astype(jnp.float32),
            )

        # use the built in flax libraries to calculate the attention matrix - attention weights are not returned, but could be 
        # I think bias gives more control compared to mask - i.e. bias can be a float. They might result in the same output with a 
        # boolean mask, but I will have to test that.
        # TODO: I don't think Flax has a flash attention module. Is there any way to add that for Flax that will actually be 
        # optimized for hardware? I don't know.
        y = nn.dot_product_attention(q, k, v, bias=attention_bias)

        # Merge the heads back together
        y = y.reshape((B, T, C))

        # Project output with a FC layer
        y = self.c_proj(y)

        return y

class GPTBlock(nn.Module):
    config: GPTConfig

    def setup(self):
        self.ln_1 = nn.LayerNorm(epsilon=1e-05)
        # I might have to write this manually to get the proper number of parameters, as the old GPT2 code 
        # migh have subtle differences from the Flax implementation
        self.attn = GPTAttention(self.config)
        self.ln_2 = nn.LayerNorm(epsilon=1e-05)
        self.mlp = GPTMLP(self.config)

    def __call__(self, inputs):
        x = inputs
        x = self.ln_1(x)
        x = inputs + self.attn(x)
        inputs2 = x
        x = self.ln_2(x)
        x = inputs2 + self.mlp(x)
        return x

class GPTLayers(nn.Module):
    config: GPTConfig
    
    def setup(self):
        self.blocks = [ GPTBlock(self.config, name=str(i)) for i in range(self.config.n_layer) ]

    def __call__(self, inputs):
        x = inputs

        for block in self.blocks:
            x = block(x)
        return x

class GPTModel(nn.Module):
    config: GPTConfig

    def setup(self):
        # This is a little confusing. vocab size is the number of embeddings we need.
        # n_embd is the dimension of each embedding (i.e. capacity for learning properties
        # about this embedding token)
        # input size = (1 x self.block_size) - 1 int for each token
        # output size = then (self.block_size x self.n_embd)
        self.wte = nn.Embed(self.config.vocab_size, self.config.n_embd)
        # embed is just a randomzied parameter matrix, so can be used for positional 
        # encoding as well. I think block size is the token length.
        # This has to match the size of the previous output, as we are just adding.
        self.wpe = nn.Embed(self.config.block_size, self.config.n_embd)
        # The attention layers
        self.h = GPTLayers(self.config)
        self.ln_f = nn.LayerNorm(epsilon=1e-05)

    def __call__(self, inputs):
        x = inputs
        input_shape = jnp.shape(x)       

        x = self.wte(x)

        # For the positional encodings we need an index that is simple the position
        # of each token. This will be the same shape as the input, but will simply
        # be repeating and increasing numbers from 1.
        # jnp.atleast_2d is needed so we can initialize with a batch size of 1
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(inputs).shape[-1]), input_shape)
        x_wpe = self.wpe(position_ids)
        
        x += x_wpe

        x = self.h(x)

        x = self.ln_f(x)

        return x


class GPT(nn.Module):
    config: GPTConfig

    def setup(self):
        self.transformer = GPTModel(self.config)
        # So the Flax model does not return parameters for lm_head, because lm_head is simply the 
        # inverse operation of the initial word embedding, so we can just reuse those weights. They
        # use the term 'tied' for this. The weights are tied together.
        self.lm_head = nn.Dense(self.config.vocab_size, use_bias=False)

    def __call__(self, inputs):
        x = inputs
        x = self.transformer(x)

        # This is from the huggingface code - might as well reuse it here.
        shared_kernel = self.transformer.variables['params']['wte']['embedding'].T
        lm_logits = self.lm_head.apply({'params': {'kernel': shared_kernel}}, x)
        return lm_logits

In [11]:
# Initialize
GPT2 = GPT(GPTConfig)

key1, key2 = jax.random.split(jax.random.key(0), 2)

# We never want to use int64 as it signficantly slows down training
x = np.random.randint(0, GPTConfig.vocab_size, (1, GPTConfig.block_size), dtype='i4')

y = GPT2.init(key2, x)

In [8]:
# Inference
import tiktoken
import jax
import jax.numpy as jnp
#from tensorflow_probability.substrates import jax as tfp

#tfd = tfp.distributions

# Generate a starting sequence and we will generate the rest of the sequence with GPT2 for
# however many batches we need
num_return_sequences = 20
max_length = 50

enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode("The advance of machine learning")

tokens = jnp.array(tokens, dtype="i4")

initial_input = jnp.atleast_2d(tokens).repeat(num_return_sequences, 0)

init_length=jnp.shape(initial_input)[-1]

# Has to be a constant length for the fast-autogression, so we need to pad (we will ignore the pads until 
# those indicies are actually needed by the model as it grows)
initial_input=jnp.pad(initial_input, ((0, 0), (0, max_length-init_length)), mode='constant', constant_values=0)

In [12]:
# Jax-style fast autoregression using the jax looping tools. This takes a bit of time to get correct
# with each model as you have to design it using the XLA tools and every single parameter needs to 
# be deterministic. This leads to lots of gotchas. For example, dynamic_slice and dynamic_slice_update
# need to have a constant for the size, but can be placed anywhere in the matrix (i.e., you can't use
# a variable with this). And, take_along_axis needs to be used for the top_k type indexing as it uses 
# the XLA gather method underneath, which is super complicated to use by hand. Every model produces 
# new difficulties and I'm sure this can be optimized quite a bit if I understood jax a bit better.
rng = jax.random.key(2)

temperature = 0.7

def get_sentence(n, carry):
    x, rng = carry

    rng, new_rng = jax.random.split(rng)

    # pull out only the sequence up until the currently filled in
    # values. It would be better if we could pass in a variable 
    # length sequence, but I don't think that's allowed by jax
    logits = GPT2.apply({'params': model_hf.params}, x)

    # GPT2 predicted the new token int he last column of the output, 
    # so we only need this.
    logits = logits[:, n, :]

    # Sample from the top 50 tokens
    topk_probs, topk_indices = jax.lax.top_k(logits, 50)

    ix = jnp.expand_dims(jax.random.categorical(rng, topk_probs/temperature, axis=-1), axis=-1)

    #print(ix)

    xcol=jnp.atleast_2d(jnp.take_along_axis(topk_indices, ix.astype(jnp.int32), axis=-1))

    x = jax.lax.dynamic_update_slice(x, xcol, (0, n+1))

    return x, rng

x, rng = jax.lax.fori_loop(init_length-1, max_length-1, get_sentence, (initial_input, rng))


In [13]:
# Print the results
for i in range(0, x.shape[0]):
    print(enc.decode(x[i,:].tolist()))

The advance of machine learning algorithms and the development of machine learning platforms allow us to rapidly develop and test new techniques for solving complex problems in real-time.

Our customers can be expected to contribute to the development of a new algorithm.


The advance of machine learning and machine learning with algorithms is a long way away. Machine learning is one of the most important tools for improving data science and machine learning development. However, one of the many challenges of machine learning is learning for information. Data
The advance of machine learning has brought with it a new era of artificial intelligence and machine learning is now one of the most exciting developments in recent memory.

The machine learning toolkit that we are using is called Deep Learning, and it is a
The advance of machine learning has made it possible to design artificial intelligence to handle complex situations. But that is only one part of the story. The next part is th

In [16]:
# Just comparing the shapes of our model vs the GPT2 trained weights.
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(y['params']['transformer'])))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(model_hf.params['transformer'])))

initialized parameter shapes:
 {'h': {'0': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '1': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '10': {'attn': {'c_attn': {'bias': (2304,), 'kernel': (2304, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 768)}}, 'ln_1': {'bias': (768,), 'scale': (768,)}, 'ln_2': {'bias': (768,), 'scale': (768,)}, 'mlp': {'c_fc': {'bias': (3072,), 'kernel': (3072, 768)}, 'c_proj': {'bias': (768,), 'kernel': (768, 3072)}}}, '11': {'attn': {