## Port model from maxtext to clip

In [None]:
import os
from functools import partial
from types import SimpleNamespace

import flax.linen as nn
import jax
import jax.numpy as jnp
import orbax
from flax.training import orbax_utils
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.pjit import pjit
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from partitions import logical_axis_rules
from maxtext.layers.models import Transformer

## Step 1: port checkpoint to maxtext

First, use custom `_save_checkpoint` of `ckpt = {"params": jax_weights}` when converting the model.

## Step 2: create a config

In [None]:
from max_utils import unbox_logicallypartioned
import pyconfig

In [None]:
# maxtext config

jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

pyconfig.initialize(
    [
        "/home/boris/maxtext/MaxText/decode.py",
        "/home/boris/maxtext/MaxText/configs/base.yml",
        "load_parameters_path=/home/boris/maxtext/test/2024-03-18-16-53/decode-ckpt-maxtext/0/items",
        "run_name=runner_direct_2024-03-18-16-53",
        "per_device_batch_size=1",
        "model_name=mistral-7b",
        "tokenizer_path=/home/boris/maxtext/input/mistral-7B-v0.1/tokenizer.model",
        "ici_tensor_parallelism=4",
        "max_prefill_predict_length=4",
        "max_target_length=16",
        "prompt=I love to",
        "autoregressive_decode_assert=read. I love to read about the Bible. I love",
        "attention=dot_product",
    ]
)

pconfig = pyconfig.config
#pconfig.get_keys()

In [None]:
# create mesh
mp_devices = 8

assert jax.device_count() % mp_devices == 0
dp_devices = jax.local_device_count() // mp_devices
dev_mesh = create_device_mesh((dp_devices, mp_devices))
mesh = Mesh(dev_mesh, ("data", "model"))

# input
rng = jax.random.PRNGKey(0)

In [None]:
# for updating config
def show(k):
    val = f"{k}={getattr(pconfig, k)},"
    #pyperclip.copy(val)
    print(val)

In [None]:
def do_try(config, do_raise=False):
    input_shape = (1, 16)

    #model = Transformer(pconfig, mesh, quant=None)
    model = Transformer(config, mesh, quant=None)

    def init_llm(key):
        return model.init(
            {"params": key, "dropout": key, "aqt": key},
            jnp.ones(input_shape, dtype=jnp.int32),
            jnp.ones(input_shape, dtype=jnp.int32),
        )["params"]

    try:
        llm_shape = jax.eval_shape(init_llm, rng)
        return llm_shape, model
    except Exception as e:
        if do_raise:
            raise(e)
        k = str(e).split("no attribute ")[-1].split("'")[1]
        show(k)

In [None]:
config = SimpleNamespace(
    decoder_block="mistral",
    num_experts=1,
    vocab_size=32_000,
    emb_dim=4096,
    mlp_dim=14336,
    num_decoder_layers=32,
    num_query_heads=32,
    normalization_layer_epsilon=1e-05,
    head_dim=128,
    num_kv_heads=8,
    mlp_activations=['silu', 'linear'],
    logits_dot_in_fp32=True,
    use_iota_embed=False,
    use_untrainable_positional_embedding=False,
    trainable_position_size=-1,
    enable_dropout=False,
    dropout_rate=0,
    scan_layers=True,
    attention="dot_product",
    quantize_kvcache=False,
    fused_qkv=False,
    fused_mlp=False,
    record_internal_nn_metrics=0,
    logits_via_embedding=False,
    # TODO: change
    param_scan_axis=1,
    # customizable
    remat_policy="full",
    dtype="bfloat16",
    weight_dtype="float32",
    max_target_length=16,
    # unused
    max_prefill_predict_length=4,
)

do_try(config)

In [None]:
llm_shape, model = do_try(config, do_raise=True)
llm_shape

In [None]:
logical_spec = nn.get_partition_spec(llm_shape)
logical_spec

In [None]:
set(x for v in flatten_dict(logical_spec).values() for x in v)

In [None]:
rules = logical_axis_rules(
        activation_partitioning_dims=1,
        parameter_partitioning_dims=1,
    )

In [None]:
llm_params_spec = nn.logical_to_mesh(logical_spec, rules)
llm_params_spec

In [None]:
@partial(pjit, in_shardings=None, out_shardings=llm_params_spec)
def init_params(logical_params):
    return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), logical_params)

In [None]:
with mesh:
    llm_params = init_params(llm_shape)

In [None]:
# this is the maxtext dir
model_dir = f"/home/boris/maxtext/mistral_7b/"

In [None]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

def _restore_checkpoint(ckpt, dir, step):
    print(f"Restoring checkpoint from {dir} at step {step}")
    restore_args = orbax_utils.restore_args_from_target(ckpt, mesh)
    orbax_options = orbax.checkpoint.CheckpointManagerOptions()
    checkpoint_manager = orbax.checkpoint.CheckpointManager(dir, orbax_checkpointer, orbax_options)
    transforms = {}
    transforms = None
    return checkpoint_manager.restore(
        step, ckpt, restore_kwargs={"restore_args": restore_args, "transforms": transforms}
    )

In [None]:
# we need to unbox for it to work
ckpt = _restore_checkpoint({"params":unbox_logicallypartioned(llm_params)}, model_dir, 0)

In [None]:
for k,v in flatten_dict(ckpt).items():
    if jnp.sum(jnp.abs(v)) == 0:
        print(k, v.shape)

In [None]:
unboxed_params = ckpt["params"]

In [None]:
list(flatten_dict(unboxed_params).items())[0]

In [None]:
list(flatten_dict(llm_params).items())[0]

In [None]:
flattened_ckpt = flatten_dict(unboxed_params)
flattened_llm = flatten_dict(llm_params)
len(flattened_ckpt), len(flattened_llm)

In [None]:
set(flattened_ckpt.keys()) - set(flattened_llm.keys()), set(flattened_llm.keys()) - set(flattened_ckpt.keys())

In [None]:
for k in flattened_llm.keys():
    print(k, flattened_llm[k].value.shape, flattened_ckpt[k].shape)

In [None]:
for k in flattened_llm.keys():
    print(k, flattened_llm[k].value.dtype, flattened_ckpt[k].dtype)

In [None]:
for k,v in flattened_llm.items():
    v = v.value
    if jnp.sum(jnp.abs(v)) == 0:
        print("*** zero ***")
    else:
        print("** non-zero **")
    print(k, v.shape)

In [None]:
for k in flattened_llm.keys():
    flattened_llm[k] = flattened_llm[k].replace_boxed(flattened_ckpt[k])

In [None]:
for k,v in flattened_llm.items():
    v = v.value
    if jnp.sum(jnp.abs(v)) == 0:
        print("*** zero ***")
    else:
        print("** non-zero **")
    print(k, v.shape)

In [None]:
llm = unflatten_dict(flattened_llm)

In [None]:
# this is the version converted for clip
model_dir = "/home/boris/maxtext/mistral_7b_pretrain"

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

def _save_checkpoint(ckpt, dir, step):
    orbax_options = orbax.checkpoint.CheckpointManagerOptions(create=True)
    save_checkpoint_manager = orbax.checkpoint.CheckpointManager(
        dir, orbax_checkpointer, orbax_options
    )
    save_args = orbax_utils.save_args_from_target(ckpt)
    save_checkpoint_manager.save(step, ckpt, save_kwargs={"save_args": save_args})

ckpt = {"params": llm}
_save_checkpoint(ckpt, model_dir, 0)

In [None]:
# check that we can restore
ckpt = _restore_checkpoint({"params":llm_params}, model_dir, 0)

## Step 3: Load the model

In [None]:
config = SimpleNamespace(
    decoder_block="mistral",
    num_experts=1,
    vocab_size=32_000,
    emb_dim=4096,
    mlp_dim=14336,
    num_decoder_layers=32,
    num_query_heads=32,
    normalization_layer_epsilon=1e-05,
    head_dim=128,
    num_kv_heads=8,
    mlp_activations=['silu', 'linear'],
    logits_dot_in_fp32=True,
    use_iota_embed=False,
    use_untrainable_positional_embedding=False,
    trainable_position_size=-1,
    enable_dropout=False,
    dropout_rate=0,
    scan_layers=True,
    attention="dot_product",
    quantize_kvcache=False,
    fused_qkv=False,
    fused_mlp=False,
    record_internal_nn_metrics=0,
    logits_via_embedding=False,
    # TODO: change
    param_scan_axis=1,
    # customizable
    remat_policy="full",
    dtype="bfloat16",
    weight_dtype="float32",
    max_target_length=16,
    # unused
    max_prefill_predict_length=4,
)

In [None]:
# create mesh
mp_devices = 8

assert jax.device_count() % mp_devices == 0
dp_devices = jax.local_device_count() // mp_devices
dev_mesh = create_device_mesh((dp_devices, mp_devices))
mesh = Mesh(dev_mesh, ("data", "model"))

# input
rng = jax.random.PRNGKey(0)

In [None]:
input_shape = (1, 16)

#model = Transformer(pconfig, mesh, quant=None)
model = Transformer(config, mesh, quant=None)

def init_llm(key):
    return model.init(
        {"params": key, "dropout": key, "aqt": key},
        jnp.ones(input_shape, dtype=jnp.int32),
        jnp.ones(input_shape, dtype=jnp.int32),
    )["params"]

llm_shape = jax.eval_shape(init_llm, rng)