In [None]:
## 说明
- 基于440000预训练的dense模型进行初始化，0个共享专家，11个非共享专家，top2, moe dim=4096, 初始化采用自设计方案
- 其次采用的是megablox的实现方案

In [None]:
# !pip install tensorflow==2.16.1
# pip install numpy==1.26.4
# 运行2遍
import json
import os
import sys
import asyncio
import argparse
from collections import defaultdict
import time
import subprocess

os.environ["JAX_PLATFORMS"] = "cpu"
from etils import epath
import json
import base64

import numpy as np
import jax.numpy as jnp
import jax
import orbax
import orbax.checkpoint as ocp
from etils import epath
from jax.sharding import PartitionSpec as PS
from flax.traverse_util import flatten_dict, unflatten_dict


METADATA_FILE = '_METADATA'
_CHECKPOINT_FILE = 'checkpoint'


read_dir = 'gs://llm_base_models_europe-west4/v5p_256/7B/PileDCSlimLlama7B32Kx4x256x1v5p_0713/checkpoints/440000/state'
save_dir = 'gs://llm_base_models_europe-west4/v5p_256/7B/xm_E8x7B_OnlyUnshareNoMgate_megablox_1215/checkpoints/'

read_dir = epath.Path(read_dir) 
save_dir = epath.Path(save_dir)

metadata_path = read_dir / METADATA_FILE
back_metadata_path = read_dir / f'{METADATA_FILE}.back'
try:
    metadata_path.rename(back_metadata_path)
except:
    pass
metadata_path.unlink(missing_ok=True) # delete
structure_path = read_dir / _CHECKPOINT_FILE
msgpack = ocp.aggregate_handlers.MsgpackHandler(0)
structure = msgpack.deserialize(structure_path)
# backup original checkpoint fil
back_structure_path = read_dir / 'checkpoint_back'
back_structure = structure.copy()
if not back_structure_path.exists():
    asyncio.run(msgpack.serialize(back_structure_path, item=back_structure))
print(f'Old structure file keys: {structure.keys()}')
remove_keys = ['opt_state', 'step'] # select the weight name you don't want to load, all weight name: opt_state, step, params
_ = [structure.pop(key) for key in remove_keys if key in structure]
print(f'New structure file keys: {structure.keys()}')
asyncio.run(msgpack.serialize(structure_path, item=structure))  # rewrite struct file

# load model based struct, note: axes must same as training
mesh_axes = ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
devices = np.asarray(jax.devices()).reshape([1] * len(mesh_axes))
mesh = jax.sharding.Mesh(devices, mesh_axes)
sharding = jax.sharding.NamedSharding(mesh, PS()) # Sharding is None because we use cpu to load weights
weight_dtype = jnp.bfloat16 # set restore weights dtype
restore_args = {}
for k, v in flatten_dict(structure).items():
    restore_args[k] =  ocp.ArrayRestoreArgs(restore_type=jax.Array, dtype=weight_dtype, sharding=sharding)
restore_args = unflatten_dict(restore_args)
ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
w = ckptr.restore(read_dir, args=ocp.args.PyTreeRestore(restore_args=restore_args))
structure_path = read_dir / _CHECKPOINT_FILE
# rewrite struct file, otherwise occur error when continue training
asyncio.run(msgpack.serialize(structure_path, item=back_structure))
while 'params' in w:
    w = w['params']
xm3p5_w = {'.'.join(k): np.array(v) for k, v in flatten_dict(w).items()}

try:
    back_metadata_path.rename(metadata_path)
except:
    pass

In [None]:
## 基于dense模型保存和moe相同名字的参数
def convert_to_jnp(params, remove_keys=[]):
    convert_params = {}
    for k, v in params.items():
        r = 0
        for remove_key in remove_keys:
            if remove_key in k: 
                r = 1
                break
        if r: continue
        k = tuple(k.split('.'))
        convert_params[k] = v
        # convert_params[k] = jnp.array(v).astype(jnp.bfloat16)
    for k, v in convert_params.items():
        print(k, v.shape, v.dtype)
    return convert_params


def save_params(step, save_dir, params):
    item = {
            'state': orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(use_ocdbt=False)),
                    }
    unflatten_params = unflatten_dict(params)
    for k, v in params.items():
        print(k, v.shape, v.dtype)
    mngr = orbax.checkpoint.CheckpointManager(save_dir, item)
    if 'params' not in unflatten_params: unflatten_params = {'params': unflatten_params}
    mngr.save(step, items={'state': {'params': unflatten_params}})

params_save_step = 0
convert_params = convert_to_jnp(xm3p5_w, remove_keys=['mlp_'])
save_params(params_save_step, save_dir, convert_params)

In [None]:
## moe部分的参数保存，保存后在bucket后台人工进行转移
start_time = time.time()
unshared_experts = 8
expert_dim = 5632

scale = 1
mlp_dim = 5632 // scale
model_dim = 4096 // scale
copy_dim = mlp_dim // unshared_experts

# fp16_dtype = np.dtype('float16')
# 4个子层
moe_params = {}
example_params = {}
np.random.seed(42)
for k, v in xm3p5_w.items():
    v = jnp.array(v).astype(jnp.bfloat16)
    ev = jnp.array([100]).astype(jnp.bfloat16)
    if 'decoder.layers.mlp_' in k:
        mlp_inx = k.find('mlp_')
        l = k[mlp_inx+4: mlp_inx+5]
        unshared_mlp = k.replace('decoder.layers.mlp_', 'decoder.layers.unshared_mlp_')
        unshared_mlp = unshared_mlp.replace('.kernel', '')
        if 'mgate' not in unshared_mlp:
            v = v.transpose(1, 0, 2)
            # unshared: unshared_experts * 12 * model_dim * mlp_dim,  mlp: model_dim * 12 * mlp_dim
            unshared_mlp_w = []
            if '.wo.' in k:
                copy_w = v.reshape(12, unshared_experts, 1, copy_dim, model_dim).repeat(expert_dim // copy_dim, 2).reshape(
                    12, unshared_experts, -1, model_dim).transpose(1, 0, 2, 3)
            else:
                copy_w = v.reshape(12, model_dim, unshared_experts, 1, copy_dim).repeat(expert_dim // copy_dim, 3).reshape(
                    12, model_dim, unshared_experts, -1).transpose(2, 0, 1, 3)
            # 加噪声
            mask = jnp.array(np.random.randint(-2, 3, copy_w.shape) * 0.1, dtype=jnp.bfloat16)
            mask = mask.astype(jnp.bfloat16)
            init_w = copy_w  + copy_w * mask
            moe_params[unshared_mlp] = init_w
            example_params[unshared_mlp] = ev
            print(f'unshared_mlp: {unshared_mlp} init_w: {init_w.shape} {init_w.dtype} take: {time.time()-start_time:.3f}s')
            
        else:
            router_gate = unshared_mlp.replace('mgate', 'router_gate.kernel')
            init_w = v.reshape(v.shape[0], v.shape[1], unshared_experts, -1).mean(-1)
            moe_params[router_gate] = init_w
            print(f'router_gate: {router_gate} init_w: {init_w.shape}  {init_w.dtype}, take: {time.time()-start_time:.3f}s')
            example_params[router_gate] = ev
    else:
        example_params[k] = ev
        
            
moe_params = {tuple(k.split('.')): v for k, v in moe_params.items()}
example_params = {tuple(k.split('.')): v for k,v in example_params.items()}

for k, v in moe_params.items():
    print(k, v.shape, v[0].sum(), v[1].sum())

moe_save_step = params_save_step + 1
example_step = params_save_step + 2

save_params(moe_save_step, save_dir, moe_params)
save_params(example_step, save_dir, example_params)

In [None]:
import subprocess

source_dir = str(save_dir).rstrip('/')
target_dir = str(save_dir).rstrip('/')

command = f'gsutil cp {source_dir}/{example_step}/state/_METADATA {target_dir}/{params_save_step}/state/ '
r = subprocess.run(command, stdout=subprocess.PIPE, shell=True)

command = f'gsutil cp {source_dir}/{example_step}/state/checkpoint {target_dir}/{params_save_step}/state/ '
r = subprocess.run(command, stdout=subprocess.PIPE, shell=True)

command = f'gsutil -m cp -r {source_dir}/{moe_save_step}/state/params.params* {target_dir}/{params_save_step}/state/ '
r = subprocess.run(command, stdout=subprocess.PIPE, shell=True)