# Test Huggingface Save/Load Workflow

In [1]:
import jax
import flax.nnx as nnx
import jax.numpy as jnp

from jaxpt.models.mobile_llm import Mobile_LLM, MobileLLM_Config, convert_to_hf, from_hf_pretrained

key = jax.random.PRNGKey(1337)
rngs = nnx.Rngs(key)
config = MobileLLM_Config(dtype=jnp.float32, \
                    vocab_size=49152,
                    n_embed=576,
                    n_head=9,
                    n_kv_head=3,
                    n_mlp_hidden=1536,
                    sdpa_implementation="xla")
                    

#hf_m.save_pretrained("/Users/vikram/dev/jaxpt/notebooks/mobile_llm.hf")


In [2]:

#m = Mobile_LLM(config, rngs)
#graphdef, params, state = nnx.split(m, nnx.Param, ...)
#nnx.display(params)
#hf_m = convert_to_hf(m)
#hf_state = hf_m.state_dict()

In [3]:
from jaxpt.models.mobile_llm import from_hf_pretrained,load_hf_pretrained
m = from_hf_pretrained(config, rngs)
graphdef, params, other_state = nnx.split(m, nnx.Param, ...)
nnx.display(params)
hf_m = load_hf_pretrained()
print(hf_m.model)

LlamaModel(
  (embed_tokens): Embedding(49152, 576)
  (layers): ModuleList(
    (0-29): 30 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=576, out_features=576, bias=False)
        (k_proj): Linear(in_features=576, out_features=192, bias=False)
        (v_proj): Linear(in_features=576, out_features=192, bias=False)
        (o_proj): Linear(in_features=576, out_features=576, bias=False)
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
        (up_proj): Linear(in_features=576, out_features=1536, bias=False)
        (down_proj): Linear(in_features=1536, out_features=576, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
    )
  )
  (norm): LlamaRMSNorm((576,), eps=1e-05)
  (rotary_emb): LlamaRotaryEmbedding()
)


In [4]:
import torch.nn as nn

class IdentityLayer(nn.Module):
    def forward(self, hidden_states, *args, **kwargs):
        return hidden_states.unsqueeze(0)

class IdentityAttn(nn.Module):
    def forward(self, hidden_states, *args, **kwargs):
        return hidden_states.unsqueeze(0), None

class ZeroRotation(nn.Module):
    def forward(self, hidden_states, *args, **kwargs):
        out_shape = hidden_states[:, :, :64].shape
        print(out_shape)
        return torch.tensor(np.ones(out_shape)), torch.tensor(np.zeros(out_shape))

for i in range(len(hf_m.model.layers)):
    if i > 0:
        hf_m.model.layers[i] = IdentityLayer() 
    #else:
        #hf_m.model.layers[i].self_attn = IdentityAttn()
        #hf_m.model.layers[i].mlp = IdentityLayer()
        #hf_m.model.layers[i].input_layernorm = IdentityLayer()
        #hf_m.model.layers[i].post_attention_layernorm = IdentityLayer()
#hf_m.model.norm = IdentityLayer() 
#hf_m.model.rotary_emb = ZeroRotation()
#hf_m.lm_head = IdentityLayer()
#hf_m.model.embed_tokens = None

for i in range(len(m.h)):
    if i > 0:
        m.h[i] = lambda x: x
    #else:
        #m.h[i].attn = lambda x: x
        #m.h[i].mlp = lambda x: x
        #m.h[i].rms_n_1 = lambda x: x
        #m.h[i].rms_n_2 = lambda x: x
#m.rms_n_f = lambda x: x

In [5]:
nnx.display(m)
print(hf_m)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
      (1-29): 29 x IdentityLayer()
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
    

In [6]:
import torch
import numpy as np


m.eval()
hf_m.eval()

with torch.no_grad():
    x = jax.random.randint(jax.random.key(1337), shape=(4, 10), minval=0, maxval=config.vocab_size)
    y_flax = m(x)
    #y_flax = m.wte(x)
    #y_flax = m.wte.attend(y_flax)
    y_flax = np.array(y_flax)
    x = torch.tensor(np.array(x))
    #y_hf = hf_m.model.embed_tokens(x)
    #print(y_hf.shape)
    #y_hf = hf_m.lm_head(y_hf)
    y_hf = hf_m(x)
    y_hf = y_hf.logits.detach().numpy()
    #y_hf = y_hf.detach().numpy()

print(np.mean(y_flax[:, :, :]))
print(np.mean(y_hf[:, :, :]))

-4.305381
-4.3187404


In [7]:

x = jax.random.uniform(jax.random.key(1337), shape=(4, 10, 576))
print(x.shape)
flax_y = m.h[0].attn.wproj(x)
x = torch.tensor(np.array(x))
hf_y = hf_m.model.layers[0].self_attn.o_proj(x)
print(hf_y)
print(flax_y)

(4, 10, 576)
tensor([[[-0.3656,  0.0217,  0.5213,  ..., -0.7870,  0.7028, -0.2325],
         [-0.0431,  0.0868,  0.6330,  ..., -1.4498,  0.5974, -0.7207],
         [ 0.0599,  0.3069,  0.6297,  ..., -1.2262, -0.3062, -0.5763],
         ...,
         [-0.9529,  0.4569,  0.7445,  ...,  0.0286,  0.3824, -0.0235],
         [-0.0175,  0.1905,  0.4627,  ..., -0.4968,  0.4359,  0.0786],
         [-1.0518,  0.3811,  0.6600,  ..., -0.1894,  0.7658, -0.1792]],

        [[-0.9684,  0.4107,  0.6200,  ...,  0.3200,  1.0673, -0.4461],
         [-0.9214,  0.8558,  0.8164,  ..., -0.4816, -0.0987, -0.2003],
         [ 0.2795,  0.7134,  0.3092,  ..., -0.0078,  0.7970, -1.3062],
         ...,
         [-0.5236,  0.3985,  1.0472,  ..., -0.5034,  0.1931, -0.6530],
         [ 0.1337,  0.1855,  0.8975,  ..., -1.0661,  0.3871, -0.2554],
         [ 0.0607,  0.8516,  0.4251,  ..., -0.7242,  0.2048, -0.2855]],

        [[-0.0798,  0.6596,  1.1176,  ..., -0.7048,  1.0005,  0.4859],
         [-0.0376,  0.9934,  0.6