In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../maxtext")))
os.environ["SKIP_JAX_PRECOMPILE"] = "1"

import functools
from etils import epath


import transformers
import numpy as np

import jax
from flax import nnx
from flax import linen as nn

import MaxText as mt
from MaxText import pyconfig
from MaxText.integration.tunix.tunix_adaptor import TunixMaxTextLlama

from tunix.rl.rollout.vllm_rollout import VllmRollout
from tunix.rl.rollout.base_rollout import RolloutConfig

from tunix.rl.rollout import base_rollout
from tunix.models.llama3 import model as llama3_lib

from vllm import LLM
import orbax.checkpoint as ocp

In [None]:
MODEL = "meta-llama/Llama-3.1-8B"
TOTAL_TPU_TO_USE = 8
MESH = [(1, TOTAL_TPU_TO_USE), ("fsdp", "tp")]  # YY


model_tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
mesh = model_mesh = jax.make_mesh(
    *MESH, devices=jax.devices()[:TOTAL_TPU_TO_USE]
)

In [None]:
def get_ref_maxtext_model(config):

  def create_model(config):
    return mt.from_pretrained(config, rngs=nnx.Rngs(params=0, dropout=1))

  abstract_model = nnx.eval_shape(create_model, config=config)
  graphdef, abstract_state = nnx.split(abstract_model)
  print("The abstract NNX state (all leaves are abstract arrays):")
  nnx.display(abstract_state)
  specs = nnx.get_partition_spec(abstract_state)
  mesh = abstract_model.mesh

  # JIT a function that creates the model state with proper sharding from the start.
  # By providing out_shardings, we instruct JAX to produce sharded output directly,
  # avoiding a large intermediate allocation on a single device.
  with nn.logical_axis_rules(config.logical_axis_rules):
    out_shardings = nn.logical_to_mesh_sharding(specs, mesh)

  @functools.partial(jax.jit, out_shardings=out_shardings)
  def create_sharded_state():
    # This will be JIT-compiled. JAX knows the output sharding and can
    # initialize the parameters directly on the target devices in a sharded way.
    model = create_model(config)
    return nnx.state(model)

  with mesh:
    # Create the model with sharded parameters.
    sharded_state = create_sharded_state()
    model = nnx.merge(graphdef, sharded_state)

    if config.load_parameters_path:
      target_for_restore = jax.tree.map(
          lambda v: v.value,
          sharded_state,
          is_leaf=lambda n: isinstance(n, nnx.Variable),
      )

      try:
        ckptr = ocp.Checkpointer(
            ocp.PyTreeCheckpointHandler(
                restore_concurrent_gb=None,
                save_concurrent_gb=None,
                use_ocdbt=True,
                use_zarr3=True,
            )
        )
        # This is a memory optimization. We don't want to restore the entire checkpoint - only the params.
        # Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste
        # memory, we instead specify here that we are just restoring the params field of the checkpoint
        # (which itself may be a dictionary containing a key named 'params').
        restore_args = ocp.checkpoint_utils.construct_restore_args(
            target_for_restore
        )
        restored = ckptr.restore(
            epath.Path(config.load_parameters_path),
            item={"params": {"params": target_for_restore}},
            transforms={},
            restore_args={"params": {"params": restore_args}},
        )
        checkpoint = restored["params"]["params"]

        if checkpoint:
          nnx.update(model, checkpoint)

      except Exception as e:
        raise ValueError(f"Checkpointing failed: {e}")

    tunix_model = TunixMaxTextLlama(
        base_model=model,
        use_attention_mask=False,  # trust Tunix loss masking
    )

    model_config = llama3_lib.ModelConfig.llama3_1_8b()
    tunix_model.config = model_config

  return tunix_model, mesh, model_config


from MaxText.integration.tunix.tunix_adaptor import TunixMaxTextLlama

config_ref = pyconfig.initialize(
    [
        "",
        "../../maxtext/MaxText/configs/base.yml",
    ],  # TODO: @mazumdera: why decode.py?
    base_output_directory="gs://dummy_output_dir",  # This is not used in Tunix.
    run_name="test-tunix-maxtext-llama3.1-8b",
    tokenizer_type="tiktoken",
    tokenizer_path="assets/tokenizer_llama3.tiktoken",
    load_parameters_path="gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items",
    per_device_batch_size=1,
    max_prefill_predict_length=4,
    max_target_length=16,
    steps=10,
    async_checkpointing="false",
    model_name="llama3.1-8b",
    checkpoint_period=5,
    skip_jax_distributed_system="true",
    weight_dtype="bfloat16",
    attention="dot_product",
    remat_policy="custom",
    decoder_layer_input="offload",
    query_proj="offload",
    key_proj="offload",
    value_proj="offload",
    opt_type="sgd",
)

maxtext_model, _, model_config = get_ref_maxtext_model(config_ref)
nnx.display(maxtext_model)

In [None]:
TOTAL_GENERATION_STEPS = 64
MAX_PROMPT_LENGTH = 64
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = None
cache_config = base_rollout.RolloutConfig(
    max_tokens_to_generate=TOTAL_GENERATION_STEPS,
    max_prompt_length=MAX_PROMPT_LENGTH,
    kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
    temperature=TEMPERATURE,
    top_p=TOP_P,
    top_k=TOP_K,
)

In [None]:
vllm_rollout = VllmRollout(
    init_with_random_weights=False,
    hbm_utilization=0.3,
    tpu_backend_type="tpu",
    model=maxtext_model,
    tokenizer=model_tokenizer,
    cache_config_or_size=1024,
    mesh=mesh,
    lora_config=None,
    model_version=MODEL,
)

In [None]:
output = vllm_rollout.generate(
    ["The capital of France is"],
    rollout_config=RolloutConfig(
        n=1, max_tokens_to_generate=64, temperature=0.1
    ),
)

In [None]:
output