# 加载基础dense模型

In [None]:
# !pip install tensorflow==2.16.1
# pip install numpy==1.26.4

import json
import os
import sys
import asyncio
import argparse
from collections import defaultdict
import time

os.environ["JAX_PLATFORMS"] = "cpu"
from etils import epath
import json
import base64

import torch
import numpy as np
import jax.numpy as jnp
import jax
import orbax
import orbax.checkpoint as ocp
from etils import epath
from jax.sharding import PartitionSpec as PS
from flax.traverse_util import flatten_dict, unflatten_dict


METADATA_FILE = '_METADATA'
_CHECKPOINT_FILE = 'checkpoint'

read_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/xm_45x7B_moe/checkpoints/0/state'
# save_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/xm_45x7B_moe_0919_test/checkpoints/'

# read_dir = 'gs://jax_llm_data_us-east5/test/pile_moe_0919/checkpoints/600/state'
# save_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/xm_45x7B_moe_0920_1.2k/checkpoints/'

read_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/PileDCSlimLlama7B32Kx4x256x1v5p_0713/ocdbt/checkpoints/448800/state'
save_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/xm_45x7B_moe_0922/checkpoints/'

read_dir = epath.Path(read_dir) 
save_dir = epath.Path(save_dir)

metadata_path = read_dir / METADATA_FILE
back_metadata_path = read_dir / f'{METADATA_FILE}.back'
try:
    metadata_path.rename(back_metadata_path)
except:
    pass
metadata_path.unlink(missing_ok=True) # delete
structure_path = read_dir / _CHECKPOINT_FILE
msgpack = ocp.aggregate_handlers.MsgpackHandler(0)
structure = msgpack.deserialize(structure_path)
# backup original checkpoint fil
back_structure_path = read_dir / 'checkpoint_back'
back_structure = structure.copy()
if not back_structure_path.exists():
    asyncio.run(msgpack.serialize(back_structure_path, item=back_structure))
print(f'Old structure file keys: {structure.keys()}')
remove_keys = ['opt_state', 'step'] # select the weight name you don't want to load, all weight name: opt_state, step, params
_ = [structure.pop(key) for key in remove_keys if key in structure]
print(f'New structure file keys: {structure.keys()}')
asyncio.run(msgpack.serialize(structure_path, item=structure))  # rewrite struct file

# load model based struct, note: axes must same as training
mesh_axes = ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
devices = np.asarray(jax.devices()).reshape([1] * len(mesh_axes))
mesh = jax.sharding.Mesh(devices, mesh_axes)
sharding = jax.sharding.NamedSharding(mesh, PS()) # Sharding is None because we use cpu to load weights
weight_dtype = jnp.bfloat16 # set restore weights dtype
restore_args = {}
for k, v in flatten_dict(structure).items():
    restore_args[k] =  ocp.ArrayRestoreArgs(restore_type=jax.Array, dtype=weight_dtype, sharding=sharding)
restore_args = unflatten_dict(restore_args)
ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
w = ckptr.restore(read_dir, args=ocp.args.PyTreeRestore(restore_args=restore_args))
structure_path = read_dir / _CHECKPOINT_FILE
# rewrite struct file, otherwise occur error when continue training
asyncio.run(msgpack.serialize(structure_path, item=back_structure))
while 'params' in w:
    w = w['params']
xm3p5_w = {'.'.join(k): np.array(v) for k, v in flatten_dict(w).items()}

try:
    back_metadata_path.rename(metadata_path)
except:
    pass

In [None]:
for k, v in xm3p5_w.items():
    print(k, v.shape, v.dtype)

In [2]:
## 基于dense模型保存和moe相同名字的参数
def convert_to_jnp(params, remove_keys=[]):
    convert_params = {}
    for k, v in params.items():
        r = 0
        for remove_key in remove_keys:
            if remove_key in k: 
                r = 1
                break
        if r: continue
        k = tuple(k.split('.'))
        convert_params[k] = v
        # convert_params[k] = jnp.array(v).astype(jnp.bfloat16)
    for k, v in convert_params.items():
        print(k, v.shape, v.dtype)
    return convert_params


def save_params(step, save_dir, params):
    item = {
            'state': orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(use_ocdbt=False)),
                    }
    # new_params = {tuple(k.split('.')): v for k,v in params.items()}
    unflatten_params = unflatten_dict(params)
    for k, v in params.items():
        print(k, v.shape, v.dtype)
    mngr = orbax.checkpoint.CheckpointManager(save_dir, item)
    if 'params' not in unflatten_params: unflatten_params = {'params': unflatten_params}
    mngr.save(step, items={'state': {'params': unflatten_params}})

convert_params = convert_to_jnp(xm3p5_w, remove_keys=['mgate'])
save_params(0, save_dir, convert_params)

k: decoder.decoder_norm.scale take: 0.000s
k: decoder.layers.mlp_0.mgate.kernel take: 0.001s
0
k: decoder.layers.mlp_0.wi_0.kernel take: 0.001s
0
copy_w: (44, 12, 4096, 5632) copy_w: float16 v: (1, 12, 4096, 5632) unshared_mlp: decoder.layers.unshared_mlp_0.wi_0.kernel
k: decoder.layers.mlp_0.wi_1.kernel take: 122.405s
0
copy_w: (44, 12, 4096, 5632) copy_w: float16 v: (1, 12, 4096, 5632) unshared_mlp: decoder.layers.unshared_mlp_0.wi_1.kernel
k: decoder.layers.mlp_0.wo.kernel take: 244.480s
0
copy_w: (44, 12, 5632, 4096) copy_w: float16 v: (1, 12, 5632, 4096) unshared_mlp: decoder.layers.unshared_mlp_0.wo.kernel
k: decoder.layers.mlp_1.mgate.kernel take: 368.050s
1
k: decoder.layers.mlp_1.wi_0.kernel take: 368.050s
1
k: decoder.layers.mlp_1.wi_1.kernel take: 368.050s
1
k: decoder.layers.mlp_1.wo.kernel take: 368.050s
1
k: decoder.layers.mlp_2.mgate.kernel take: 368.050s
2
k: decoder.layers.mlp_2.wi_0.kernel take: 368.050s
2
k: decoder.layers.mlp_2.wi_1.kernel take: 368.050s
2
k: decode

In [2]:
## moe部分的参数保存，保存后在bucket后台人工进行转移
start_time = time.time()
unshared_experts = 44

scale = 1
mlp_dim = 5632 // scale
model_dim = 4096 // scale
copy_dim = 128 // scale
copy_expert_dim = mlp_dim // unshared_experts
# fp16_dtype = np.dtype('float16')
# 4个子层
total_moe_params = [{} for i in range(4)]
for k, v in xm3p5_w.items():
    v = jnp.array(v).astype(jnp.bfloat16)
    print(f'k: {k} take: {time.time() - start_time:.3f}s')
    if 'decoder.layers.mlp_' in k:
        mlp_inx = k.find('mlp_')
        l = k[mlp_inx+4: mlp_inx+5]
        if int(l) in [2, 3]:
            continue
        if int(l) % 2 != 0: continue
        moe_params = total_moe_params[int(l)]
        unshared_mlp = k.replace('decoder.layers.mlp_', 'decoder.layers.unshared_mlp_')
        if 'mgate' in unshared_mlp:
            moe_params[unshared_mlp] = v
            continue
        unshared_mlp = unshared_mlp.replace('.kernel', '')
        v = v.transpose(1, 0, 2)
        # unshared: 44 * 12 * model_dim * mlp_dim,  mlp: model_dim * 12 * mlp_dim
        unshared_mlp_w = []
        if '.wo.' in k:
            copy_w = v.reshape(12, 44, 1, copy_dim, model_dim).repeat(44, 2).reshape(12, 44, -1, model_dim).transpose(1, 0, 2, 3)
        else:
            copy_w = v.reshape(12, model_dim, 44, 1, copy_dim).repeat(44, 3).reshape(12, model_dim, 44, -1).transpose(2, 0, 1, 3)
        v = v[None] # extend dim
        init_w = (copy_w + v) / 2
        print(f'copy_w: {copy_w.shape} copy_w: {copy_w.dtype} init_w: {init_w.dtype} unshared_mlp: {unshared_mlp}')
        moe_params[unshared_mlp] = init_w


k: decoder.decoder_norm.scale take: 0.000s
k: decoder.layers.mlp_0.mgate.kernel take: 0.000s
0
k: decoder.layers.mlp_0.wi_0.kernel take: 0.000s
0
k: decoder.layers.mlp_0.wi_1.kernel take: 0.000s
0
k: decoder.layers.mlp_0.wo.kernel take: 0.000s
0
k: decoder.layers.mlp_1.mgate.kernel take: 0.000s
1
k: decoder.layers.mlp_1.wi_0.kernel take: 0.000s
1
k: decoder.layers.mlp_1.wi_1.kernel take: 0.000s
1
k: decoder.layers.mlp_1.wo.kernel take: 0.000s
1
k: decoder.layers.mlp_2.mgate.kernel take: 0.000s
2
k: decoder.layers.mlp_2.wi_0.kernel take: 0.000s
2
k: decoder.layers.mlp_2.wi_1.kernel take: 0.000s
2
k: decoder.layers.mlp_2.wo.kernel take: 0.000s
2
k: decoder.layers.mlp_3.mgate.kernel take: 0.000s
3
k: decoder.layers.mlp_3.wi_0.kernel take: 0.000s
3
k: decoder.layers.mlp_3.wi_1.kernel take: 0.000s
3
k: decoder.layers.mlp_3.wo.kernel take: 0.000s
3
k: decoder.layers.post_self_attention_layer_norm_0.scale take: 0.000s
k: decoder.layers.post_self_attention_layer_norm_1.scale take: 0.000s
k: de

# 分批保存模型参数

In [8]:
for i, params in enumerate(total_moe_params):
    if not params: continue
    print(params.keys(), '\n')
    save_moe_params = {}
    for k, v in params.items():
        newk = tuple(k.split('.'))
        save_moe_params[newk] = v
    print(f'Save step: {i+1}')
    save_params(i + 1, save_dir, save_moe_params)



True

In [None]:
# _METADATA, checkpoint三个文件转移到实际保存的checkpoint文件夹中进行替换
start_time = time.time()
unshared_experts = 44
mlp_dim = 5632 // scale
copy_expert_dim = mlp_dim // unshared_experts
moe_params = {}
for k, v in xm3p5_w.items():
    v = jnp.array([100]).astype(jnp.bfloat16)
    print(f'k: {k} take: {time.time() - start_time:.3f}s')
    if 'mgate' not in k:
        moe_params[k] = v
    if 'decoder.layers.mlp_' in k:
        mlp_inx = k.find('mlp_')
        l = k[mlp_inx+4: mlp_inx+5]
        if int(l) % 2 != 0: continue
        unshared_mlp = k.replace('decoder.layers.mlp_', 'decoder.layers.unshared_mlp_')
        # mgate需要.kernel, moe不需要，因为moe的参数采用的是self.params创建的，而mgate是采用flax的DenseGeneral
        if 'mgate' not in unshared_mlp:
            unshared_mlp = unshared_mlp.replace('.kernel', '')
        moe_params[unshared_mlp] = v

moe_params = {tuple(k.split('.')): v for k,v in moe_params.items()}
save_params(5, save_dir, moe_params)

In [None]:
import subprocess

# save_dir1 = str(save_dir).rstrip('/')
source_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/xm_45x7B_moe_0922/checkpoints'
target_dir = 'gs://llm_base_models_us-east5/v5p_256/7B/xm_45x7B_moe_0922/checkpoints'
source_step = 5
command = f'gsutil cp {source_dir}/{source_step}/state/_METADATA {target_dir}/0/state/ '
r = subprocess.run(command, stdout=subprocess.PIPE, shell=True)

command = f'gsutil cp {source_dir}/{source_step}/state/checkpoint {target_dir}/0/state/ '
r = subprocess.run(command, stdout=subprocess.PIPE, shell=True)

# command = f'gsutil cp {source_dir}/{source_step}/state/_sharding {target_dir}/0/state/ '
# r = subprocess.run(command, stdout=subprocess.PIPE, shell=True)

In [None]:
# 基于tpu type 构建_sharding文件
'''
_sharding文件格式如下：
{
  b3B0X3N0YXRlLm11LnBhcmFtcy50b2tlbl9lbWJlZGRlci5lbWJlZGRpbmc=': {'sharding_type': 'NamedSharding',
  'shape': [1, 1, 4, 1, 1, 1, 1],
  'axis_names': ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor','autoregressive'],
  'partition_spec': [['tensor', 'autoregressive'], ['fsdp', 'fsdp_transpose', 'sequence']],
   2: 4},
   ...
   }
   '''
# moe sharding
_sharding_path = 'gs://llm_base_models_us-east5/v5p_256/7B/xm3p5_moe_params_no_opt_v5p_64_sharding'
_sharding_path = epath.Path(_sharding_path)
# 读取已有的_sharding文件
with _sharding_path.open('r') as f:
    _sharding = json.load(f)

tpu_type = 'v5p-64'
core_nums = int(tpu_type.split('-')[-1])
if 'v3' not in tpu_type:
    core_nums = core_nums // 2
print(f'core_nums: {core_nums}')
updated_sharding = {}
for k, v in _sharding.items():
    v = json.loads(v)
    v['shape'][2] = core_nums
    updated_sharding[k] = json.dumps(v)
    
updated_sharding_path = f'{save_dir}/0/state/_sharding'

updated_sharding_path = epath.Path(updated_sharding_path)
with updated_sharding_path.open('w') as f:
    json.dump(updated_sharding, f)

In [None]:
# import base64

# def decode_base64(encoded_str):
#     decoded_bytes = base64.b64decode(encoded_str)
#     decoded_str = decoded_bytes.decode('utf-8')
#     return decoded_str
    
# for k, v in _sharding.items():
#     decode_k = decode_base64(k)
#     if 'opt_state' in decode_k or 'step' in decode_k: continue
#     print(decode_k)