In [None]:
%pip install datasets

In [None]:
import jax
from torch import load
from requests import get
from jax.lax import pmean
from flax import linen as nn
from functools import partial
from flax import traverse_util
from tqdm.notebook import tqdm
from datasets import load_dataset
from dataclasses import dataclass
from flax.jax_utils import replicate
from transformers import AutoTokenizer
from jax import Array, pmap, jit, random
from  jax.nn.initializers import lecun_normal
from flax.training.train_state import TrainState
from optax import adamw, set_to_zero, multi_transform
from jax import numpy as jnp, ensure_compile_time_eval
from flax.training.orbax_utils import save_args_from_target
from orbax.checkpoint import Checkpointer, PyTreeCheckpointHandler
from jax import numpy as jnp, value_and_grad, ensure_compile_time_eval

jax.devices()

In [None]:
@dataclass
class PhiConfig:
    n_head: int = 32
    n_layer: int = 24
    n_embed: int = 2048
    rotary_dim: int = 32
    ln_eps: float = 1e-5
    n_positions: int = 2048
    vocab_size: int = 51200
    target_hidden_size: int = 2048
    param_dtype: jnp.dtype = jnp.bfloat16

In [None]:
def compute_cos_sin(config: PhiConfig) -> (Array, Array):
    t = jnp.arange(config.n_positions, dtype=jnp.float32)
    inv_freq = 1 / (10000 ** (jnp.arange(0, config.rotary_dim, 2, dtype=jnp.float32) / config.rotary_dim))
    freqs = jnp.outer(t, inv_freq)
    return jnp.cos(freqs).astype(config.param_dtype), jnp.sin(freqs).astype(config.param_dtype)

def apply_rotary_emb(qkv: Array, cos: Array, sin: Array) -> Array:
    _, seq_len, _, _, _ = qkv.shape
    _, rotary_dim = cos.shape
    rotary_dim *= 2

    q_rot = qkv[:, :, 0, :, :rotary_dim]
    q_pass = qkv[:, :, 0, :, rotary_dim:]

    k_rot = qkv[:, :, 1, :, :rotary_dim]
    k_pass = qkv[:, :, 1, :, rotary_dim:]

    q1, q2 = jnp.split(q_rot.astype(jnp.float32), 2, axis=-1)
    k1, k2 = jnp.split(k_rot.astype(jnp.float32), 2, axis=-1)
    c, s = cos[:seq_len][:, None, :].astype(jnp.float32), sin[:seq_len][:, None, :].astype(jnp.float32)

    q_rot = jnp.concatenate([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).astype(qkv.dtype)
    k_rot = jnp.concatenate([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).astype(qkv.dtype)

    return jnp.concatenate([
        jnp.concatenate([q_rot, q_pass], axis=-1)[:, :, None, :, :],
        jnp.concatenate([k_rot, k_pass], axis=-1)[:, :, None, :, :],
        qkv[:, :, 2:3, :, :]
    ], axis=2)

class SelfAttention(nn.Module):
    config: PhiConfig

    @nn.compact
    def __call__(self, x: Array) -> Array:
        batch_size, seq_len, n_embed = x.shape

        with ensure_compile_time_eval():
            cos, sin = compute_cos_sin(self.config)

        scale = (n_embed // self.config.n_head) ** -0.5
        mask = jnp.triu(jnp.full((seq_len, seq_len), -10000.0, dtype=jnp.float16), 1)
        qkv = nn.Dense(features=3 * self.config.n_embed, use_bias=True, param_dtype=self.config.param_dtype)(x)
        qkv = jnp.reshape(qkv, (batch_size, seq_len, 3, self.config.n_head, n_embed // self.config.n_head))
        qkv = apply_rotary_emb(qkv, cos, sin)
        qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = jnp.split(qkv, 3, axis=0)
        a = (q @ jnp.swapaxes(k, -2, -1)) * scale + mask
        a = nn.softmax(a, axis=-1)
        a = (a @ v).swapaxes(1, 2).reshape(batch_size, seq_len, n_embed)
        return nn.Dense(features=n_embed, use_bias=True, param_dtype=self.config.param_dtype)(a)

class MLP(nn.Module):
    config: PhiConfig

    @nn.compact
    def __call__(self, x: Array) -> (Array, Array):
        h = nn.Dense(features=self.config.n_embed * 4, use_bias=True, param_dtype=self.config.param_dtype)(x)
        h = nn.Dense(features=self.config.n_embed, use_bias=True, param_dtype=self.config.param_dtype)(nn.gelu(h))

        l = nn.Dense(features=self.config.target_hidden_size, use_bias=True, param_dtype=self.config.param_dtype)(x)
        l = nn.Dense(features=self.config.n_embed, use_bias=True, param_dtype=self.config.param_dtype)(nn.gelu(l))

        return (h, jnp.mean((l - h)**2))

class Block(nn.Module):
    config: PhiConfig

    @nn.compact
    def __call__(self, x: Array) -> (Array, Array):
        h = nn.LayerNorm(epsilon=self.config.ln_eps, param_dtype=self.config.param_dtype)(x)
        a = SelfAttention(self.config)(h)
        (h, loss) = MLP(self.config)(h)
        return (a + h, loss)

class Phi(nn.Module):
    config: PhiConfig

    @nn.compact
    def __call__(self, x: Array) -> list[Array]:
        total_loss = 0
        h = nn.Embed(num_embeddings=self.config.vocab_size, features=self.config.n_embed, param_dtype=self.config.param_dtype)(x)
        for _ in range(self.config.n_layer):
            (h, loss) = Block(self.config)(h)
            total_loss += loss
        # useless layers while training
        h = nn.LayerNorm(epsilon=self.config.ln_eps, param_dtype=self.config.param_dtype)(h)
        o = nn.Dense(self.config.vocab_size, use_bias=True, param_dtype=self.config.param_dtype)(h)
        return o, total_loss

In [None]:
def load_model_into_flax(config, model_path) -> dict:
    print("loading pytorch model")
    with open("/kaggle/working/pytorch_model.bin", "wb") as f:
        f.write(get(model_path, allow_redirects=True).content)
    model = load("/kaggle/working/pytorch_model.bin", map_location="cpu")
    print("pytorch model loaded")
    initializer = lecun_normal()
    print("loading model into flax")
    print("init trainable params")
    params = {}
    for i in range(config.n_layer):
        params[f"Embed_{i}"] = {}
        params["LayerNorm_0"] = {}
        params["Dense_0"] = {}
        params[f"Block_{i}"] = {}
        params[f"Block_{i}"]["MLP_0"] = {}
        params[f"Block_{i}"]["MLP_0"]["Dense_2"] = {}
        params[f"Block_{i}"]["MLP_0"]["Dense_3"] = {}
        params[f"Block_{i}"]["LayerNorm_0"] = {}
        params[f"Block_{i}"]["SelfAttention_0"] = {}
        params[f"Block_{i}"]["SelfAttention_0"]["Dense_0"] = {}
        params[f"Block_{i}"]["SelfAttention_0"]["Dense_1"] = {}
        params[f"Block_{i}"]["MLP_0"]["Dense_0"] = {}
        params[f"Block_{i}"]["MLP_0"]["Dense_1"] = {}

    for i in range(config.n_layer):
        params[f"Block_{i}"]["MLP_0"]["Dense_2"]["kernel"] = initializer(random.PRNGKey(0), (config.n_embed, config.target_hidden_size), dtype=config.param_dtype)
        params[f"Block_{i}"]["MLP_0"]["Dense_2"]["bias"] = jnp.zeros((config.target_hidden_size,), dtype=config.param_dtype)
        params[f"Block_{i}"]["MLP_0"]["Dense_3"]["kernel"] = initializer(random.PRNGKey(0), (config.target_hidden_size, config.n_embed), dtype=config.param_dtype)
        params[f"Block_{i}"]["MLP_0"]["Dense_3"]["bias"] = jnp.zeros((config.n_embed,), dtype=config.param_dtype)
    print("init frozen params")
    for param_name in model:
        if "layers" not in param_name:
            continue
        param_name = param_name.replace("transformer.", "")
        jnp_array = jnp.array(model[param_name].numpy()).astype(jnp.float16)
        match param_name:
            case "embd.wte.weight": params["Embed_0"]["embedding"] = jnp_array
            case "1.ln.weight": params["Block_0"]["LayerNorm_0"]["scale"] = jnp_array
            case "1.ln.bias": params["Block_0"]["LayerNorm_0"]["bias"] = jnp_array
            case "1.mixer.Wqkv.weight": params["Block_0"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "1.mixer.Wqkv.bias": params["Block_0"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "1.mixer.out_proj.weight": params["Block_0"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "1.mixer.out_proj.bias": params["Block_0"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "1.mlp.fc1.weight": params["Block_0"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "1.mlp.fc1.bias": params["Block_0"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "1.mlp.fc2.weight": params["Block_0"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "1.mlp.fc2.bias": params["Block_0"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "2.ln.weight": params["Block_1"]["LayerNorm_0"]["scale"] = jnp_array
            case "2.ln.bias": params["Block_1"]["LayerNorm_0"]["bias"] = jnp_array
            case "2.mixer.Wqkv.weight": params["Block_1"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "2.mixer.Wqkv.bias": params["Block_1"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "2.mixer.out_proj.weight": params["Block_1"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "2.mixer.out_proj.bias": params["Block_1"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "2.mlp.fc1.weight": params["Block_1"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "2.mlp.fc1.bias": params["Block_1"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "2.mlp.fc2.weight": params["Block_1"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "2.mlp.fc2.bias": params["Block_1"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "3.ln.weight": params["Block_2"]["LayerNorm_0"]["scale"] = jnp_array
            case "3.ln.bias": params["Block_2"]["LayerNorm_0"]["bias"] = jnp_array
            case "3.mixer.Wqkv.weight": params["Block_2"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "3.mixer.Wqkv.bias": params["Block_2"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "3.mixer.out_proj.weight": params["Block_2"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "3.mixer.out_proj.bias": params["Block_2"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "3.mlp.fc1.weight": params["Block_2"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "3.mlp.fc1.bias": params["Block_2"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "3.mlp.fc2.weight": params["Block_2"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "3.mlp.fc2.bias": params["Block_2"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "4.ln.weight": params["Block_3"]["LayerNorm_0"]["scale"] = jnp_array
            case "4.ln.bias": params["Block_3"]["LayerNorm_0"]["bias"] = jnp_array
            case "4.mixer.Wqkv.weight": params["Block_3"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "4.mixer.Wqkv.bias": params["Block_3"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "4.mixer.out_proj.weight": params["Block_3"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "4.mixer.out_proj.bias": params["Block_3"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "4.mlp.fc1.weight": params["Block_3"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "4.mlp.fc1.bias": params["Block_3"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "4.mlp.fc2.weight": params["Block_3"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "4.mlp.fc2.bias": params["Block_3"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "5.ln.weight": params["Block_4"]["LayerNorm_0"]["scale"] = jnp_array
            case "5.ln.bias": params["Block_4"]["LayerNorm_0"]["bias"] = jnp_array
            case "5.mixer.Wqkv.weight": params["Block_4"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "5.mixer.Wqkv.bias": params["Block_4"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "5.mixer.out_proj.weight": params["Block_4"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "5.mixer.out_proj.bias": params["Block_4"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "5.mlp.fc1.weight": params["Block_4"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "5.mlp.fc1.bias": params["Block_4"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "5.mlp.fc2.weight": params["Block_4"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "5.mlp.fc2.bias": params["Block_4"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "6.ln.weight": params["Block_5"]["LayerNorm_0"]["scale"] = jnp_array
            case "6.ln.bias": params["Block_5"]["LayerNorm_0"]["bias"] = jnp_array
            case "6.mixer.Wqkv.weight": params["Block_5"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "6.mixer.Wqkv.bias": params["Block_5"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "6.mixer.out_proj.weight": params["Block_5"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "6.mixer.out_proj.bias": params["Block_5"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "6.mlp.fc1.weight": params["Block_5"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "6.mlp.fc1.bias": params["Block_5"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "6.mlp.fc2.weight": params["Block_5"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "6.mlp.fc2.bias": params["Block_5"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "8.ln.weight": params["Block_7"]["LayerNorm_0"]["scale"] = jnp_array
            case "8.ln.bias": params["Block_7"]["LayerNorm_0"]["bias"] = jnp_array
            case "8.mixer.Wqkv.weight": params["Block_7"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "8.mixer.Wqkv.bias": params["Block_7"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "8.mixer.out_proj.weight": params["Block_7"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "8.mixer.out_proj.bias": params["Block_7"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "8.mlp.fc1.weight": params["Block_7"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "8.mlp.fc1.bias": params["Block_7"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "8.mlp.fc2.weight": params["Block_7"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "8.mlp.fc2.bias": params["Block_7"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "9.ln.weight": params["Block_8"]["LayerNorm_0"]["scale"] = jnp_array
            case "9.ln.bias": params["Block_8"]["LayerNorm_0"]["bias"] = jnp_array
            case "9.mixer.Wqkv.weight": params["Block_8"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "9.mixer.Wqkv.bias": params["Block_8"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "9.mixer.out_proj.weight": params["Block_8"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "9.mixer.out_proj.bias": params["Block_8"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "9.mlp.fc1.weight": params["Block_8"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "9.mlp.fc1.bias": params["Block_8"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "9.mlp.fc2.weight": params["Block_8"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "9.mlp.fc2.bias": params["Block_8"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "10.ln.weight": params["Block_9"]["LayerNorm_0"]["scale"] = jnp_array
            case "10.ln.bias": params["Block_9"]["LayerNorm_0"]["bias"] = jnp_array
            case "10.mixer.Wqkv.weight": params["Block_9"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "10.mixer.Wqkv.bias": params["Block_9"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "10.mixer.out_proj.weight": params["Block_9"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "10.mixer.out_proj.bias": params["Block_9"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "10.mlp.fc1.weight": params["Block_9"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "10.mlp.fc1.bias": params["Block_9"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "10.mlp.fc2.weight": params["Block_9"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "10.mlp.fc2.bias": params["Block_9"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "11.ln.weight": params["Block_10"]["LayerNorm_0"]["scale"] = jnp_array
            case "11.ln.bias": params["Block_10"]["LayerNorm_0"]["bias"] = jnp_array
            case "11.mixer.Wqkv.weight": params["Block_10"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "11.mixer.Wqkv.bias": params["Block_10"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "11.mixer.out_proj.weight": params["Block_10"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "11.mixer.out_proj.bias": params["Block_10"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "11.mlp.fc1.weight": params["Block_10"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "11.mlp.fc1.bias": params["Block_10"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "11.mlp.fc2.weight": params["Block_10"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "11.mlp.fc2.bias": params["Block_10"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "12.ln.weight": params["Block_11"]["LayerNorm_0"]["scale"] = jnp_array
            case "12.ln.bias": params["Block_11"]["LayerNorm_0"]["bias"] = jnp_array
            case "12.mixer.Wqkv.weight": params["Block_11"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "12.mixer.Wqkv.bias": params["Block_11"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "12.mixer.out_proj.weight": params["Block_11"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "12.mixer.out_proj.bias": params["Block_11"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "12.mlp.fc1.weight": params["Block_11"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "12.mlp.fc1.bias": params["Block_11"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "12.mlp.fc2.weight": params["Block_11"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "12.mlp.fc2.bias": params["Block_11"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "13.ln.weight": params["Block_12"]["LayerNorm_0"]["scale"] = jnp_array
            case "13.ln.bias": params["Block_12"]["LayerNorm_0"]["bias"] = jnp_array
            case "13.mixer.Wqkv.weight": params["Block_12"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "13.mixer.Wqkv.bias": params["Block_12"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "13.mixer.out_proj.weight": params["Block_12"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "13.mixer.out_proj.bias": params["Block_12"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "13.mlp.fc1.weight": params["Block_12"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "13.mlp.fc1.bias": params["Block_12"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "13.mlp.fc2.weight": params["Block_12"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "13.mlp.fc2.bias": params["Block_12"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "14.ln.weight": params["Block_13"]["LayerNorm_0"]["scale"] = jnp_array
            case "14.ln.bias": params["Block_13"]["LayerNorm_0"]["bias"] = jnp_array
            case "14.mixer.Wqkv.weight": params["Block_13"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "14.mixer.Wqkv.bias": params["Block_13"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "14.mixer.out_proj.weight": params["Block_13"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "14.mixer.out_proj.bias": params["Block_13"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "14.mlp.fc1.weight": params["Block_13"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "14.mlp.fc1.bias": params["Block_13"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "14.mlp.fc2.weight": params["Block_13"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "14.mlp.fc2.bias": params["Block_13"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "15.ln.weight": params["Block_14"]["LayerNorm_0"]["scale"] = jnp_array
            case "15.ln.bias": params["Block_14"]["LayerNorm_0"]["bias"] = jnp_array
            case "15.mixer.Wqkv.weight": params["Block_14"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "15.mixer.Wqkv.bias": params["Block_14"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "15.mixer.out_proj.weight": params["Block_14"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "15.mixer.out_proj.bias": params["Block_14"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "15.mlp.fc1.weight": params["Block_14"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "15.mlp.fc1.bias": params["Block_14"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "15.mlp.fc2.weight": params["Block_14"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "15.mlp.fc2.bias": params["Block_14"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "16.ln.weight": params["Block_15"]["LayerNorm_0"]["scale"] = jnp_array
            case "16.ln.bias": params["Block_15"]["LayerNorm_0"]["bias"] = jnp_array
            case "16.mixer.Wqkv.weight": params["Block_15"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "16.mixer.Wqkv.bias": params["Block_15"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "16.mixer.out_proj.weight": params["Block_15"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "16.mixer.out_proj.bias": params["Block_15"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "16.mlp.fc1.weight": params["Block_15"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "16.mlp.fc1.bias": params["Block_15"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "16.mlp.fc2.weight": params["Block_15"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "16.mlp.fc2.bias": params["Block_15"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "17.ln.weight": params["Block_16"]["LayerNorm_0"]["scale"] = jnp_array
            case "17.ln.bias": params["Block_16"]["LayerNorm_0"]["bias"] = jnp_array
            case "17.mixer.Wqkv.weight": params["Block_16"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "17.mixer.Wqkv.bias": params["Block_16"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "17.mixer.out_proj.weight": params["Block_16"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "17.mixer.out_proj.bias": params["Block_16"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "17.mlp.fc1.weight": params["Block_16"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "17.mlp.fc1.bias": params["Block_16"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "17.mlp.fc2.weight": params["Block_16"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "17.mlp.fc2.bias": params["Block_16"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "18.ln.weight": params["Block_17"]["LayerNorm_0"]["scale"] = jnp_array
            case "18.ln.bias": params["Block_17"]["LayerNorm_0"]["bias"] = jnp_array
            case "18.mixer.Wqkv.weight": params["Block_17"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "18.mixer.Wqkv.bias": params["Block_17"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "18.mixer.out_proj.weight": params["Block_17"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "18.mixer.out_proj.bias": params["Block_17"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "18.mlp.fc1.weight": params["Block_17"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "18.mlp.fc1.bias": params["Block_17"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "18.mlp.fc2.weight": params["Block_17"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "18.mlp.fc2.bias": params["Block_17"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "19.ln.weight": params["Block_18"]["LayerNorm_0"]["scale"] = jnp_array
            case "19.ln.bias": params["Block_18"]["LayerNorm_0"]["bias"] = jnp_array
            case "19.mixer.Wqkv.weight": params["Block_18"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "19.mixer.Wqkv.bias": params["Block_18"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "19.mixer.out_proj.weight": params["Block_18"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "19.mixer.out_proj.bias": params["Block_18"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "19.mlp.fc1.weight": params["Block_18"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "19.mlp.fc1.bias": params["Block_18"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "19.mlp.fc2.weight": params["Block_18"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "19.mlp.fc2.bias": params["Block_18"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "20.ln.weight": params["Block_19"]["LayerNorm_0"]["scale"] = jnp_array
            case "20.ln.bias": params["Block_19"]["LayerNorm_0"]["bias"] = jnp_array
            case "20.mixer.Wqkv.weight": params["Block_19"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "20.mixer.Wqkv.bias": params["Block_19"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "20.mixer.out_proj.weight": params["Block_19"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "20.mixer.out_proj.bias": params["Block_19"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "20.mlp.fc1.weight": params["Block_19"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "20.mlp.fc1.bias": params["Block_19"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "20.mlp.fc2.weight": params["Block_19"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "20.mlp.fc2.bias": params["Block_19"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "21.ln.weight": params["Block_20"]["LayerNorm_0"]["scale"] = jnp_array
            case "21.ln.bias": params["Block_20"]["LayerNorm_0"]["bias"] = jnp_array
            case "21.mixer.Wqkv.weight": params["Block_20"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "21.mixer.Wqkv.bias": params["Block_20"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "21.mixer.out_proj.weight": params["Block_20"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "21.mixer.out_proj.bias": params["Block_20"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "21.mlp.fc1.weight": params["Block_20"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "21.mlp.fc1.bias": params["Block_20"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "21.mlp.fc2.weight": params["Block_20"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "21.mlp.fc2.bias": params["Block_20"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "22.ln.weight": params["Block_21"]["LayerNorm_0"]["scale"] = jnp_array
            case "22.ln.bias": params["Block_21"]["LayerNorm_0"]["bias"] = jnp_array
            case "22.mixer.Wqkv.weight": params["Block_21"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "22.mixer.Wqkv.bias": params["Block_21"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "22.mixer.out_proj.weight": params["Block_21"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "22.mixer.out_proj.bias": params["Block_21"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "22.mlp.fc1.weight": params["Block_21"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "22.mlp.fc1.bias": params["Block_21"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "22.mlp.fc2.weight": params["Block_21"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "22.mlp.fc2.bias": params["Block_21"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "23.ln.weight": params["Block_22"]["LayerNorm_0"]["scale"] = jnp_array
            case "23.ln.bias": params["Block_22"]["LayerNorm_0"]["bias"] = jnp_array
            case "23.mixer.Wqkv.weight": params["Block_22"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "23.mixer.Wqkv.bias": params["Block_22"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "23.mixer.out_proj.weight": params["Block_22"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "23.mixer.out_proj.bias": params["Block_22"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "23.mlp.fc1.weight": params["Block_22"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "23.mlp.fc1.bias": params["Block_22"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "23.mlp.fc2.weight": params["Block_22"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "23.mlp.fc2.bias": params["Block_22"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "24.ln.weight": params["Block_23"]["LayerNorm_0"]["scale"] = jnp_array
            case "24.ln.bias": params["Block_23"]["LayerNorm_0"]["bias"] = jnp_array
            case "24.mixer.Wqkv.weight": params["Block_23"]["SelfAttention_0"]["Dense_0"]["kernel"] = jnp_array
            case "24.mixer.Wqkv.bias": params["Block_23"]["SelfAttention_0"]["Dense_0"]["bias"] = jnp_array
            case "24.mixer.out_proj.weight": params["Block_23"]["SelfAttention_0"]["Dense_1"]["kernel"] = jnp_array
            case "24.mixer.out_proj.bias": params["Block_23"]["SelfAttention_0"]["Dense_1"]["bias"] = jnp_array
            case "24.mlp.fc1.weight": params["Block_23"]["MLP_0"]["Dense_0"]["kernel"] = jnp_array
            case "24.mlp.fc1.bias": params["Block_23"]["MLP_0"]["Dense_0"]["bias"] = jnp_array
            case "24.mlp.fc2.weight": params["Block_23"]["MLP_0"]["Dense_1"]["kernel"] = jnp_array
            case "24.mlp.fc2.bias": params["Block_23"]["MLP_0"]["Dense_1"]["bias"] = jnp_array
            case "25.ln.weight": params["LayerNorm_0"]["scale"] = jnp_array
            case "25.ln.bias": params["LayerNorm_0"]["bias"] = jnp_array
            case "25.ln.weight": params["Dense_0"]["kernel"] = jnp_array
            case "25.ln.bias": params["Dense_0"]["bias"] = jnp_array
    print("model loaded into flax")
    return params

def init_train_state(config, batch_size, model_path) -> TrainState:
    phi = Phi(config)
    # variables = jit(phi.init)(random.PRNGKey(0), jnp.ones((batch_size // jax.local_device_count(), config.n_positions), dtype=jnp.int32))
    params = load_model_into_flax(config, model_path)

    partition_optimizers = {"trainable": adamw(3e-4), "frozen": set_to_zero()}
    param_partitions = traverse_util.path_aware_map(
        lambda path, _: "trainable" if ("Dense_2" in path or  "Dense_3" in path) else "frozen", params)

    state = TrainState.create(
        apply_fn=phi.apply,
        tx=multi_transform(partition_optimizers, param_partitions),
        params=params
    )
    return replicate(state)

In [None]:
@partial(pmap, axis_name="device")
def train_step(state: TrainState, batch: Array):
    def loss_fn(params):
        _, loss = state.apply_fn({"params": params}, batch)
        return loss
    
    grad_fn = value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    grads = pmean(grads, axis_name="device")
    loss = pmean(loss, axis_name="device")
    return state.apply_gradients(grads=grads), loss

In [None]:
def train_epoch(state: TrainState, n_iter: int, batch_size: int, config: PhiConfig) -> TrainState:
    dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_gen")
    tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
    tokenizer.pad_token = tokenizer.eos_token
    n_devices = jax.local_device_count()

    loss = 0
    tqdm_range = tqdm(range(n_iter))
    for i in tqdm_range:
        batch = jnp.array(tokenizer([dataset[i * batch_size + j]["prompt"] for j in range(batch_size * n_devices)], padding="max_length", max_length=config.n_positions, truncation=True)["input_ids"], dtype=jnp.int32)
        batch = jax.tree_map(lambda x: x.reshape((n_devices, -1, *x.shape[1:])), batch)

        state, loss = train_step(state, batch)
        tqdm_range.set_description(f"Loss: {loss:.4f}")

    return state
        
def train(batch_size):
    config = PhiConfig()
    state = init_train_state(config, batch_size, "https://huggingface.co/microsoft/phi-1_5/resolve/main/pytorch_model.bin")
    state = train_epoch(state, 1000, batch_size, config)
    save_checkpoint("/kaggle/working/phi-jax_2", state, state.step)
    
train(128)

In [None]:
jax.random.nor