# FlaxGPT

FlaxGPT is a simplistic implementation of GPT (decoder-only transformer) model in [Flax](https://flax.readthedocs.io/en/latest/quick_start.html).

The code is minimum in a single notebook and therefore good for hacking and educational purposes.

In [1]:
!pip install -q -U einops

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━[0m [32m41.0/44.6 kB[0m [31m1.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m814.8 kB/s[0m eta [36m0:00:00[0m
[?25h

## Library

In [1]:
import os
import dataclasses
import json
from pathlib import Path
from typing import List

import rich
from sentencepiece import SentencePieceProcessor

import jax
import jax.numpy as jnp, jax.random as jrandom
from flax import linen as nn
from flax import traverse_util
from flax.training import orbax_utils
import orbax.checkpoint
import einops


@dataclasses.dataclass
class Config:
  """GPT Config."""
  max_seq_length: int      # maximum context length
  vocab_size: int          # vocabulary size
  n_embed: int             # embedding size (= n_head * head_size)
  n_layer: int             # number of transformer blocks
  intermediate_size: int   # intermediate size of FFN

  # multi head / multi query attention
  n_head: int              # number of heads
  n_query_groups: int      # number of query groups in multi-query attention

  # RoPE positional embedding
  rope_condense_ratio: int # rope condense ratio
  rope_base: int           # rope base

  def __post_init__(self):
    assert self.n_embed % self.n_head == 0, f'Embedding size(n_embed={self.n_embed}) should be divisible by (n_head={self.n_head})'
    self.head_size = self.n_embed // self.n_head

  @classmethod
  def llama2_7b(cls):
    return cls(
        max_seq_length=4096,
        vocab_size=32000,
        n_layer=32,
        n_head=32,
        n_embed=4096,
        n_query_groups=32,
        intermediate_size=11008,
        rope_base=10000,
        rope_condense_ratio=1,
    )


class RMSNorm(nn.Module):
  """Root Mean Square Layer Normalization.

  Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
  https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
  """
  axis: int = -1
  eps: float = 1e-5

  @nn.compact
  def __call__(self, x):
    weight = self.param('weight', lambda rng, shape: jnp.ones(shape), x.shape[-1])
    norm_x = jnp.mean(x * x, axis=self.axis, keepdims=True)
    x_normed = x / jnp.sqrt(norm_x + self.eps)
    return weight * x_normed


class SelfAttention(nn.Module):
  """Multi head / Multi query / Grouped Query Attention.

  About n_query_groups
  to use multi-head attention (MHA), set this to `n_head` (default)
  to use multi-query attention (MQA), set this to 1
  to use grouped-query attention (GQA), set this to a value in between
  Example with `n_head=4`
  ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
  │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
  └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
    │    │    │    │         │        │                 │
  ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
  │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
  └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
    │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
  ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
          MHA                    GQA                   MQA
    n_query_groups=4       n_query_groups=2      n_query_groups=1
    q_per_kv=1             q_per_kv=2            q_per_kv=4
    n_head=4               n_head=4              n_head=4
    n_qkv=3                n_qkv=4               n_qkv=6
  credit https://arxiv.org/pdf/2305.13245.pdf
  """
  config: Config

  @nn.compact
  def __call__(self, x, rope_emb, mask=None):
    T = x.shape[-2]  # x: (B, T, C)
    mask = mask or jnp.tril(jnp.ones((T, T)))

    # nq = n_head, nk = nv = n_query_groups
    qkv_dim = (self.config.n_head + 2 * self.config.n_query_groups) * self.config.head_size
    qkv_proj = nn.Dense(features=qkv_dim, use_bias=False, name='proj_qkv')(x)

    # number of q's per group
    q_per_kv = self.config.n_head // self.config.n_query_groups
    # number of qkvs per group, k=v=1
    n_qkv = q_per_kv + 2
    # break embedding into (n_groups, n_qkv, head_size)
    qkv = einops.rearrange(qkv_proj, 'b t (n_groups n_qkv h) -> b n_groups n_qkv t h',
                           n_groups=self.config.n_query_groups,
                           n_qkv=n_qkv)
    # split q, k, v within groups
    q, k, v = einops.unpack(qkv, [[q_per_kv], [1], [1]], 'b n_groups * t h')

    if q_per_kv != 1:
      # repeat k and v in each group
      k = einops.repeat(k, 'b n_groups 1 t h -> b n_groups q_per_kv t h', q_per_kv=q_per_kv)
      v = einops.repeat(v, 'b n_groups 1 t h -> b n_groups q_per_kv t h', q_per_kv=q_per_kv)

    # merge groups into heads
    q = einops.rearrange(q, 'b n_groups q_per_kv t h -> b (n_groups q_per_kv) t h')
    k = einops.rearrange(k, 'b n_groups q_per_kv t h -> b (n_groups q_per_kv) t h')
    v = einops.rearrange(v, 'b n_groups q_per_kv t h -> b (n_groups q_per_kv) t h')

    # apply position embedding
    # NOTE: only apply to q and k, but not v
    q = apply_rope(q, rope_emb)
    k = apply_rope(k, rope_emb)

    # multi head scaled dot attention
    weights = einops.einsum(q, k, 'b nh tq h, b nh tk h -> b nh tq tk')
    weights = weights / jnp.sqrt(self.config.head_size)
    weights = jnp.where(mask, weights, float('-inf'))
    weights = nn.softmax(weights, axis=-1)
    out = einops.einsum(weights, v, 'b nh tq tv, b nh tv h -> b tq nh h')

    # concat heads
    out = einops.rearrange(out, 'b t nh h -> b t (nh h)')

    # final projection
    out = nn.Dense(self.config.n_embed, use_bias=False, name='proj_out')(out)

    return out


class MLP(nn.Module):
  """LLaMA style MLP."""
  config: Config

  @nn.compact
  def __call__(self, x):
    x1 = nn.Dense(self.config.intermediate_size, use_bias=False, name='fc_1')(x)
    x2 = nn.Dense(self.config.intermediate_size, use_bias=False, name='fc_2')(x)
    x = nn.silu(x1) * x2
    x = nn.Dense(self.config.n_embed, use_bias=False, name='proj_out')(x)
    return x


class Block(nn.Module):
  """A Transformer Block of attention followed by MLP."""
  config: Config

  @nn.compact
  def __call__(self, x, rope_emb, mask=None):
    n1 = RMSNorm(name='norm_1')(x)
    h = SelfAttention(self.config, name='attn')(n1, rope_emb, mask=mask)
    x = h + x
    n2 = RMSNorm(name='norm_2')(x)
    h = MLP(self.config, name='mlp')(n2)
    x = h + x
    return x


class GPT(nn.Module):
  """The full decoder only tranformer."""

  config: Config

  @nn.compact
  def __call__(self, x):
    T = x.shape[-1]  # (B, T)
    rope_emb = self.variable('cache', 'rope_emb', build_rope_cache,
                             self.config.max_seq_length,
                             self.config.head_size,
                             self.config.rope_base,
                             self.config.rope_condense_ratio)

    x = nn.Embed(num_embeddings=self.config.vocab_size,
                 features=self.config.n_embed, name='emb')(x)
    self.sow('intermediates', 'emb_out', x)

    for i in range(self.config.n_layer):
      x = Block(config=self.config, name=f'block_{i}')(x, rope_emb.value[:T])
      self.sow('intermediates', f'block_{i}_out', x)

    # final layer norm
    x = RMSNorm(name='ln_f')(x)
    self.sow('intermediates', 'ln_out', x)
    # language model head
    x = nn.Dense(self.config.vocab_size, name='lm_head', use_bias=False)(x)

    return x


class Tokenizer:
  """tokenizing and encoding/decoding text using SentencePiece.

  Taken from llama codebase: https://github.com/facebookresearch/llama/blob/main/llama/tokenizer.py
  """
  def __init__(self, model_path: str):
    # reload tokenizer
    assert os.path.isfile(model_path), model_path
    self.sp_model = SentencePieceProcessor(model_file=model_path)

    # BOS / EOS token IDs
    self.n_words: int = self.sp_model.vocab_size()
    self.bos_id: int = self.sp_model.bos_id()
    self.eos_id: int = self.sp_model.eos_id()
    self.pad_id: int = self.sp_model.pad_id()

    assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

  def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
    assert type(s) is str
    t = self.sp_model.encode(s)
    if bos:
      t = [self.bos_id] + t
    if eos:
      t = t + [self.eos_id]
    return t

  def decode(self, t: List[int]) -> str:
    return self.sp_model.decode(t)


def generate(key, model, tokenizer, variables, prompt, max_tokens=100):
  """Example implementation for sampling from the model."""

  x = jnp.array(tokenizer.encode(prompt, bos=True, eos=False)).reshape(1, -1)
  result = x[0].tolist()

  for t in range(max_tokens):
    x = x[..., -model.config.max_seq_length:]
    logits = model.apply(variables, x=x)
    next_token_logits = logits[:, -1, :]
    next_token = jrandom.categorical(jrandom.fold_in(key, t), next_token_logits)
    print(f't={t}', tokenizer.decode(result))
    if next_token.item() == tokenizer.eos_id:
      break
    result.append(next_token.item())
    x = jnp.concatenate((x, next_token.reshape(-1,1)), axis=-1)

  return tokenizer.decode(result)


def build_rope_cache(
    seq_len: int,
    n_elem: int,
    base: int = 10000,
    condense_ratio: int = 1
):
  """Enhanced Transformer with Rotary Position Embedding.

  Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
  transformers/rope/__init__.py. MIT License:
  https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
  """
  # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
  theta = 1.0 / (base ** (jnp.arange(0, n_elem, 2) / n_elem))

  # Create position indexes `[0, 1, ..., seq_len - 1]`
  seq_idx = jnp.arange(seq_len) / condense_ratio

  # Calculate the product of position index and $\theta_i$
  idx_theta = jnp.outer(seq_idx, theta)
  idx_theta = jnp.tile(idx_theta, (1, 2))

  return jnp.stack([jnp.cos(idx_theta), jnp.sin(idx_theta)], axis=-1)


def apply_rope(x, rope_emb):
  """Apply rope embedding to input x."""
  cos, sin = rope_emb[..., 0], rope_emb[..., 1]
  head_size = x.shape[-1]
  x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
  x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)
  rotated = jnp.concatenate((-x2, x1), axis=-1)  # (B, nh, T, hs)
  roped = (x * cos) + (rotated * sin)
  return roped


def save_checkpoint(variables, path: Path):
  """Saves model to checkpoint."""
  orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
  model_index = {}
  for name, value in traverse_util.flatten_dict(variables, sep='.').items():
    ckpt = {'value': value}
    orbax_checkpointer.save(path / name, ckpt, save_args=orbax_utils.save_args_from_target(ckpt))
    model_index[name] = 'true'
    print(f'Saved {name}')

  with open(path / 'model_index.json', 'w') as f:
    json.dump(model_index, f)
  print(f'Save success')


def load_checkpoint(path: Path):
  """Loads model from checkpoint."""
  with open(path / 'model_index.json') as f:
    model_index = json.load(f)

  orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

  variables = {}
  for name in model_index:
    ckpt = orbax_checkpointer.restore(path / name)
    variables[name] = jax.device_put(ckpt['value'])
    print(f'Loaded variable: {name}')

  return traverse_util.unflatten_dict(variables, sep='.')

## Utils

In [None]:
# @title Checkpoint Converter
# @markdown Download the model and tokenizer files from [hugginface](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main).
# @markdown
# @markdown The needed files are
# @markdown - `pytorch_model-*.bin`
# @markdown - `pytorch_model.bin.index.json`
# @markdown - `tokenizer.model`

llama_checkpoint = '/content/drive/MyDrive/checkpoints/meta-llama/Llama-2-7b-chat-hf' # @param {type:"string"}
flaxgpt_checkpoint = '/content/drive/MyDrive/checkpoints/flax-gpt/Llama-2-7b-chat-hf' # @param {type:"string"}

import re
import torch

def init_from_llama_checkpoint(
    model: GPT,
    config: Config,
    llama_checkpoint: Path):
  variable_shapes = jax.eval_shape(model.init, jrandom.key(0), jnp.zeros((1,1), dtype=jnp.int32))
  param_shapes = traverse_util.flatten_dict(variable_shapes['params'], sep='.')

  if not llama_checkpoint.is_dir():
    raise ValueError(f'llama checkpoint directory {llama_checkpoint} does not exist')

  with open(llama_checkpoint / 'pytorch_model.bin.index.json', 'r') as f:
    llama_model_index = json.load(f)

  param_names = sorted(llama_model_index['weight_map'].keys())
  files = sorted(list(set(llama_model_index['weight_map'].values())))

  params = {}

  def load_params(name, value):
    if hasattr(value, 'numpy'):
      value = value.numpy()
    assert name in param_shapes, f'Param does not exist: {name}'
    assert param_shapes[name].shape == value.shape, f'Shapes not match: {name} Expected: {param_shapes[name].shape} Actual: {value.shape}'
    params[name] = jax.device_put(value)
    print(f'Success: loaded param: {name}, dtype:{params[name].dtype} shape:{params[name].shape} device:{params[name].device()}')

  qkv = {i: {} for i in range(config.n_layer)}

  for f in files:
    print(f'Loading checkpoint file {f}')
    states = torch.load(llama_checkpoint / f)
    for name, value in states.items():
      if name == 'lm_head.weight':
        load_params('lm_head.kernel', value.T)
      elif name == 'model.embed_tokens.weight':
        load_params('emb.embedding', value)
      elif name == 'model.norm.weight':
        load_params('ln_f.weight', value)
      elif ret := re.match(r'model.layers\.(\d+)\.(.*)', name):
        i, sub_name = ret.groups()
        i = int(i)
        if sub_name == 'input_layernorm.weight':
          load_params(f'block_{i}.norm_1.weight', value)
        elif sub_name == 'post_attention_layernorm.weight':
          load_params(f'block_{i}.norm_2.weight', value)
        elif sub_name == 'mlp.gate_proj.weight':
          load_params(f'block_{i}.mlp.fc_1.kernel', value.T)
        elif sub_name == 'mlp.up_proj.weight':
          load_params(f'block_{i}.mlp.fc_2.kernel', value.T)
        elif sub_name == 'mlp.down_proj.weight':
          load_params(f'block_{i}.mlp.proj_out.kernel', value.T)
        elif sub_name == 'self_attn.o_proj.weight':
          load_params(f'block_{i}.attn.proj_out.kernel', value.T)
        elif sub_name == 'self_attn.q_proj.weight':
          qkv[i]['q'] = value.numpy()
        elif sub_name == 'self_attn.k_proj.weight':
          qkv[i]['k'] = value.numpy()
        elif sub_name == 'self_attn.v_proj.weight':
          qkv[i]['v'] = value.numpy()
        elif sub_name == 'self_attn.rotary_emb.inv_freq':
          pass
        else:
          raise ValueError(f'unhandled param: {name}')
      else:
        raise ValueError(f'unhandled param: {name}')
    del(states)  # save memory

  def combine_qkv(q, k, v, n_heads):
    q = einops.rearrange(q, '(nh h) n_embed -> n_embed nh 1 h', nh=n_heads)
    k = einops.rearrange(k, '(nh h) n_embed -> n_embed nh 1 h', nh=n_heads)
    v = einops.rearrange(v, '(nh h) n_embed -> n_embed nh 1 h', nh=n_heads)

    packed, _ = einops.pack([q,k,v], 'n_embed nh * h')
    qkv = einops.rearrange(packed, 'n_embed nh n_qkv h -> n_embed (nh n_qkv h)')
    return qkv

  for i in range(config.n_layer):
    q, k, v = qkv[i]['q'], qkv[i]['k'], qkv[i]['v']
    proj_qkv = combine_qkv(q, k, v, config.n_head)
    load_params(f'block_{i}.attn.proj_qkv.kernel', proj_qkv)

  return traverse_util.unflatten_dict(params, sep='.')


llama2_7b_config = Config.llama2_7b()

model = GPT(llama2_7b_config)
params = init_from_llama_checkpoint(
    model, llama2_7b_config, Path(llama_checkpoint))

variables = model.init(jrandom.key(0), jnp.zeros((1,1), dtype=jnp.int32))
del(variables['params'])
variables['params'] = params

save_checkpoint(variables, Path(flaxgpt_checkpoint))

# Generation Demo

In [2]:
import jax
import jax.random as jrandom
import jax.numpy as jnp

tokenizer_model_path = '/content/drive/MyDrive/checkpoints/meta-llama/Llama-2-7b-chat-hf/tokenizer.model'
flaxgpt_checkpoint = '/content/drive/MyDrive/checkpoints/flax-gpt/Llama-2-7b-chat-hf'
model = GPT(Config.llama2_7b())
variables = load_checkpoint(Path(flaxgpt_checkpoint))
tokenizer = Tokenizer(tokenizer_model_path)

prompt = "Hello, my name is"
max_tokens = 10
key = jrandom.PRNGKey(1337)
x = jnp.array(tokenizer.encode(prompt, bos=True, eos=False)).reshape(1, -1)
result = x[0].tolist()

for t in range(max_tokens):
  x = x[..., -model.config.max_seq_length:]
  logits = model.apply(variables, x=x)
  next_token_logits = logits[:, -1, :]
  next_token = jrandom.categorical(jrandom.fold_in(key, t), next_token_logits)
  print(f't={t}', tokenizer.decode(result))
  if next_token.item() == tokenizer.eos_id:
    break
  result.append(next_token.item())
  x = jnp.concatenate((x, next_token.reshape(-1,1)), axis=-1)


Loaded variable: cache.rope_emb
Loaded variable: params.emb.embedding
Loaded variable: params.block_0.attn.proj_out.kernel
Loaded variable: params.block_0.attn.proj_qkv.kernel
Loaded variable: params.block_0.mlp.fc_1.kernel
Loaded variable: params.block_0.mlp.fc_2.kernel
Loaded variable: params.block_0.mlp.proj_out.kernel
Loaded variable: params.block_0.norm_1.weight
Loaded variable: params.block_0.norm_2.weight
Loaded variable: params.block_1.attn.proj_out.kernel
Loaded variable: params.block_1.attn.proj_qkv.kernel
Loaded variable: params.block_1.mlp.fc_1.kernel
Loaded variable: params.block_1.mlp.fc_2.kernel
Loaded variable: params.block_1.mlp.proj_out.kernel
Loaded variable: params.block_1.norm_1.weight
Loaded variable: params.block_1.norm_2.weight
Loaded variable: params.block_2.attn.proj_out.kernel
Loaded variable: params.block_2.attn.proj_qkv.kernel
Loaded variable: params.block_2.mlp.fc_1.kernel
Loaded variable: params.block_2.mlp.fc_2.kernel
Loaded variable: params.block_2.mlp.