# Roadmap
- Implement GPT from scratch using flax
  - [x] Train a reference pytorch model (scaled down version of LLaMA2 "tiny GPT") using [Lit-GPT](https://github.com/Lightning-AI/lit-gpt)
  - [x] Re-implement all layers using flax
  - [x] The model should be numerically equivalent to the original model. We can verify this by loading a checkpoint from the pytorch model and compare the results.
- [ ] do prediction on the tiny model
- [ ] Load LLaMA-7B checkpoint and do prediction
- [ ] train tiny GPT in jax to match the metrics of the reference model
- [ ] Implement K-V cache in prediction
- [ ] Finetuning
- [ ] LoRA finetuning
- [ ] Quantization
- [ ] Distributed Training (TPUs)

# Setup

In [None]:
# Lit-GPT (for reference model)
!pip install -U git+https://github.com/Lightning-AI/lit-gpt.git torchaudio torchdata torchtext torchvision

In [None]:
# Jax libraries
!pip install einops git+https://github.com/google/CommonLoopUtils.git

Collecting git+https://github.com/google/CommonLoopUtils.git
  Cloning https://github.com/google/CommonLoopUtils.git to /tmp/pip-req-build-h6t7xoq0
  Running command git clone --filter=blob:none --quiet https://github.com/google/CommonLoopUtils.git /tmp/pip-req-build-h6t7xoq0
  Resolved https://github.com/google/CommonLoopUtils.git to commit 1368e52d0876dd0c90894793e8e9e97fc6f98adc
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m762.4 kB/s[0m eta [36m0:00:00[0m
Collecting ml_collections (from clu==0.0.11)
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: clu, ml_collections
  Building wheel for clu (setup.py) 

# Train a reference model

Transformers are a bit more complicated than linear models. To ensure we have the correct implementation, we first train a scaled down version of the LLaMA2 model using Lit-GPT (we call it "tiny GPT"), and then use it as a reference.

We are using the [Tiny shakespear dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) from Andrej Karpathy to train this model.

## Tiny GPT

This tiny GPT model is a scaled down version of LLaMA2, with less layers, smaller head sizes and embedding sizes. This gives us a fast model for development.

For convinence, I've pointed the checkpoint to Google Drive so I can reuse it without having to train it every time.

In [None]:
from pathlib import Path

checkpoint_root = '/content/drive/MyDrive/checkpoints'
checkpoint_path = Path(checkpoint_root) / 'tiny_gpt'
checkpoint_path.mkdir(exist_ok=True)

In [None]:
%%writefile {checkpoint_path}/lit_config.json
{
  "name": "tiny_gpt",
  "block_size": 128,
  "vocab_size": 32000,
  "padding_multiple": 64,
  "padded_vocab_size": 32000,
  "n_layer": 4,
  "n_head": 4,
  "n_embd": 128,
  "rotary_percentage": 1.0,
  "parallel_residual": false,
  "bias": false,
  "lm_head_bias": false,
  "n_query_groups": 4,
  "shared_attention_norm": false,
  "_norm_class": "RMSNorm",
  "norm_eps": 1e-05,
  "_mlp_class": "LLaMAMLP",
  "gelu_approximate": "none",
  "intermediate_size": 512,
  "rope_condense_ratio": 1,
  "rope_base": 10000,
  "n_expert": 0,
  "n_expert_per_token": 0
}

Overwriting /content/drive/MyDrive/checkpoints/tiny_gpt/lit_config.json


We also need the tokenizer file from LLaMA, which can be downloaded here:

https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main

download the `tokenizer.model` file and put it into the checkpoint dir.

## Training Tiny GPT in pytorch

In [None]:
import json
import os
import shutil

import torch
from torch.nn import functional as F
from tqdm.notebook import tqdm
from lit_gpt import GPT, Config, Tokenizer

tokenizer = Tokenizer(checkpoint_path)
config = Config.from_json(checkpoint_path / "lit_config.json")

if not os.path.exists('input.txt'):
  !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open('input.txt', 'r', encoding='utf-8') as f:
  text = f.read()

torch.manual_seed(1337)
batch_size = 8
block_size = 16

data = tokenizer.encode(text).long()
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]


def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y


def generate(model, tokenizer, prompt, max_tokens=100, temperature=1.0, context_window=50):
  input = tokenizer.encode(prompt).view(1, -1)
  eos_id=tokenizer.eos_id
  model.eval()

  result = [input[0]]

  for _ in range(max_tokens):
    # truncate
    input = input[:, -context_window:]
    with torch.no_grad():
      logits = model(input)

    next_token_logits = logits[0, -1, :]
    probs = torch.nn.functional.softmax(next_token_logits / temperature, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    result.append(next_token)
    if next_token.cpu().item() == eos_id:
      break
    input = torch.cat((input, next_token.view(1, -1)), dim=1)

  return tokenizer.decode(torch.cat(result).cpu().numpy())


model = GPT(config)


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
total_steps = 10000

with tqdm(range(total_steps)) as pbar:
  for step in pbar:
    model.train()
    xb, yb = get_batch('train')
    B, T = xb.shape
    logits = model(xb)
    yb = yb.view(-1)
    logits = logits.view(B*T, -1)
    loss = F.cross_entropy(logits, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    pbar.set_description(f'train_loss={loss.item():.3f}')
    optimizer.step()

    if step % (total_steps // 5) == 0 or step == total_steps-1:
      model.eval()
      losses = []
      for _ in range(100):
        with torch.no_grad():
          xb, yb = get_batch('valid')
          B, T = xb.shape
          logits = model(xb)
          yb = yb.view(-1)
          logits = logits.view(B*T, -1)
          loss = F.cross_entropy(logits, yb).item()
          losses.append(loss)
      avg_loss = torch.tensor(losses).mean().item()
      print(f'{step=}, validation loss={avg_loss}')
      # save checkpoint
      model_ckpt_path = checkpoint_path / f'model-{step}.pth'
      torch.save(model.state_dict(), model_ckpt_path)
      shutil.copy(model_ckpt_path, checkpoint_path / 'lit_model.pth')

print('Test sampling model:')
print(generate(model, tokenizer, 'Shakespear:\n', max_tokens=100, context_window=50))

  0%|          | 0/10000 [00:00<?, ?it/s]

step=0, validation loss=10.444548606872559
step=2000, validation loss=4.689100742340088
step=4000, validation loss=4.564824104309082
step=6000, validation loss=4.602177143096924
step=8000, validation loss=4.63491153717041
step=9999, validation loss=4.567898750305176
Test sampling model:
Shakespear:
My brother is the gates of seventeen,
To wunder-a wage to cheek
Who which he hath moved the court-trees her brother season on my lie
you in them. Friar tune will teach his shame, my good heart!
Antonio, mules extremes, and where poor man may be patiently
ber-chender-twenty your highness uncartius,
And stopsing eye in thy digressing rit the


# GPT Components

A LLaMA style GPT model (decoder-only transformer) looks like this:

[![](https://mermaid.ink/img/pako:eNp9kc9rwyAUx_8VeScDbaA55rBD2WCHRsrWozBcfG0lUYMxjNL0f59R1w025uH9-H4_8sR3hdZKhBqOvf1oz8J5sj1wQ8JRZpg8Wa8fyJN-R0ljLJIX6-jNB9uhWYSZbHvbdptEvNgB6RKKhO3t-BtKNY2JbBYyoLGrslhlcS7LMt9l2WLF96S_ZlT_2-zHE1gEduKCjlmn6b3KI-594ppnFJLumrclfyFRjH5vT8qP3MAKNDotlAz_e10wDv6MGjnUoZTCdRy4uQVOTN6-XkwLtXcTrmAapPD4qMTJCQ31UfRjUFEqb12TFhb3dvsE9HGKjQ?type=png)](https://mermaid.live/edit#pako:eNp9kc9rwyAUx_8VeScDbaA55rBD2WCHRsrWozBcfG0lUYMxjNL0f59R1w025uH9-H4_8sR3hdZKhBqOvf1oz8J5sj1wQ8JRZpg8Wa8fyJN-R0ljLJIX6-jNB9uhWYSZbHvbdptEvNgB6RKKhO3t-BtKNY2JbBYyoLGrslhlcS7LMt9l2WLF96S_ZlT_2-zHE1gEduKCjlmn6b3KI-594ppnFJLumrclfyFRjH5vT8qP3MAKNDotlAz_e10wDv6MGjnUoZTCdRy4uQVOTN6-XkwLtXcTrmAapPD4qMTJCQ31UfRjUFEqb12TFhb3dvsE9HGKjQ)



By inpsecting the reference model, we can know the layers in GPT model to be implemented: Embedding, Attention, MLP and LayerNorm(RMSNorm):

In [None]:
import lit_gpt
from pathlib import Path
import dataclasses
import rich

checkpoint_root = '/content/drive/MyDrive/checkpoints'
checkpoint_path = Path(checkpoint_root) / 'tiny_gpt'

lit_config = lit_gpt.Config.from_json(checkpoint_path / "lit_config.json")
lit_model = lit_gpt.GPT(lit_config)
rich.print('Model Config:', lit_config)
rich.print('Model:', lit_model)

  _torch_pytree._register_pytree_node(


# Config

To avoid passing paramters from parent layers to sublayers, we define a config (like in Lit-GPT) for the whole transformer and pass it to every sub layer:

In [None]:
import dataclasses
import torch
from flax import linen as nn
import jax
import jax.numpy as jnp, jax.random as jrandom
import einops

@dataclasses.dataclass
class Config:
  max_seq_length: int      # maximum context length
  vocab_size: int          # vocabulary size
  n_embed: int             # embedding size (= n_head * head_size)
  n_layer: int             # number of transformer blocks
  intermediate_size: int   # intermediate size of FFN

  # multi head / multi query attention
  n_head: int              # number of heads
  n_query_groups: int      # number of query groups in multi-query attention

  # RoPE positional embedding
  rope_condense_ratio: int # rope condense ratio
  rope_base: int           # rope base

  def __post_init__(self):
    assert self.n_embed % self.n_head == 0, f'Embedding size(n_embed={self.n_embed}) should be divisible by (n_head={self.n_head})'
    self.head_size = self.n_embed // self.n_head


In [None]:
config = Config(
    vocab_size=lit_config.vocab_size,
    n_embed=lit_config.n_embd,
    n_head=lit_config.n_head,
    n_layer=lit_config.n_layer,
    intermediate_size=lit_config.intermediate_size,
    n_query_groups=lit_config.n_query_groups,
    max_seq_length=lit_config.block_size,
    rope_condense_ratio=lit_config.rope_condense_ratio,
    rope_base=lit_config.rope_base
)

rich.print('Model Config:', config)

# Transformer Block

A transformer block looks like this: (tip: you can use https://mermaid.live/ to create diagrams like this)

[![](https://mermaid.ink/img/pako:eNplj70OgjAQgF-F3CQRBhgZTDSOYgw4dmlolUbaknqNMcC7W2lA1E5333d_7aDSjEMGl0Y_qpoaDHZnogL3hGotBnG8CYq8PGojE8-nbFRbRK5QaOXdnPq-5HdSX_C7YJY2_SyLZLVah-FyUfq1KB1Vfjh56gJfm_73L8dPOv1obdEd4jlEILmRVDD3-e7NCGDNJSeQuZBRcyNA1ODqqEVdPlUFGRrLI7Ato8j3gl4NlR4OLxn1YEI?type=png)](https://mermaid.live/edit#pako:eNplj70OgjAQgF-F3CQRBhgZTDSOYgw4dmlolUbaknqNMcC7W2lA1E5333d_7aDSjEMGl0Y_qpoaDHZnogL3hGotBnG8CYq8PGojE8-nbFRbRK5QaOXdnPq-5HdSX_C7YJY2_SyLZLVah-FyUfq1KB1Vfjh56gJfm_73L8dPOv1obdEd4jlEILmRVDD3-e7NCGDNJSeQuZBRcyNA1ODqqEVdPlUFGRrLI7Ato8j3gl4NlR4OLxn1YEI)

To implement a Block, we need to implement RMSNorm and Attention.

## RMSNorm
The original layernorm normalizes to a standard normal distribution:
$\frac{x - \bar{x}}{ \sqrt{\sigma^2 + \epsilon}}$

The RMS layernorm normalizes to a vector of norm $\sqrt{N}$:
$\frac{x}{\sqrt{\frac{|x|^2}{N} + \epsilon}}$

Where N is the layer dimension.

In [None]:
class RMSNorm(nn.Module):
  """Root Mean Square Layer Normalization.

  Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
  https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
  """
  axis: int = -1
  eps: float = 1e-5

  @nn.compact
  def __call__(self, x):
    weight = self.param('weight', lambda rng, shape: jnp.ones(shape), x.shape[-1])
    norm_x = jnp.mean(x * x, axis=self.axis, keepdims=True)
    x_normed = x / jnp.sqrt(norm_x + self.eps)
    return weight * x_normed

### Tests

In [None]:
generator = torch.Generator().manual_seed(1337)
rmsnorm_lit = lit_gpt.rmsnorm.RMSNorm(size=lit_config.n_embd)
rmsnorm = RMSNorm()
# input
B, T, C = 1, 5, lit_config.n_embd
x = torch.randn((B, T, C), generator=generator)
# run the reference RMSNorm
with torch.no_grad():
  out1 = rmsnorm_lit(x)

variables = rmsnorm.init(jrandom.key(0), x=x.numpy())
# copy state from reference RMSNorm
variables['params']['weight'] = rmsnorm_lit.state_dict()['weight'].numpy()

out2 = rmsnorm.apply(variables, x.numpy())

assert jnp.allclose(out1.numpy(), out2)
print('Normalized norm should be close to 1:', jnp.linalg.norm(out2[0,0] / variables['params']['weight'] / jnp.sqrt(x.shape[-1])))
print('Test passed: RMS')

Normalized norm should be close to 1: 0.99999535
Test passed: RMS


## Rotary Positional Embedding (RoPE)

In [None]:
def build_rope_cache(
    seq_len: int,
    n_elem: int,
    base: int = 10000,
    condense_ratio: int = 1
):
  """Enhanced Transformer with Rotary Position Embedding.

  Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
  transformers/rope/__init__.py. MIT License:
  https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
  """
  # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
  theta = 1.0 / (base ** (jnp.arange(0, n_elem, 2) / n_elem))

  # Create position indexes `[0, 1, ..., seq_len - 1]`
  seq_idx = jnp.arange(seq_len) / condense_ratio

  # Calculate the product of position index and $\theta_i$
  idx_theta = jnp.outer(seq_idx, theta)
  idx_theta = jnp.tile(idx_theta, (1, 2))

  return jnp.stack([jnp.cos(idx_theta), jnp.sin(idx_theta)], axis=-1)


def apply_rope(x, rope_emb):
  cos, sin = rope_emb[..., 0], rope_emb[..., 1]
  head_size = x.shape[-1]
  x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
  x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)
  rotated = jnp.concatenate((-x2, x1), axis=-1)  # (B, nh, T, hs)
  roped = (x * cos) + (rotated * sin)
  return roped

### Tests

In [None]:
generator = torch.Generator().manual_seed(1337)
input_shape = (B, n_heads, T, head_size) = (1, 4, 10, 32)
x = torch.randn(input_shape, generator=generator)
cos, sin = lit_gpt.model.build_rope_cache(seq_len=128, n_elem=head_size, base=10000, condense_ratio=1)
expected_output = lit_gpt.model.apply_rope(x, cos[:T], sin[:T])

rope_emb = build_rope_cache(seq_len=128, n_elem=head_size, base=10000, condense_ratio=1)
y = apply_rope(x.numpy(), rope_emb[:T])
assert jnp.allclose(y, expected_output.numpy(), rtol=1e-5), "rope embedding doesn't match"
print('Test passed: RoPE')

Test passed: RoPE


## Self Attention

In [None]:
class SelfAttention(nn.Module):
  """Multi head / Multi query / Grouped Query Attention.

  About n_query_groups
  to use multi-head attention (MHA), set this to `n_head` (default)
  to use multi-query attention (MQA), set this to 1
  to use grouped-query attention (GQA), set this to a value in between
  Example with `n_head=4`
  ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
  │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
  └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
    │    │    │    │         │        │                 │
  ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
  │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
  └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
    │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
  ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
          MHA                    GQA                   MQA
    n_query_groups=4       n_query_groups=2      n_query_groups=1
    q_per_kv=1             q_per_kv=2            q_per_kv=4
    n_head=4               n_head=4              n_head=4
    n_qkv=3                n_qkv=4               n_qkv=6
  credit https://arxiv.org/pdf/2305.13245.pdf
  """
  config: Config

  @nn.compact
  def __call__(self, x, rope_emb, mask=None):
    T = x.shape[-2]  # x: (B, T, C)
    mask = mask or jnp.tril(jnp.ones((T, T)))

    # nq = n_head, nk = nv = n_query_groups
    qkv_dim = (self.config.n_head + 2 * self.config.n_query_groups) * self.config.head_size
    qkv_proj = nn.Dense(features=qkv_dim, use_bias=False, name='proj_qkv')(x)

    # number of q's per group
    q_per_kv = self.config.n_head // self.config.n_query_groups
    # number of qkvs per group, k=v=1
    n_qkv = q_per_kv + 2
    # break embedding into (n_groups, n_qkv, head_size)
    qkv = einops.rearrange(qkv_proj, 'b t (n_groups n_qkv h) -> b n_groups n_qkv t h',
                           n_groups=self.config.n_query_groups,
                           n_qkv=n_qkv)
    # split q, k, v within groups
    q, k, v = einops.unpack(qkv, [[q_per_kv], [1], [1]], 'b n_groups * t h')

    if q_per_kv != 1:
      # repeat k and v in each group
      k = einops.repeat(k, 'b n_groups 1 t h -> b n_groups q_per_kv t h', q_per_kv=q_per_kv)
      v = einops.repeat(v, 'b n_groups 1 t h -> b n_groups q_per_kv t h', q_per_kv=q_per_kv)

    # merge groups into heads
    q = einops.rearrange(q, 'b n_groups q_per_kv t h -> b (n_groups q_per_kv) t h')
    k = einops.rearrange(k, 'b n_groups q_per_kv t h -> b (n_groups q_per_kv) t h')
    v = einops.rearrange(v, 'b n_groups q_per_kv t h -> b (n_groups q_per_kv) t h')

    # apply position embedding
    # NOTE: only apply to q and k, but not v
    q = apply_rope(q, rope_emb)
    k = apply_rope(k, rope_emb)

    # multi head scaled dot attention
    weights = einops.einsum(q, k, 'b nh tq h, b nh tk h -> b nh tq tk')
    weights = weights / jnp.sqrt(self.config.head_size)
    weights = jnp.where(mask, weights, float('-inf'))
    weights = nn.softmax(weights, axis=-1)
    out = einops.einsum(weights, v, 'b nh tq tv, b nh tv h -> b tq nh h')

    # concat heads
    out = einops.rearrange(out, 'b t nh h -> b t (nh h)')

    # final projection
    out = nn.Dense(self.config.n_embed, use_bias=False, name='proj_out')(out)

    return out


### Tests

In [None]:
torch.manual_seed(1337)
generator = torch.Generator().manual_seed(1337)
B, T, C = 1, 5, lit_config.n_embd
x = torch.randn(B, T, C, generator=generator)
attn_lit = lit_gpt.model.CausalSelfAttention(lit_config)
with torch.no_grad():
  cos, sin = lit_gpt.model.build_rope_cache(seq_len=lit_config.block_size, n_elem=lit_config.head_size, base=10000, condense_ratio=1)
  out1 = attn_lit(x, cos[:T], sin[:T])

attn = SelfAttention(config)
rope_emb = build_rope_cache(seq_len=config.max_seq_length, n_elem=config.head_size, base=10000, condense_ratio=1)
variables = attn.init(jrandom.key(0), x.numpy(), rope_emb[:T])
# copy weights
variables['params']['proj_qkv']['kernel'] = attn_lit.state_dict()['attn.weight'].T.numpy()
variables['params']['proj_out']['kernel'] = attn_lit.state_dict()['proj.weight'].T.numpy()
out2 = attn.apply(variables, x.numpy(), rope_emb[:T], mask=None)

assert jnp.allclose(out1.numpy(), out2, rtol=1e-4)
print('Test passed: SelfAttention')

Test passed: SelfAttention


## MLP

In [None]:
class MLP(nn.Module):
  """LLaMA style MLP."""
  config: Config

  @nn.compact
  def __call__(self, x):
    x1 = nn.Dense(self.config.intermediate_size, use_bias=False, name='fc_1')(x)
    x2 = nn.Dense(self.config.intermediate_size, use_bias=False, name='fc_2')(x)
    x = nn.silu(x1) * x2
    x = nn.Dense(self.config.n_embed, use_bias=False, name='proj_out')(x)
    return x

### Tests

In [None]:
torch.manual_seed(1337)
generator = torch.Generator().manual_seed(1337)
B, T, C = 1, 5, lit_config.n_embd
x = torch.randn(B, T, C, generator=generator)
mlp_lit = lit_gpt.model.LLaMAMLP(lit_config)
with torch.no_grad():
  out1 = mlp_lit(x)

mlp = MLP(config)
variables = mlp.init(jrandom.key(0), x.numpy())
variables['params']['fc_1']['kernel'] = mlp_lit.state_dict()['fc_1.weight'].T.numpy()
variables['params']['fc_2']['kernel'] = mlp_lit.state_dict()['fc_2.weight'].T.numpy()
variables['params']['proj_out']['kernel'] = mlp_lit.state_dict()['proj.weight'].T.numpy()

out2 = mlp.apply(variables, x.numpy())
assert jnp.allclose(out1.numpy(), out2, rtol=1e-4)
print('Test passed: MLP', out1.shape, out2.shape)

Test passed: MLP torch.Size([1, 5, 128]) (1, 5, 128)


## Block


In [None]:
class Block(nn.Module):
  config: Config

  @nn.compact
  def __call__(self, x, rope_emb, mask=None):
    n1 = RMSNorm(name='norm_1')(x)
    h = SelfAttention(self.config, name='attn')(n1, rope_emb, mask=mask)
    x = h + x
    n2 = RMSNorm(name='norm_2')(x)
    h = MLP(self.config, name='mlp')(n2)
    x = h + x
    return x

### Tests

In [None]:
torch.manual_seed(1337)
generator = torch.Generator().manual_seed(1337)
B, T, C = 1, 5, lit_config.n_embd
x = torch.randn(B, T, C, generator=generator)
block_lit = lit_gpt.model.Block(lit_config)
with torch.no_grad():
  cos, sin = lit_gpt.model.build_rope_cache(seq_len=lit_config.block_size, n_elem=lit_config.head_size, base=10000, condense_ratio=1)
  out1 = block_lit(x, cos[:T], sin[:T])

block = Block(config)
rope_emb = build_rope_cache(seq_len=config.max_seq_length, n_elem=config.head_size, base=10000, condense_ratio=1)
variables = block.init(jrandom.key(0), x.numpy(), rope_emb[:T])
variables['params']['norm_1']['weight'] = block_lit.state_dict()['norm_1.weight'].numpy()
variables['params']['norm_2']['weight'] = block_lit.state_dict()['norm_2.weight'].numpy()
variables['params']['attn']['proj_qkv']['kernel'] = block_lit.state_dict()['attn.attn.weight'].T.numpy()
variables['params']['attn']['proj_out']['kernel'] = block_lit.state_dict()['attn.proj.weight'].T.numpy()
variables['params']['mlp']['fc_1']['kernel'] = block_lit.state_dict()['mlp.fc_1.weight'].T.numpy()
variables['params']['mlp']['fc_2']['kernel'] = block_lit.state_dict()['mlp.fc_2.weight'].T.numpy()
variables['params']['mlp']['proj_out']['kernel'] = block_lit.state_dict()['mlp.proj.weight'].T.numpy()


out2 = block.apply(variables, x.numpy(), rope_emb[:T])

print(out1.shape, out2.shape)
assert jnp.allclose(out1.numpy(), out2, rtol=1e-4)
print('Test passed: Block')

torch.Size([1, 5, 128]) (1, 5, 128)
Test passed: Block


# Transformer

In [None]:
class GPT(nn.Module):
  config: Config

  @nn.compact
  def __call__(self, x):
    T = x.shape[-1]  # (B, T)
    rope_emb = self.variable('cache', 'rope_emb', build_rope_cache,
                             self.config.max_seq_length,
                             self.config.head_size,
                             self.config.rope_base,
                             self.config.rope_condense_ratio)

    x = nn.Embed(num_embeddings=self.config.vocab_size,
                 features=self.config.n_embed, name='emb')(x)
    self.sow('intermediates', 'emb_out', x)

    for i in range(self.config.n_layer):
      x = Block(config=self.config, name=f'block_{i}')(x, rope_emb.value[:T])
      self.sow('intermediates', f'block_{i}_out', x)

    # final layer norm
    x = RMSNorm(name='ln_f')(x)
    self.sow('intermediates', 'ln_out', x)
    # language model head
    x = nn.Dense(self.config.vocab_size, name='lm_head', use_bias=False)(x)

    return x

## Tests

We first get all intermediate outputs from reference model, then compare them one by one with our implementation.

In [None]:
# our test input
tokenizer = lit_gpt.Tokenizer(checkpoint_path)
idx = tokenizer.encode('Hello, my name is').view(1, -1)

lit_model.load_state_dict(torch.load(checkpoint_path / 'lit_model.pth'))

with torch.no_grad():
  # Calling model end-to-end
  out1 = lit_model(idx)

  # Calling model layer by layer
  # token embedding
  token_emb = lit_model.transformer.wte(idx)
  T = idx.size(1)

  # rope embeddings
  cos, sin = lit_model.cos[:T], lit_model.sin[:T]

  # transformer blocks
  hidden_results = []
  h = token_emb
  for block in lit_model.transformer.h:
      h_out = block(h, cos, sin, mask=None, input_pos=None)
      hidden_results.append({'input': h.numpy(), 'output': h_out.numpy()})
      h = h_out

  # final layer norm
  ln_result = lit_model.transformer.ln_f(h)
  # transformer output
  out2 = lit_model.lm_head(ln_result)

  # store all expected results
  expected_intermediates = {
      'emb': {'input': idx.numpy(), 'output': token_emb.numpy()},
      'blocks': hidden_results,
      'ln': {'input': hidden_results[-1]['output'], 'output': ln_result.numpy()},
      'lm_head': {'input': ln_result.numpy(), 'output': out2.numpy()},
      'rope': {'sin': lit_model.sin.numpy(), 'cos': lit_model.cos.numpy()},
  }

assert torch.allclose(out1, out2), "model output doesn't match"

model = GPT(config)
# variables = model.init(jrandom.key(0), x=idx.numpy())
variables = model.lazy_init(jrandom.key(0), x=idx.numpy())

state_dict = lit_model.state_dict().copy()
variables['params']['lm_head']['kernel'] = state_dict.pop('lm_head.weight').T.numpy()
variables['params']['emb']['embedding'] = state_dict.pop('transformer.wte.weight').numpy()
variables['params']['ln_f']['weight'] = state_dict.pop('transformer.ln_f.weight').numpy()

for i in range(config.n_layer):
  variables['params'][f'block_{i}']['norm_1']['weight'] = state_dict.pop(f'transformer.h.{i}.norm_1.weight').numpy()
  variables['params'][f'block_{i}']['norm_2']['weight'] = state_dict.pop(f'transformer.h.{i}.norm_2.weight').numpy()
  variables['params'][f'block_{i}']['attn']['proj_qkv']['kernel'] = state_dict.pop(f'transformer.h.{i}.attn.attn.weight').T.numpy()
  variables['params'][f'block_{i}']['attn']['proj_out']['kernel'] = state_dict.pop(f'transformer.h.{i}.attn.proj.weight').T.numpy()
  variables['params'][f'block_{i}']['mlp']['fc_1']['kernel'] = state_dict.pop(f'transformer.h.{i}.mlp.fc_1.weight').T.numpy()
  variables['params'][f'block_{i}']['mlp']['fc_2']['kernel'] = state_dict.pop(f'transformer.h.{i}.mlp.fc_2.weight').T.numpy()
  variables['params'][f'block_{i}']['mlp']['proj_out']['kernel'] = state_dict.pop(f'transformer.h.{i}.mlp.proj.weight').T.numpy()

assert len(state_dict.keys()) == 0, f'State not loaded: {state_dict.keys()}'

out3, states = model.apply(variables, x=idx.numpy(), mutable=['intermediates'])
intermediates = states['intermediates']

# embedding
assert jnp.allclose(expected_intermediates['emb']['output'], intermediates['emb_out'][0]), 'Emb not match'

# blocks
for i in range(config.n_layer):
  assert jnp.allclose(expected_intermediates['blocks'][i]['output'], intermediates[f'block_{i}_out'][0], rtol=1e-4, atol=1e-5), f'Block {i} not match'

# final ln
assert jnp.allclose(expected_intermediates['ln']['output'], intermediates['ln_out'][0], rtol=1e-4, atol=1e-6), 'final layer norm not match'

# End to end result
assert jnp.allclose(out1.numpy(), out3, rtol=1e-4, atol=1e-5)

print('All tests passed: Transformer')

All tests passed: Transformer


# Generation

In [None]:
def generate(key, model, tokenizer, variables, prompt, max_tokens=100):
  x = tokenizer.encode(prompt).view(1, -1).numpy()
  result = x[0].tolist()

  for t in range(max_tokens):
    x = x[..., -model.config.max_seq_length:]
    logits = model.apply(variables, x=x)
    next_token_logits = logits[:, -1, :]
    next_token = jrandom.categorical(jrandom.fold_in(key, t), next_token_logits)
    if next_token.item() == tokenizer.eos_id:
      break
    result.append(next_token.item())
    x = jnp.concatenate((x, next_token.reshape(-1,1)), axis=-1)

  return tokenizer.decode(jnp.array(result))

result = generate(jrandom.key(0), model, tokenizer, variables, "Citizen: \n", max_tokens=30)
print(result)

Citizen: 
Further of: but'another at the gates,
Since thou, my cousin, which one would be verified, nothing,--

C


# Scale to LLaMA2

Now let's just use the jax code to load a LLaMA checkpoint and do prediction.



## All Model definitions

For clarity and convenience, I've just collected all model definitions in one place:

In [None]:
from pathlib import Path
import dataclasses
import json

import rich

import jax
import jax.numpy as jnp, jax.random as jrandom
from flax import linen as nn
from flax import traverse_util
import einops


@dataclasses.dataclass
class Config:
  """GPT Config."""
  max_seq_length: int      # maximum context length
  vocab_size: int          # vocabulary size
  n_embed: int             # embedding size (= n_head * head_size)
  n_layer: int             # number of transformer blocks
  intermediate_size: int   # intermediate size of FFN

  # multi head / multi query attention
  n_head: int              # number of heads
  n_query_groups: int      # number of query groups in multi-query attention

  # RoPE positional embedding
  rope_condense_ratio: int # rope condense ratio
  rope_base: int           # rope base

  def __post_init__(self):
    assert self.n_embed % self.n_head == 0, f'Embedding size(n_embed={self.n_embed}) should be divisible by (n_head={self.n_head})'
    self.head_size = self.n_embed // self.n_head


class RMSNorm(nn.Module):
  """Root Mean Square Layer Normalization.

  Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
  https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
  """
  axis: int = -1
  eps: float = 1e-5

  @nn.compact
  def __call__(self, x):
    weight = self.param('weight', lambda rng, shape: jnp.ones(shape), x.shape[-1])
    norm_x = jnp.mean(x * x, axis=self.axis, keepdims=True)
    x_normed = x / jnp.sqrt(norm_x + self.eps)
    return weight * x_normed


class SelfAttention(nn.Module):
  """Multi head / Multi query / Grouped Query Attention.

  About n_query_groups
  to use multi-head attention (MHA), set this to `n_head` (default)
  to use multi-query attention (MQA), set this to 1
  to use grouped-query attention (GQA), set this to a value in between
  Example with `n_head=4`
  ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
  │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
  └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
    │    │    │    │         │        │                 │
  ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
  │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
  └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
    │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
  ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
          MHA                    GQA                   MQA
    n_query_groups=4       n_query_groups=2      n_query_groups=1
    q_per_kv=1             q_per_kv=2            q_per_kv=4
    n_head=4               n_head=4              n_head=4
    n_qkv=3                n_qkv=4               n_qkv=6
  credit https://arxiv.org/pdf/2305.13245.pdf
  """
  config: Config

  @nn.compact
  def __call__(self, x, rope_emb, mask=None):
    T = x.shape[-2]  # x: (B, T, C)
    mask = mask or jnp.tril(jnp.ones((T, T)))

    # nq = n_head, nk = nv = n_query_groups
    qkv_dim = (self.config.n_head + 2 * self.config.n_query_groups) * self.config.head_size
    qkv_proj = nn.Dense(features=qkv_dim, use_bias=False, name='proj_qkv')(x)

    # number of q's per group
    q_per_kv = self.config.n_head // self.config.n_query_groups
    # number of qkvs per group, k=v=1
    n_qkv = q_per_kv + 2
    # break embedding into (n_groups, n_qkv, head_size)
    qkv = einops.rearrange(qkv_proj, 'b t (n_groups n_qkv h) -> b n_groups n_qkv t h',
                           n_groups=self.config.n_query_groups,
                           n_qkv=n_qkv)
    # split q, k, v within groups
    q, k, v = einops.unpack(qkv, [[q_per_kv], [1], [1]], 'b n_groups * t h')

    if q_per_kv != 1:
      # repeat k and v in each group
      k = einops.repeat(k, 'b n_groups 1 t h -> b n_groups q_per_kv t h', q_per_kv=q_per_kv)
      v = einops.repeat(v, 'b n_groups 1 t h -> b n_groups q_per_kv t h', q_per_kv=q_per_kv)

    # merge groups into heads
    q = einops.rearrange(q, 'b n_groups q_per_kv t h -> b (n_groups q_per_kv) t h')
    k = einops.rearrange(k, 'b n_groups q_per_kv t h -> b (n_groups q_per_kv) t h')
    v = einops.rearrange(v, 'b n_groups q_per_kv t h -> b (n_groups q_per_kv) t h')

    # apply position embedding
    # NOTE: only apply to q and k, but not v
    q = apply_rope(q, rope_emb)
    k = apply_rope(k, rope_emb)

    # multi head scaled dot attention
    weights = einops.einsum(q, k, 'b nh tq h, b nh tk h -> b nh tq tk')
    weights = weights / jnp.sqrt(self.config.head_size)
    weights = jnp.where(mask, weights, float('-inf'))
    weights = nn.softmax(weights, axis=-1)
    out = einops.einsum(weights, v, 'b nh tq tv, b nh tv h -> b tq nh h')

    # concat heads
    out = einops.rearrange(out, 'b t nh h -> b t (nh h)')

    # final projection
    out = nn.Dense(self.config.n_embed, use_bias=False, name='proj_out')(out)

    return out


class MLP(nn.Module):
  """LLaMA style MLP."""
  config: Config

  @nn.compact
  def __call__(self, x):
    x1 = nn.Dense(self.config.intermediate_size, use_bias=False, name='fc_1')(x)
    x2 = nn.Dense(self.config.intermediate_size, use_bias=False, name='fc_2')(x)
    x = nn.silu(x1) * x2
    x = nn.Dense(self.config.n_embed, use_bias=False, name='proj_out')(x)
    return x


class Block(nn.Module):
  """A Transformer Block of attention followed by MLP."""
  config: Config

  @nn.compact
  def __call__(self, x, rope_emb, mask=None):
    n1 = RMSNorm(name='norm_1')(x)
    h = SelfAttention(self.config, name='attn')(n1, rope_emb, mask=mask)
    x = h + x
    n2 = RMSNorm(name='norm_2')(x)
    h = MLP(self.config, name='mlp')(n2)
    x = h + x
    return x


class GPT(nn.Module):
  """The full decoder only tranformer."""

  config: Config

  @nn.compact
  def __call__(self, x):
    T = x.shape[-1]  # (B, T)
    rope_emb = self.variable('cache', 'rope_emb', build_rope_cache,
                             self.config.max_seq_length,
                             self.config.head_size,
                             self.config.rope_base,
                             self.config.rope_condense_ratio)

    x = nn.Embed(num_embeddings=self.config.vocab_size,
                 features=self.config.n_embed, name='emb')(x)
    self.sow('intermediates', 'emb_out', x)

    for i in range(self.config.n_layer):
      x = Block(config=self.config, name=f'block_{i}')(x, rope_emb.value[:T])
      self.sow('intermediates', f'block_{i}_out', x)

    # final layer norm
    x = RMSNorm(name='ln_f')(x)
    self.sow('intermediates', 'ln_out', x)
    # language model head
    x = nn.Dense(self.config.vocab_size, name='lm_head', use_bias=False)(x)

    return x


def build_rope_cache(
    seq_len: int,
    n_elem: int,
    base: int = 10000,
    condense_ratio: int = 1
):
  """Enhanced Transformer with Rotary Position Embedding.

  Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
  transformers/rope/__init__.py. MIT License:
  https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
  """
  # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
  theta = 1.0 / (base ** (jnp.arange(0, n_elem, 2) / n_elem))

  # Create position indexes `[0, 1, ..., seq_len - 1]`
  seq_idx = jnp.arange(seq_len) / condense_ratio

  # Calculate the product of position index and $\theta_i$
  idx_theta = jnp.outer(seq_idx, theta)
  idx_theta = jnp.tile(idx_theta, (1, 2))

  return jnp.stack([jnp.cos(idx_theta), jnp.sin(idx_theta)], axis=-1)


def apply_rope(x, rope_emb):
  """Apply rope embedding to input x."""
  cos, sin = rope_emb[..., 0], rope_emb[..., 1]
  head_size = x.shape[-1]
  x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
  x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)
  rotated = jnp.concatenate((-x2, x1), axis=-1)  # (B, nh, T, hs)
  roped = (x * cos) + (rotated * sin)
  return roped


## Convert pytorch checkpoint to Jax

In [None]:
import re
import torch

def init_from_llama_checkpoint(
    model: GPT,
    config: Config,
    llama_checkpoint: Path):
  variable_shapes = jax.eval_shape(model.init, jrandom.key(0), jnp.zeros((1,1), dtype=jnp.int32))
  param_shapes = traverse_util.flatten_dict(variable_shapes['params'], sep='.')

  if not llama_checkpoint.is_dir():
    raise ValueError(f'llama checkpoint directory {llama_checkpoint} does not exist')

  with open(llama_checkpoint / 'pytorch_model.bin.index.json', 'r') as f:
    llama_model_index = json.load(f)

  param_names = sorted(llama_model_index['weight_map'].keys())
  files = sorted(list(set(llama_model_index['weight_map'].values())))

  params = {}

  def load_params(name, value):
    if hasattr(value, 'numpy'):
      value = value.numpy()
    assert name in param_shapes, f'Param does not exist: {name}'
    assert param_shapes[name].shape == value.shape, f'Shapes not match: {name} Expected: {param_shapes[name].shape} Actual: {value.shape}'
    params[name] = jax.device_put(value)
    print(f'Success: loaded param: {name}, dtype:{params[name].dtype} shape:{params[name].shape} device:{params[name].device()}')

  qkv = {i: {} for i in range(config.n_layer)}

  for f in files:
    print(f'Loading checkpoint file {f}')
    states = torch.load(llama_checkpoint / f)
    for name, value in states.items():
      if name == 'lm_head.weight':
        load_params('lm_head.kernel', value.T)
      elif name == 'model.embed_tokens.weight':
        load_params('emb.embedding', value)
      elif name == 'model.norm.weight':
        load_params('ln_f.weight', value)
      elif ret := re.match(r'model.layers\.(\d+)\.(.*)', name):
        i, sub_name = ret.groups()
        i = int(i)
        if sub_name == 'input_layernorm.weight':
          load_params(f'block_{i}.norm_1.weight', value)
        elif sub_name == 'post_attention_layernorm.weight':
          load_params(f'block_{i}.norm_2.weight', value)
        elif sub_name == 'mlp.gate_proj.weight':
          load_params(f'block_{i}.mlp.fc_1.kernel', value.T)
        elif sub_name == 'mlp.up_proj.weight':
          load_params(f'block_{i}.mlp.fc_2.kernel', value.T)
        elif sub_name == 'mlp.down_proj.weight':
          load_params(f'block_{i}.mlp.proj_out.kernel', value.T)
        elif sub_name == 'self_attn.o_proj.weight':
          load_params(f'block_{i}.attn.proj_out.kernel', value.T)
        elif sub_name == 'self_attn.q_proj.weight':
          qkv[i]['q'] = value.numpy()
        elif sub_name == 'self_attn.k_proj.weight':
          qkv[i]['k'] = value.numpy()
        elif sub_name == 'self_attn.v_proj.weight':
          qkv[i]['v'] = value.numpy()
        elif sub_name == 'self_attn.rotary_emb.inv_freq':
          pass
        else:
          raise ValueError(f'unhandled param: {name}')
      else:
        raise ValueError(f'unhandled param: {name}')
    del(states)  # save memory

  def combine_qkv(q, k, v, n_heads):
    q = einops.rearrange(q, '(nh h) n_embed -> n_embed nh 1 h', nh=n_heads)
    k = einops.rearrange(k, '(nh h) n_embed -> n_embed nh 1 h', nh=n_heads)
    v = einops.rearrange(v, '(nh h) n_embed -> n_embed nh 1 h', nh=n_heads)

    packed, _ = einops.pack([q,k,v], 'n_embed nh * h')
    qkv = einops.rearrange(packed, 'n_embed nh n_qkv h -> n_embed (nh n_qkv h)')
    return qkv

  for i in range(config.n_layer):
    q, k, v = qkv[i]['q'], qkv[i]['k'], qkv[i]['v']
    proj_qkv = combine_qkv(q, k, v, config.n_head)
    load_params(f'block_{i}.attn.proj_qkv.kernel', proj_qkv)

  return traverse_util.unflatten_dict(params, sep='.')


# deduced from: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main
llama2_7b_config = Config(
    max_seq_length=4096,
    vocab_size=32000,
    n_layer=32,
    n_head=32,
    n_embed=4096,
    n_query_groups=32,
    intermediate_size=11008,
    rope_base=10000,
    rope_condense_ratio=1,
)

model = GPT(llama2_7b_config)
llama_checkpoint = Path('/content/drive/MyDrive/checkpoints/meta-llama/Llama-2-7b-chat-hf')
params = init_from_llama_checkpoint(model, llama2_7b_config, llama_checkpoint)

variables = model.init(jrandom.key(0), jnp.zeros((1,1), dtype=jnp.int32))
del(variables['params'])
variables['params'] = params

# rich.print(variables)

Now we save using orbax so we don't have to load the pytorch checkpoint next time:

In [None]:
from flax.training import orbax_utils
import orbax.checkpoint
import json

def save_checkpoint(variables, path: Path):
  orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
  model_index = {}
  for name, value in traverse_util.flatten_dict(variables, sep='.').items():
    ckpt = {'value': value}
    orbax_checkpointer.save(path / name, ckpt, save_args=orbax_utils.save_args_from_target(ckpt))
    model_index[name] = 'true'
    print(f'Saved {name}')

  with open(path / 'model_index.json', 'w') as f:
    json.dump(model_index, f)

  print(f'Save success')

save_checkpoint(variables, Path('/content/drive/MyDrive/checkpoints/flaxgpt-llama-2-7b-hf-chat'))

## Load jax checkpoint

In [None]:
from flax.training import orbax_utils
import orbax.checkpoint
import json

def load_checkpoint(path: Path):
  with open(path / 'model_index.json') as f:
    model_index = json.load(f)

  orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

  variables = {}
  for name in model_index:
    ckpt = orbax_checkpointer.restore(path / name)
    variables[name] = jax.device_put(ckpt['value'])
    print(f'Loaded variable: {name}')

  return traverse_util.unflatten_dict(variables, sep='.')

## Generate using loaded jax checkpoint

In [None]:
import os
from logging import getLogger
from typing import List

from sentencepiece import SentencePieceProcessor


logger = getLogger()


class Tokenizer:
    """tokenizing and encoding/decoding text using SentencePiece."""
    def __init__(self, model_path: str):
        """
        Initializes the Tokenizer with a SentencePiece model.

        Args:
            model_path (str): The path to the SentencePiece model file.
        """
        # reload tokenizer
        assert os.path.isfile(model_path), model_path
        self.sp_model = SentencePieceProcessor(model_file=model_path)
        logger.info(f"Reloaded SentencePiece model from {model_path}")

        # BOS / EOS token IDs
        self.n_words: int = self.sp_model.vocab_size()
        self.bos_id: int = self.sp_model.bos_id()
        self.eos_id: int = self.sp_model.eos_id()
        self.pad_id: int = self.sp_model.pad_id()
        logger.info(
            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
        )
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
        """
        Encodes a string into a list of token IDs.

        Args:
            s (str): The input string to be encoded.
            bos (bool): Whether to prepend the beginning-of-sequence token.
            eos (bool): Whether to append the end-of-sequence token.

        Returns:
            List[int]: A list of token IDs.
        """
        assert type(s) is str
        t = self.sp_model.encode(s)
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, t: List[int]) -> str:
        """
        Decodes a list of token IDs into a string.

        Args:
            t (List[int]): The list of token IDs to be decoded.

        Returns:
            str: The decoded string.
        """
        return self.sp_model.decode(t)


def generate(key, model, tokenizer, variables, prompt, max_tokens=100):
  x = jnp.array(tokenizer.encode(prompt, bos=True, eos=False)).reshape(1, -1)
  result = x[0].tolist()

  for t in range(max_tokens):
    x = x[..., -model.config.max_seq_length:]
    logits = model.apply(variables, x=x)
    next_token_logits = logits[:, -1, :]
    next_token = jrandom.categorical(jrandom.fold_in(key, t), next_token_logits)
    print(f't={t}', tokenizer.decode(result))
    if next_token.item() == tokenizer.eos_id:
      break
    result.append(next_token.item())
    x = jnp.concatenate((x, next_token.reshape(-1,1)), axis=-1)

  return tokenizer.decode(result)


In [None]:
llama_checkpoint = llama_checkpoint = Path('/content/drive/MyDrive/checkpoints/meta-llama/Llama-2-7b-chat-hf')
tokenizer = Tokenizer(model_path=str(llama_checkpoint / 'tokenizer.model'))
variables = load_checkpoint(Path('/content/drive/MyDrive/checkpoints/flaxgpt-llama-2-7b-hf-chat'))

# deduced from: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main
llama2_7b_config = Config(
    max_seq_length=4096,
    vocab_size=32000,
    n_layer=32,
    n_head=32,
    n_embed=4096,
    n_query_groups=32,
    intermediate_size=11008,
    rope_base=10000,
    rope_condense_ratio=1,
)

model = GPT(llama2_7b_config)

r = generate(jrandom.PRNGKey(0), model, tokenizer, variables, 'Hello, my name is', max_tokens=10)
print(r)

Loaded variable: cache.rope_emb
Loaded variable: params.emb.embedding
Loaded variable: params.block_0.attn.proj_out.kernel
Loaded variable: params.block_0.attn.proj_qkv.kernel
Loaded variable: params.block_0.mlp.fc_1.kernel
Loaded variable: params.block_0.mlp.fc_2.kernel
Loaded variable: params.block_0.mlp.proj_out.kernel
Loaded variable: params.block_0.norm_1.weight
Loaded variable: params.block_0.norm_2.weight
Loaded variable: params.block_1.attn.proj_out.kernel
Loaded variable: params.block_1.attn.proj_qkv.kernel
Loaded variable: params.block_1.mlp.fc_1.kernel
Loaded variable: params.block_1.mlp.fc_2.kernel
Loaded variable: params.block_1.mlp.proj_out.kernel
Loaded variable: params.block_1.norm_1.weight
Loaded variable: params.block_1.norm_2.weight
Loaded variable: params.block_2.attn.proj_out.kernel
Loaded variable: params.block_2.attn.proj_qkv.kernel
Loaded variable: params.block_2.mlp.fc_1.kernel
Loaded variable: params.block_2.mlp.fc_2.kernel
Loaded variable: params.block_2.mlp.