In [10]:

import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../../maxtext')))
os.environ["SKIP_JAX_PRECOMPILE"] = "1"

import jax.numpy as jnp
from flax import nnx
import sys
import os
import flax.linen as nn
import logging

import MaxText as mt
from MaxText import pyconfig
from tunix.rl.rollout.vllm_rollout import VllmRollout
from tunix.rl.rollout import base_rollout
import transformers
import jax 
from MaxText.integration.tunix.tunix_adaptor import TunixMaxTextLlama
from tunix.rl import utils


In [11]:

show_hbm_usage = utils.show_hbm_usage

show_hbm_usage("Before loading model")
def get_ref_maxtext_model():

  #TODO: @mazumdera: change this to use Gemma2-2b-it
  config = pyconfig.initialize(
      ["", "../../maxtext/MaxText/configs/base.yml"], #TODO: @mazumdera: why decode.py?
      base_output_directory="gs://dummy_output_dir",  # This is not used in Tunix.
      run_name="none",
      tokenizer_path="../../maxtext/assets/tokenizer.gemma",
      per_device_batch_size=1,
      max_target_length=1024,
      steps=10,
      async_checkpointing="false",
      model_name="llama3.1-8b", #"llama3.1-8b"
      checkpoint_period=5, 
      skip_jax_distributed_system="true",
      weight_dtype="bfloat16",
      attention="dot_product"
  )
  
  def create_model(config):
    return mt.from_pretrained(config, rngs=nnx.Rngs(params=0, dropout=1))

  model = nnx.eval_shape(create_model, config=config)

  abstract_model = nnx.eval_shape(create_model, config=config)
  graphdef, abstract_state = nnx.split(abstract_model)
  print('The abstract NNX state (all leaves are abstract arrays):')
  nnx.display(abstract_state)

  @nnx.jit
  def partial_init(config):
    model = create_model(config)
    # nnx.update(model, checkpoint)
    # shard model
    state = nnx.state(model)
    specs = nnx.get_partition_spec(state)
    state = jax.lax.with_sharding_constraint(state, specs)
    nnx.update(model, state)
    return model

  with jax.sharding.use_mesh(model.mesh), nn.logical_axis_rules(config.logical_axis_rules):
    model = partial_init(config)
  print(model)

  tunix_model = TunixMaxTextLlama(
        base_model=model,
        use_attention_mask=False,  # trust Tunix loss masking
    )
  mesh  = tunix_model.base.mesh
  
  tunix_model.to_hf_mappings = lambda *args: {}
  tunix_model.to_hf_transpose_keys = lambda *args: {}
  tunix_model.lora_to_hf_mappings = lambda *args: {}

  # Add these lines to properly get the graph definition and state
  graphdef, state = nnx.split(tunix_model)
  tunix_model = nnx.merge(graphdef, state)  # Recreate model in proper NNX format
    
  return tunix_model, mesh

model, mesh = get_ref_maxtext_model()

print(model)
show_hbm_usage("After loading model")


TOTAL_GENERATION_STEPS = 64
MAX_PROMPT_LENGTH = 64  
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = None
cache_config = base_rollout.RolloutConfig(max_tokens_to_generate=TOTAL_GENERATION_STEPS, max_prompt_length=MAX_PROMPT_LENGTH, kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256, temperature=TEMPERATURE, top_p=TOP_P, top_k=TOP_K)




Updating keys from env and command line: ['run_name', 'model_name', 'async_checkpointing', 'checkpoint_period', 'weight_dtype', 'attention', 'base_output_directory', 'tokenizer_path', 'per_device_batch_size', 'steps', 'skip_jax_distributed_system', 'max_target_length']
Running Model: llama3.1-8b
Updating following parameters in config

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_num_decoder_layers: 32
base_mlp_dim: 14336
head_dim: 128
mlp_activations: ['silu', 'linear']
vocab_size: 128256
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1e-05
rope_max_timescale: 500000
decoder_block: llama2
Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_num_decoder_layers', 'base_mlp_dim', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'rope_max_timescale', 'decoder_block']
Skipping jax distributed system due to skip_jax_distribute

Config param dcn_context_parallelism: 1
Config param dcn_data_parallelism: -1
Config param dcn_expert_parallelism: 1
Config param dcn_fsdp_parallelism: 1
Config param dcn_fsdp_transpose_parallelism: 1
Config param dcn_parallelism: [-1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Config param dcn_pipeline_parallelism: 1
Config param dcn_sequence_parallelism: 1
Config param dcn_tensor_parallelism: 1
Config param dcn_tensor_sequence_parallelism: 1
Config param dcn_tensor_transpose_parallelism: 1
Config param decode_sampling_nucleus_p: -1
Config param decode_sampling_strategy: greedy
Config param decode_sampling_temperature: 1.0
Config param decode_sampling_top_k: 0
Config param decoder_block: DecoderBlockType.LLAMA2
Config param decoder_layer_input: device
Config param dpo_beta: 0.1
Config param dpo_label_smoothing: 0.0
Config param dropout_rate: 0.0
Config param dtype: bfloat16
Config param dtype_mm: float32
Config param dump_hlo: False
Config param dump_hlo_delete_local_after: True
Config param du

Num_devices: 8, shape (1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
[38;2;79;201;177mTransformer[0m[38;2;255;213;3m([0m[38;2;105;105;105m # Param: 8,030,261,248 (16.1 GB), RngState: 4 (24 B), Total: 8,030,261,252 (16.1 GB)[0m
  [38;2;156;220;254mconfig[0m[38;2;212;212;212m=[0m<MaxText.pyconfig.HyperParameters object at 0x7af1b69b5640>,
  [38;2;156;220;254mdecoder[0m[38;2;212;212;212m=[0m[38;2;79;201;177mToNNX[0m[38;2;255;213;3m([0m[38;2;105;105;105m # Param: 7,504,924,672 (15.0 GB), RngState: 4 (24 B), Total: 7,504,924,676 (15.0 GB)[0m
    [38;2;156;220;254mdecoder_norm[0m[38;2;212;212;212m=[0m[38;2;255;213;3m{[0m[38;2;207;144;120m'scale'[0m: [38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 4,096 (8.2 KB)[0m
      [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0m[38;2;79;201;177mArray[0m[38;2;255;213;3m([0m[38;2;156;220;254mshape[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;182;207;169m4096[0m,[38;2;255;213;3m)[0m, [38;2;

In [12]:
nnx.display(nnx.state(model))

In [13]:
def create_maxtext_to_vllm_mappings():
    """Create mappings for transferring MaxText scanned state to vLLM unscanned state."""
    return {
        # Token embeddings - shard vocab dimension for TP
        'base.token_embedder.embedding': ('embed.embedding', ('model', None)),
        
        # Final layer norm - no sharding needed
        'base.decoder.decoder_norm.scale': ('model.norm.scale', (None,)),
        
        # LM head (logits projection) - shard vocab dimension for TP
        'base.decoder.logits_dense.kernel': ('lm_head', (None, 'model')),
        
        # Layer-specific mappings (scanned -> unscanned)
        # MLP components - shard hidden dimensions for TP
        'base.decoder.layers.mlp.wi_0.kernel': ('model.layers.*.mlp.gate_proj.kernel', (None, 'layer', 'model')),  # gate_proj: (4096, 14336) - shard output
        'base.decoder.layers.mlp.wi_1.kernel': ('model.layers.*.mlp.up_proj.kernel', (None, 'layer', 'model')),    # up_proj: (4096, 14336) - shard output  
        'base.decoder.layers.mlp.wo.kernel': ('model.layers.*.mlp.down_proj.kernel', ('model', 'layer', None)),    # down_proj: (14336, 4096) - shard input
        
        # Layer norms - no sharding needed
        'base.decoder.layers.pre_self_attention_layer_norm.scale': ('model.layers.*.input_layernorm.scale', (None, 'layer')),
        'base.decoder.layers.post_self_attention_layer_norm.scale': ('model.layers.*.post_attention_layernorm.scale', (None, 'layer')),
        
        # Attention components - shard head dimensions for TP
        'base.decoder.layers.self_attention.query.kernel': ('model.layers.*.self_attn.q_proj.kernel', (None, 'layer', 'model', None)),  # q_proj: shard num_heads
        'base.decoder.layers.self_attention.key.kernel': ('model.layers.*.self_attn.k_proj.kernel', (None, 'layer', 'model', None)),    # k_proj: shard num_kv_heads
        'base.decoder.layers.self_attention.value.kernel': ('model.layers.*.self_attn.v_proj.kernel', (None, 'layer', 'model', None)),  # v_proj: shard num_kv_heads
        'base.decoder.layers.self_attention.out.kernel': ('model.layers.*.self_attn.o_proj.kernel', ('model', 'layer', None, None)),    # o_proj: shard input heads
    }



In [14]:
from functools import lru_cache


In [15]:
# # Debug

# from tunix.generate import vllm_sampler

# sampler = vllm_sampler.VllmSampler(
#     tokenizer=model_tokenizer,
#     max_model_len=64,
#     mesh=mesh,
#     model_version="meta-llama/Llama-3.1-8B",
#     hbm_utilization=0.3,
#     mapping_config=vllm_sampler.MappingConfig(
#         to_hf_mappings=mappings,
#         to_hf_transpose_keys={},
#         lora_to_hf_mappings=None,
#         lora_config=None,
#     ),
#     tp = 8,
# )


# # Define transpose operations needed for shape compatibility
# transpose_keys = {
#     # # MLP transposes (after layer extraction)
#     # 'wo.kernel': (1, 0),  # down_proj: (14336, 4096) - transpose needed
    
#     # # Attention output transpose (after layer extraction) 
#     # 'out.kernel': (1, 2, 0),  # o_proj: (32, 128, 4096) -> (32, 128, 4096) - reorder dimensions
# }

# from tunix.generate.utils import transfer_state_with_mappings_scanned
# transfer_state_with_mappings(
#     nnx.state(model),
#     sampler._model_runner.state,
#     mappings,
#     transpose_keys=transpose_keys,
# )


In [16]:
mappings = create_maxtext_to_vllm_mappings()

transpose_keys = {
    # # MLP transposes (after layer extraction)
    # 'wo.kernel': (1, 0),  # down_proj: (14336, 4096) - transpose needed
    
    # # Attention output transpose (after layer extraction) 
    # 'out.kernel': (1, 2, 0),  # o_proj: (32, 128, 4096) -> (32, 128, 4096) - reorder dimensions
}


model.to_hf_mappings = create_maxtext_to_vllm_mappings
model.to_hf_transpose_keys = lambda *args: transpose_keys
model.lora_to_hf_mappings = lambda *args: None  # No LoRA

In [17]:
model_tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")

In [18]:
rollout = VllmRollout(model=model,tokenizer=model_tokenizer,cache_config_or_size=64, tp = 8, mesh=mesh,lora_config=None,model_version="meta-llama/Llama-3.1-8B")


TypeError: VllmRollout.__init__() got an unexpected keyword argument 'tp'

In [9]:
from tunix.rl.rollout.base_rollout import RolloutConfig



rollout.generate(["hello world", "how are you?"], rollout_config= RolloutConfig(n=1))

NameError: name 'rollout' is not defined

hi
