In [None]:
import os
import sys
import yaml
import json
import base64
from collections import defaultdict
from typing import Tuple
import functools

sys.path.append('/home/lishengping/projects/maxtext/MaxText')
os.environ['HARDWARE'] = 'tpu'

from layers import models
import max_utils
import jax
import orbax
import jax.numpy as jnp
from jax.sharding import Mesh
from flax.traverse_util import flatten_dict, unflatten_dict
from flax import linen as nn
from transformers import AutoTokenizer
from etils import epath
import orbax.checkpoint as ocp

import pyconfig
from jax.sharding import PartitionSpec
from flax.linen import partitioning as nn_partitioning


TOKENIZER_PATH = '/home/lishengping/tokenizer'
if not os.path.exists(TOKENIZER_PATH):
    !gsutil cp -r gs://llm_base_models_us-east5/qwen/tokenizer /home/lishengping/
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True, trust_remote_code=True)

read_dir = "gs://llm_base_models_europe-west4/v5p_256/7B/PileDCSlimLlama7B32Kx4x256x1v5p_0713/checkpoints"
read_dir = 'gs://llm_base_models_europe-west4/v5p_256/7B/xm3.5-7b-chat-v6/checkpoints/'
read_dir = 'gs://llm_base_models_europe-west4/v5p_256/7B/xm3.5-7b-chat-v7/checkpoints/'

# read_dir = 'gs://llm_base_models_europe-west4/v5p_256/7B/dense_continue_1209/checkpoints'
# read_dir = "gs://llm_base_models_europe-west4/v5p_256/7B/test"

read_dir = epath.Path(read_dir)

config_name = '/home/lishengping/projects/maxtext/MaxText/configs/dc_8x7b_moe.yml'

argv = [None, config_name]
pyconfig.initialize(argv)
config = pyconfig.config
# validate_train_config(config)
devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)

In [None]:
def decode_base64(encoded_str):
    decoded_bytes = base64.b64decode(encoded_str)
    decoded_str = decoded_bytes.decode('utf-8')
    return decoded_str


def mesh_shard_rules(mesh, rules, remove_keys=[]):
    _sharding_dict = {}
    for name, rule in rules.items():
        if isinstance(rule, str):
            rule = json.loads(rule)
        name = decode_base64(name)
        param_key = tuple(name.split('.'))
        remove = any([1 if key in param_key else 0 for key in remove_keys])
        if remove: continue
        prule = [tuple(r) if isinstance(r, list) else r for r in rule['partition_spec'] ]
        spec = jax.sharding.PartitionSpec(*prule)
        _sharding_dict[param_key] = jax.sharding.NamedSharding(mesh, spec)
    return _sharding_dict


def rewrite_bucket_sharding(mesh, old_sharding, save_path):
    cur_machine_sharding = {}
    for k, v in old_sharding.items():
        if isinstance(v, str):
            v = json.loads(v)
        v['shape'] = mesh.device_ids.shape
        cur_machine_sharding[k] = v
    save_path = epath.Path(save_path)
    with save_path.open('w') as f:
        json.dump(cur_machine_sharding, f)
    
# mesh length is 8
_sharding_path = 'gs://llm_base_models_europe-west4/v5p_256/7B/xm3p5_7b_sharding'
# dense
_sharding_path = 'gs://llm_base_models_europe-west4/v5p_256/7B/xm3.5-7b-chat-v6/checkpoints/451600/state/_sharding'
# moe with unshare
# _sharding_path = 'gs://llm_base_models_europe-west4/v5p_256/7B/xm3p5_7b_with_perlayer_moe_sharding'
_sharding_path = epath.Path(_sharding_path)

# remove_keys = ['opt_state', 'step']
remove_keys = []
with _sharding_path.open('r') as f:
    _sharding_rules = json.load(f)
_sharding_dict = mesh_shard_rules(mesh, _sharding_rules, remove_keys=remove_keys)
_sharding_dict = unflatten_dict(_sharding_dict)
restore_args = {}
weight_dtype = jnp.bfloat16
for k, v in flatten_dict(_sharding_dict).items():
    joink = '.'.join(k)
    if 'unshare' in joink: continue
    restore_args[k] =  ocp.ArrayRestoreArgs(restore_type=jax.Array, dtype=weight_dtype, sharding=v)
for k, v in restore_args.items():
    print(k)
restore_args = unflatten_dict(restore_args)

In [None]:
# load
read_dir = 'gs://llm_base_models_europe-west4/v5p_256/7B/PileDCSlimLlama7B32Kx4x256x1v5p_0713/checkpoints/440000/state'
read_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/xm3.5-7b-chat-v1/checkpoints/450000/state'
read_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/xm3.5-7b-chat-v2/checkpoints/453200/state'
read_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/xm3.5-7b-chat-v3/checkpoints/453200/state'
read_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/xm3.5-7b-chat-v4/checkpoints/451000/state'
read_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/xm3.5-7b-chat-v5/checkpoints/450800/state'
read_dir = 'gs://llm_base_models_us-central2/v5p_256/7B/PileDCSlimLlama7B4Kx4x256x1v5p/checkpoints/checkpoint_00100000/state'
read_dir = epath.Path(read_dir)
use_ocdbt = False
ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt))
# restore_args: 结构必须和模型结构一致
if use_ocdbt:
    state = ckptr.restore(read_dir, args=ocp.args.PyTreeRestore(restore_args=restore_args))
else:
    state = ckptr.restore(read_dir)
params = {'params': state['params']}


In [None]:
ffn_params = {}
for k, v in flatten_dict(params).items():
    joink = '.'.join(k)
    if 'layers.mlp' in joink:
        ffn_params[k] = v
    print(joink)
print(len(ffn_params))

In [None]:
# save
# xm3.5-7b-chat-v6: 0
# xm3.5-7b-chat-v7: 1
# PileDCSlimLlama7B32Kx4x256x1v5p_0713: 440000: 2
# xm3.5-7b-chat-v1: 3
# xm3.5-7b-chat-v2: 4
# xm3.5-7b-chat-v3: 5
# xm3.5-7b-chat-v4: 6
# xm3.5-7b-chat-v5: 7
# 下面的转不了，是pax格式
# PileDCSlimLlama7B32Kx4x256x1v5p: 100000: 8
# PileDCSlimLlama7B32Kx4x256x1v5p: 200000: 9
# PileDCSlimLlama7B32Kx4x256x1v5p: 300000: 10
# PileDCSlimLlama7B32Kx4x256x1v5p: 400000: 11
save_step = 8
save_dir = "gs://llm_base_models_europe-west4/v5p_256/7B/moe_ffn_init_base_multi_dmodels_1210/checkpoints"
save_dir = epath.Path(save_dir)
item = {
    "state": orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler(use_ocdbt=False))
}
save_params = unflatten_dict(ffn_params)
assert 'params' not in save_params['params']['params']
max_mngr = orbax.checkpoint.CheckpointManager(save_dir, item)
max_mngr.save(save_step, {'state': save_params})