# Converting the LLama 3.2 1B model from Hugging Face to JAX

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/JAX_porting_HF_model.ipynb)

[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax-ai-stack/blob/main/docs/JAX_porting_HF_model.ipynb)

This tutorial demonstrates to convert Meta's [Llama 3.2 1B model](https://huggingface.co/meta-llama/Llama-3.2-1B) from Hugging Face to a JAX model and run it on T4 GPU.

You need some familiarity with [Flax](https://flax.readthedocs.io/en/latest/index.html), a library for building Neural Networks in JAX, to follow along. If you are getting started, check out the tutorials on [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html#example-a-simple-neural-network-with-flax) and [Flax's MNIST tutorial](https://flax.readthedocs.io/en/latest/mnist_tutorial.html).

## Setup

Let's install the `jax-ai-stack`, we'll use the `jax` and `flax` libraries from the stack in this tutorial. We will also need `huggingface_hub` for downloading model weights and `transformers` for tokenization.

In [1]:
!pip install -q jax-ai-stack
!pip install -Uq transformers huggingface_hub
!huggingface-cli login

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.3 MB[0m [31m?[0m eta [36m-:--:--[0m
[2K   [91m━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.6/2.3 MB[0m [31m17.9 MB/s[0m eta [36m0:00:01[0m
[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.3/2.3 MB[0m [31m33.5 MB/s[0m eta [36m0:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m394.9/394.9 kB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.2/86.2 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.0/102.0 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━

Take care of the imports.

In [2]:
import jax
import jax.numpy as jnp
from flax import nnx
from safetensors import safe_open
from pathlib import Path
import os
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from dataclasses import dataclass

## Define the configuration

The Hugging Face Transformers library has [some tips for using the Llama3 model](https://huggingface.co/docs/transformers/main/en/model_doc/llama3). For further reference, the modeling code in the transformers library lives in the [`models/llama/modeling_llama.py` file](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py).

Before we create the model in JAX, we need to define some parameters. You can refer to the [Llama2 documentation for the configuration options](https://huggingface.co/docs/transformers/main/en/model_doc/llama2#transformers.LlamaConfig).

In [3]:
@dataclass
class LlamaConfig:
    def __init__(self):
        self.dim = 2048
        self.n_layers = 16
        self.n_heads = 32
        self.n_kv_heads = 8
        self.head_dim = self.dim // self.n_heads
        self.intermediate_size = 14336
        self.vocab_size = 128256
        self.multiple_of = 256
        self.norm_eps = 1e-05
        self.rope_theta = 500000.0

config = LlamaConfig()

## Load the model weights

We'll use the transformers library to download the model weights.

Meta requires [acceptance of the license](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/discussions/130) before you can access the files. You will also need a Hugging Face access token, please refer to [Hugging Face documentation](https://huggingface.co/docs/hub/en/security-tokens) to set it up.

In [None]:
model_id = "meta-llama/Llama-3.2-1B"
if os.path.exists('/kaggle'):
    weights_base_dir = '/kaggle/tmp'
elif os.path.exists('/content'):
    # Colab
    weights_base_dir = '/content'
else:
    # Local machine
    weights_base_dir = '.'

path_to_model_weights = os.path.join(weights_base_dir, model_id)

snapshot_download(repo_id=model_id, local_dir=path_to_model_weights)

Then extract the model weights from the safetensors file and store them in the `weights` dict. These weights will be loaded into our JAX model soon.

In [5]:
def load_safetensors():
    weights = {}
    safetensors_files = Path(path_to_model_weights).glob('*.safetensors')

    for file in safetensors_files:
        with safe_open(file, framework="jax", device="cpu") as f:
            for key in f.keys():
                weights[key] = f.get_tensor(key)
    return weights

weights = load_safetensors()

Note that the weights are stored as `bfloat16`.

## Define the Flax model

Now we can define the model in Flax.

[This Transformer vs Llama diagram](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/_images/transformer_vs_llama.svg) from Nvidia visualizes the model architecture pretty nicely. We will define each layer using [Flax's NNX.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module).

We will start by defining the RMS normalization layer. Note how we load the parameters from the `weights` dict.

In [6]:
class LlamaRMSNorm(nnx.Module):

    def __init__(self, name=None, layer_idx=None, rngs=None):
        if name is None and layer_idx is None:
            # Final normalization layer
            self.norm_weights = nnx.Param(weights["model.norm.weight"], rngs=rngs)
        else:
            self.norm_weights = nnx.Param(weights[f"model.layers.{layer_idx}.{name}.weight"], rngs=rngs)

    def __call__(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.astype(jnp.float32)
        squared_mean = jnp.mean(jnp.square(hidden_states), axis=-1, keepdims=True)
        hidden_states = hidden_states * jnp.reciprocal(jnp.sqrt(squared_mean + config.norm_eps))
        return self.norm_weights * hidden_states.astype(input_dtype)

Llama 3 uses [Rotary Position Embedding (RoPE)](https://arxiv.org/abs/2104.09864) to encode both token and positional embeddings. For a gentle introduction to RoPE, please refer to the [CMU lecture slides](https://www.cs.cmu.edu/~mgormley/courses/10423-s24//slides/lecture5-vit-ink.pdf) and this awesome [EleutherAI blog](https://blog.eleuther.ai/rotary-embeddings/).

In [7]:
class LlamaRotaryEmbedding(nnx.Module):

    def __init__(self, dim, base=10000, rngs=None):
        self.dim = dim
        self.base = base

    def __call__(self, position_ids):
        inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
        inv_freq_expanded = jnp.expand_dims(inv_freq, axis=(0, 1))
        position_ids_expanded = jnp.expand_dims(position_ids, axis=(0, 2)).astype(jnp.float32)
        freqs = jnp.einsum('bij,bjk->bijk', position_ids_expanded, inv_freq_expanded)
        emb = jnp.concatenate([freqs, freqs], axis=-1)
        cos = jnp.cos(emb).squeeze(2).astype(jnp.bfloat16)
        sin = jnp.sin(emb).squeeze(2).astype(jnp.bfloat16)
        return cos, sin

Now we create the attention layers. Note how we load the weights into the q, k and v projection layers.

In [8]:
class LlamaAttention(nnx.Module):

    def __init__(self, layer_idx, rngs=None):
        self.q_proj = nnx.Linear(config.dim, config.n_heads * config.head_dim, use_bias=False, rngs=rngs)
        self.q_proj.kernel.value = weights[f"model.layers.{layer_idx}.self_attn.q_proj.weight"].T
        self.k_proj = nnx.Linear(config.dim, config.n_kv_heads * config.head_dim, use_bias=False, rngs=rngs)
        self.k_proj.kernel.value = weights[f"model.layers.{layer_idx}.self_attn.k_proj.weight"].T
        self.v_proj = nnx.Linear(config.dim, config.n_kv_heads * config.head_dim, use_bias=False, rngs=rngs)
        self.v_proj.kernel.value = weights[f"model.layers.{layer_idx}.self_attn.v_proj.weight"].T
        self.o_proj = nnx.Linear(config.n_heads * config.head_dim, config.dim, use_bias=False, rngs=rngs)
        self.o_proj.kernel.value = weights[f"model.layers.{layer_idx}.self_attn.o_proj.weight"].T
        self.rotary_emb = LlamaRotaryEmbedding(config.head_dim, base=config.rope_theta, rngs=rngs)

    # Alternative implementation: https://github.com/google/flax/blob/5d896bc1a2c68e2099d147cd2bc18ebb6a46a0bd/examples/gemma/positional_embeddings.py#L45
    def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
        cos = jnp.expand_dims(cos, axis=unsqueeze_dim)
        sin = jnp.expand_dims(sin, axis=unsqueeze_dim)
        q_embed = (q * cos) + (self.rotate_half(q) * sin)
        k_embed = (k * cos) + (self.rotate_half(k) * sin)
        return q_embed, k_embed

    def rotate_half(self, x):
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return jnp.concatenate([-x2, x1], axis=-1)

    def repeat_kv(self, hidden_states, n_repeat):
        batch, n_kv_heads, seq_len, head_dim = hidden_states.shape
        if n_repeat == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].repeat(n_repeat, axis=2)
        return hidden_states.reshape(batch, n_kv_heads * n_repeat, seq_len, head_dim)

    def __call__(self, x, position_ids):
        batch_size, seq_len, _ = x.shape
        query = self.q_proj(x).reshape(batch_size, seq_len, config.n_heads, config.head_dim).transpose((0, 2, 1, 3))
        key = self.k_proj(x).reshape(batch_size, seq_len, config.n_kv_heads, config.head_dim).transpose((0, 2, 1, 3))
        value = self.v_proj(x).reshape(batch_size, seq_len, config.n_kv_heads, config.head_dim).transpose((0, 2, 1, 3))
        # Assuming batch_size=1
        cos, sin = self.rotary_emb(position_ids[0])
        query, key = self.apply_rotary_pos_emb(query, key, cos, sin)

        key = self.repeat_kv(key, config.n_heads // config.n_kv_heads)
        value = self.repeat_kv(value, config.n_heads // config.n_kv_heads)

        attn_weights = jnp.matmul(query, jnp.transpose(key, (0, 1, 3, 2)))
        attn_weights = (attn_weights.astype(jnp.float32) / jnp.sqrt(config.head_dim)).astype(jnp.bfloat16)
        attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1).astype(jnp.bfloat16)
        attn_output = jnp.matmul(attn_weights, value).transpose((0, 2, 1, 3)).reshape(batch_size, seq_len, -1)
        output = self.o_proj(attn_output)
        return output

MLP layer follows the attention layer. Similarly we load the weights into the gate, up and down projection layers.

In [9]:
class LlamaMLP(nnx.Module):

    def __init__(self, layer_idx, rngs=None):
        self.gate_proj = nnx.Linear(config.dim, config.intermediate_size, use_bias=False, rngs=rngs)
        self.gate_proj.kernel.value = weights[f"model.layers.{layer_idx}.mlp.gate_proj.weight"].T
        self.up_proj = nnx.Linear(config.dim, config.intermediate_size, use_bias=False, rngs=rngs)
        self.up_proj.kernel.value = weights[f"model.layers.{layer_idx}.mlp.up_proj.weight"].T
        self.down_proj = nnx.Linear(config.intermediate_size, config.dim, use_bias=False, rngs=rngs)
        self.down_proj.kernel.value = weights[f"model.layers.{layer_idx}.mlp.down_proj.weight"].T

    def __call__(self, x):
        return self.down_proj(jax.nn.silu(self.gate_proj(x)) * self.up_proj(x))

We assemble the decoder block.

In [10]:
class LlamaTransformerBlock(nnx.Module):

    def __init__(self, layer_idx, rngs=None):
        self.input_layernorm = LlamaRMSNorm(name="input_layernorm", layer_idx=layer_idx, rngs=rngs)
        self.attention = LlamaAttention(layer_idx=layer_idx, rngs=rngs)
        self.post_attention_layernorm = LlamaRMSNorm(name="post_attention_layernorm", layer_idx=layer_idx, rngs=rngs)
        self.mlp = LlamaMLP(layer_idx=layer_idx, rngs=rngs)

    def __call__(self, x, position_ids):
        residual = x
        x = self.input_layernorm(x)
        x = self.attention(x, position_ids)
        x = residual + x

        residual = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = residual + x
        return x

Finally we have the enire model.

In [11]:
class LlamaForCausalLM(nnx.Module):

    def __init__(self, rngs=None):
        self.token_embed = nnx.Embed(num_embeddings=config.vocab_size, features=config.dim, dtype=jnp.bfloat16, rngs=rngs)
        self.token_embed.embedding.value = weights["model.embed_tokens.weight"]

        self.layers = [LlamaTransformerBlock(layer_idx=idx, rngs=rngs) for idx in range(config.n_layers)]
        self.lm_head = nnx.Linear(config.dim, config.vocab_size, use_bias=False, rngs=rngs)
        self.lm_head.kernel.value = weights["model.embed_tokens.weight"].T
        self.norm = LlamaRMSNorm(name=None, layer_idx=None, rngs=rngs)

    def __call__(self, input_ids, position_ids):
        assert input_ids.shape[0] == 1, "Only batch size 1 is supported"
        x = self.token_embed(input_ids)
        for layer in self.layers:
            x = layer(x, position_ids)
        x = self.norm(x)
        logits = self.lm_head(x)
        return logits

## Run the Flax model

Let's take it for a spin! We are still going to use the tokenizer from Hugging Face (since our primary focus is re-building the model instead of the tokenizer).

In [12]:
model = LlamaForCausalLM(rngs=nnx.Rngs(0))

# We no longer need `weights`
del weights

tokenizer = AutoTokenizer.from_pretrained(model_id)
input_text = "The capital of Japan is"

input_ids = tokenizer(input_text, return_tensors="jax")["input_ids"]
position_ids = jnp.asarray([jnp.arange(input_ids.shape[1])])

for _ in range(15):
    logits = model(input_ids, position_ids)
    next_token = jnp.argmax(logits[:, -1, :], axis=-1)
    input_ids = jnp.concatenate([input_ids, next_token[:, None]], axis=1)
    position_ids = jnp.asarray([jnp.arange(input_ids.shape[1])])
    print(f"Generated token: {next_token[0]}")

print(tokenizer.decode(input_ids[0]))

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

Generated token: 27286
Generated token: 13
Generated token: 27286
Generated token: 374
Generated token: 279
Generated token: 7928
Generated token: 3363
Generated token: 304
Generated token: 6457
Generated token: 323
Generated token: 279
Generated token: 6864
Generated token: 315
Generated token: 6457
Generated token: 13
<|begin_of_text|>The capital of Japan is Tokyo. Tokyo is the largest city in Japan and the capital of Japan.


There you have it. We have successfully converted the Hugging Face model weights from the safetensors file, loaded them up in our JAX model, and run the model.

For simplicity, we have left out many optimizations (JIT, batch inference, KV cache, accelerators, SPMD and etc.) to speed things up. Feel free to implement them as an exercise.

In [13]:
del model
del tokenizer

## Convert weights from other frameworks

You can also convert weights from other frameworks. Afer all, weights are just numbers. But note that the tensor names and layouts could be different (this is expected since different frameworks may implement differently), so you may need to adjust your code accordingly.

**Before proceeding, you may need to restart the runtime to release the GPU memory.**

### Keras Hub

We will use Keras Hub to load the Llama 3.2 1B model  and extract the weights.

In [1]:
!pip install -Uq keras-hub

import keras_hub

llama_lm = keras_hub.models.Llama3CausalLM.from_preset("hf://meta-llama/Llama-3.2-1B", dtype="bfloat16")

weights_dict = {}
for layer in llama_lm.backbone.layers:
    weights = layer.get_weights()
    if weights:
        weights_dict[layer.name] = weights

For example, here are the embeddings. You then need to extract every weight tensor and load them up in JAX like we did for the Hugging Face model.

In [2]:
weights_dict["token_embedding"]

[array([[0.0045166, 0.0166016, 0.0209961, ..., -0.00537109, -0.0422363,
         -0.0314941],
        [0.0214844, -0.0238037, 0.0211182, ..., -0.0107422, -0.00106049,
         -0.0373535],
        [0.0136108, 0.010437, 0.0127563, ..., 0.00811768, -0.012207,
         0.00512695],
        ...,
        [0.000850677, 0.0163574, -0.0192871, ..., -0.000326157,
         -0.00297546, 0.00662231],
        [0.000850677, 0.0163574, -0.0192871, ..., -0.000326157,
         -0.00297546, 0.00662231],
        [0.000850677, 0.0163574, -0.0192871, ..., -0.000326157,
         -0.00297546, 0.00662231]], dtype=bfloat16)]

In [3]:
del llama_lm
del weights_dict

### PyTorch

When you download the Hugging Face model, the original PyTorch model weights released by Meta are automatically downloaded as well. They are located in the `original` subfolder (another way to access them is to visit [Meta's Llama website](https://www.llama.com/llama-downloads/) and go from there). We can load the weights like this:

In [4]:
import torch
import os

if os.path.exists('/kaggle'):
    weights_base_dir = '/kaggle/tmp'
elif os.path.exists('/content'):
    # Colab
    weights_base_dir = '/content'
else:
    # Local machine
    weights_base_dir = '.'
path_to_model_weights = os.path.join(weights_base_dir, "meta-llama/Llama-3.2-1B/original")
model_weights = torch.load(os.path.join(path_to_model_weights, "consolidated.00.pth"))

  model_weights = torch.load(os.path.join(path_to_model_weights, "consolidated.00.pth"))


For example, here are the embeddings. You then need to extract every weight tensor and load them up in JAX like we did for the Hugging Face model.

In [5]:
model_weights["tok_embeddings.weight"]

tensor([[ 0.0045,  0.0166,  0.0210,  ..., -0.0054, -0.0422, -0.0315],
        [ 0.0215, -0.0238,  0.0211,  ..., -0.0107, -0.0011, -0.0374],
        [ 0.0136,  0.0104,  0.0128,  ...,  0.0081, -0.0122,  0.0051],
        ...,
        [ 0.0009,  0.0164, -0.0193,  ..., -0.0003, -0.0030,  0.0066],
        [ 0.0009,  0.0164, -0.0193,  ..., -0.0003, -0.0030,  0.0066],
        [ 0.0009,  0.0164, -0.0193,  ..., -0.0003, -0.0030,  0.0066]],
       device='cuda:0', dtype=torch.bfloat16)