### 说明
- 基于440000预训练的dense模型进行初始化，1个共享专家，8个非共享专家，共享和非共享专家均直接复制dense的mlp。
- 其次采用的是openmoe的实现方案

In [1]:
# !pip install tensorflow==2.16.1
# pip install numpy==1.26.4
# 运行2遍
import json
import os
import sys
import asyncio
import argparse
from collections import defaultdict
import time

os.environ["JAX_PLATFORMS"] = "cpu"
from etils import epath
import json
import base64

import torch
import numpy as np
import jax.numpy as jnp
import jax
import orbax
import orbax.checkpoint as ocp
from etils import epath
from jax.sharding import PartitionSpec as PS
from flax.traverse_util import flatten_dict, unflatten_dict


METADATA_FILE = '_METADATA'
_CHECKPOINT_FILE = 'checkpoint'


read_dir = 'gs://llm_base_models_europe-west4/v5p_256/7B/PileDCSlimLlama7B32Kx4x256x1v5p_0713/checkpoints/440000/state'
save_dir = 'gs://llm_base_models_europe-west4/v5p_256/7B/xm_45x7B_moe_129/checkpoints/'

read_dir = epath.Path(read_dir) 
save_dir = epath.Path(save_dir)

metadata_path = read_dir / METADATA_FILE
back_metadata_path = read_dir / f'{METADATA_FILE}.back'
try:
    metadata_path.rename(back_metadata_path)
except:
    pass
metadata_path.unlink(missing_ok=True) # delete
structure_path = read_dir / _CHECKPOINT_FILE
msgpack = ocp.aggregate_handlers.MsgpackHandler(0)
structure = msgpack.deserialize(structure_path)
# backup original checkpoint fil
back_structure_path = read_dir / 'checkpoint_back'
back_structure = structure.copy()
if not back_structure_path.exists():
    asyncio.run(msgpack.serialize(back_structure_path, item=back_structure))
print(f'Old structure file keys: {structure.keys()}')
remove_keys = ['opt_state', 'step'] # select the weight name you don't want to load, all weight name: opt_state, step, params
_ = [structure.pop(key) for key in remove_keys if key in structure]
print(f'New structure file keys: {structure.keys()}')
asyncio.run(msgpack.serialize(structure_path, item=structure))  # rewrite struct file

# load model based struct, note: axes must same as training
mesh_axes = ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
devices = np.asarray(jax.devices()).reshape([1] * len(mesh_axes))
mesh = jax.sharding.Mesh(devices, mesh_axes)
sharding = jax.sharding.NamedSharding(mesh, PS()) # Sharding is None because we use cpu to load weights
weight_dtype = jnp.bfloat16 # set restore weights dtype
restore_args = {}
for k, v in flatten_dict(structure).items():
    restore_args[k] =  ocp.ArrayRestoreArgs(restore_type=jax.Array, dtype=weight_dtype, sharding=sharding)
restore_args = unflatten_dict(restore_args)
ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
w = ckptr.restore(read_dir, args=ocp.args.PyTreeRestore(restore_args=restore_args))
structure_path = read_dir / _CHECKPOINT_FILE
# rewrite struct file, otherwise occur error when continue training
asyncio.run(msgpack.serialize(structure_path, item=back_structure))
while 'params' in w:
    w = w['params']
xm3p5_w = {'.'.join(k): np.array(v) for k, v in flatten_dict(w).items()}

try:
    back_metadata_path.rename(metadata_path)
except:
    pass

2024-11-27 19:10:06.572276: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-27 19:10:06.588055: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-27 19:10:06.588083: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Old structure file keys: dict_keys(['params'])
New structure file keys: dict_keys(['params'])


I0000 00:00:1732734613.567408   29825 gcs_resource.cc:109] Using default AdmissionQueue with limit 32
I0000 00:00:1732734613.571445   30404 google_auth_provider.cc:180] Running on GCE, using service account 887571727717-compute@developer.gserviceaccount.com


In [2]:
## 基于dense模型保存和moe相同名字的参数
def convert_to_jnp(params, remove_keys=[]):
    convert_params = {}
    for k, v in params.items():
        r = 0
        for remove_key in remove_keys:
            if remove_key in k: 
                r = 1
                break
        if r: continue
        k = tuple(k.split('.'))
        convert_params[k] = v
        # convert_params[k] = jnp.array(v).astype(jnp.bfloat16)
    for k, v in convert_params.items():
        print(k, v.shape, v.dtype)
    return convert_params


def save_params(step, save_dir, params):
    item = {
            'state': orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(use_ocdbt=False)),
                    }
    # new_params = {tuple(k.split('.')): v for k,v in params.items()}
    unflatten_params = unflatten_dict(params)
    for k, v in params.items():
        print(k, v.shape, v.dtype)
    mngr = orbax.checkpoint.CheckpointManager(save_dir, item)
    if 'params' not in unflatten_params: unflatten_params = {'params': unflatten_params}
    mngr.save(step, items={'state': {'params': unflatten_params}})

# convert_params = convert_to_jnp(xm3p5_w, remove_keys=[])
# save_params(0, save_dir, convert_params)

In [3]:
## moe部分的参数保存，保存后在bucket后台人工进行转移
start_time = time.time()
unshared_experts = 8

scale = 1
mlp_dim = 5632 // scale
model_dim = 4096 // scale
# fp16_dtype = np.dtype('float16')
# 4个子层
total_moe_params = [{} for i in range(4)]
for k, v in xm3p5_w.items():
    v = jnp.array(v).astype(jnp.bfloat16)
    print(f'k: {k} take: {time.time() - start_time:.3f}s')
    if 'decoder.layers.mlp_' in k:
        mlp_inx = k.find('mlp_')
        l = k[mlp_inx+4: mlp_inx+5]
        # if int(l) % 2 != 0: continue
        moe_params = total_moe_params[int(l)]
        unshared_mlp = k.replace('decoder.layers.mlp_', 'decoder.layers.unshared_mlp_')
        unshared_mlp = unshared_mlp.replace('.kernel', '')
        copy_w = v.transpose(1, 0, 2)[None].repeat(unshared_experts, 0)
        print(f'unshared_mlp: {unshared_mlp} v: {v.shape} copy_w: {copy_w.shape}')
        moe_params[unshared_mlp] = copy_w
        if 'mgate' in unshared_mlp:
            router_gate = unshared_mlp.replace('mgate', 'router_gate.kernel')
            moe_params[router_gate] = v[...,:unshared_experts]

k: decoder.decoder_norm.scale take: 0.050s
k: decoder.layers.mlp_0.mgate.kernel take: 0.064s
unshared_mlp: decoder.layers.unshared_mlp_0.mgate v: (4096, 12, 44) copy_w: (8, 12, 4096, 44)
k: decoder.layers.mlp_0.wi_0.kernel take: 0.244s
unshared_mlp: decoder.layers.unshared_mlp_0.wi_0 v: (4096, 12, 5632) copy_w: (8, 12, 4096, 5632)
k: decoder.layers.mlp_0.wi_1.kernel take: 1.265s
unshared_mlp: decoder.layers.unshared_mlp_0.wi_1 v: (4096, 12, 5632) copy_w: (8, 12, 4096, 5632)
k: decoder.layers.mlp_0.wo.kernel take: 1.900s
unshared_mlp: decoder.layers.unshared_mlp_0.wo v: (5632, 12, 4096) copy_w: (8, 12, 5632, 4096)
k: decoder.layers.mlp_1.mgate.kernel take: 2.617s
unshared_mlp: decoder.layers.unshared_mlp_1.mgate v: (4096, 12, 44) copy_w: (8, 12, 4096, 44)
k: decoder.layers.mlp_1.wi_0.kernel take: 2.626s
unshared_mlp: decoder.layers.unshared_mlp_1.wi_0 v: (4096, 12, 5632) copy_w: (8, 12, 4096, 5632)
k: decoder.layers.mlp_1.wi_1.kernel take: 2.626s
unshared_mlp: decoder.layers.unshared_ml

In [5]:
for i, params in enumerate(total_moe_params):
    if not params: continue
    print(f'Save step: {i+1}')
    # print(params.keys(), '\n')
    save_moe_params = {}
    for k, v in params.items():
        newk = tuple(k.split('.'))
        save_moe_params[newk] = v
        print(newk, v.shape)
    save_params(i + 1, save_dir, save_moe_params)



Save step: 1
('decoder', 'layers', 'unshared_mlp_0', 'mgate') (8, 12, 4096, 44)
('decoder', 'layers', 'unshared_mlp_0', 'router_gate', 'kernel') (4096, 12, 8)
('decoder', 'layers', 'unshared_mlp_0', 'wi_0') (8, 12, 4096, 5632)
('decoder', 'layers', 'unshared_mlp_0', 'wi_1') (8, 12, 4096, 5632)
('decoder', 'layers', 'unshared_mlp_0', 'wo') (8, 12, 5632, 4096)
('decoder', 'layers', 'unshared_mlp_0', 'mgate') (8, 12, 4096, 44) bfloat16
('decoder', 'layers', 'unshared_mlp_0', 'router_gate', 'kernel') (4096, 12, 8) bfloat16
('decoder', 'layers', 'unshared_mlp_0', 'wi_0') (8, 12, 4096, 5632) bfloat16
('decoder', 'layers', 'unshared_mlp_0', 'wi_1') (8, 12, 4096, 5632) bfloat16
('decoder', 'layers', 'unshared_mlp_0', 'wo') (8, 12, 5632, 4096) bfloat16




Save step: 2
('decoder', 'layers', 'unshared_mlp_1', 'mgate') (8, 12, 4096, 44)
('decoder', 'layers', 'unshared_mlp_1', 'router_gate', 'kernel') (4096, 12, 8)
('decoder', 'layers', 'unshared_mlp_1', 'wi_0') (8, 12, 4096, 5632)
('decoder', 'layers', 'unshared_mlp_1', 'wi_1') (8, 12, 4096, 5632)
('decoder', 'layers', 'unshared_mlp_1', 'wo') (8, 12, 5632, 4096)
('decoder', 'layers', 'unshared_mlp_1', 'mgate') (8, 12, 4096, 44) bfloat16
('decoder', 'layers', 'unshared_mlp_1', 'router_gate', 'kernel') (4096, 12, 8) bfloat16
('decoder', 'layers', 'unshared_mlp_1', 'wi_0') (8, 12, 4096, 5632) bfloat16
('decoder', 'layers', 'unshared_mlp_1', 'wi_1') (8, 12, 4096, 5632) bfloat16
('decoder', 'layers', 'unshared_mlp_1', 'wo') (8, 12, 5632, 4096) bfloat16




Save step: 3
('decoder', 'layers', 'unshared_mlp_2', 'mgate') (8, 12, 4096, 44)
('decoder', 'layers', 'unshared_mlp_2', 'router_gate', 'kernel') (4096, 12, 8)
('decoder', 'layers', 'unshared_mlp_2', 'wi_0') (8, 12, 4096, 5632)
('decoder', 'layers', 'unshared_mlp_2', 'wi_1') (8, 12, 4096, 5632)
('decoder', 'layers', 'unshared_mlp_2', 'wo') (8, 12, 5632, 4096)
('decoder', 'layers', 'unshared_mlp_2', 'mgate') (8, 12, 4096, 44) bfloat16
('decoder', 'layers', 'unshared_mlp_2', 'router_gate', 'kernel') (4096, 12, 8) bfloat16
('decoder', 'layers', 'unshared_mlp_2', 'wi_0') (8, 12, 4096, 5632) bfloat16
('decoder', 'layers', 'unshared_mlp_2', 'wi_1') (8, 12, 4096, 5632) bfloat16
('decoder', 'layers', 'unshared_mlp_2', 'wo') (8, 12, 5632, 4096) bfloat16




Save step: 4
('decoder', 'layers', 'unshared_mlp_3', 'mgate') (8, 12, 4096, 44)
('decoder', 'layers', 'unshared_mlp_3', 'router_gate', 'kernel') (4096, 12, 8)
('decoder', 'layers', 'unshared_mlp_3', 'wi_0') (8, 12, 4096, 5632)
('decoder', 'layers', 'unshared_mlp_3', 'wi_1') (8, 12, 4096, 5632)
('decoder', 'layers', 'unshared_mlp_3', 'wo') (8, 12, 5632, 4096)
('decoder', 'layers', 'unshared_mlp_3', 'mgate') (8, 12, 4096, 44) bfloat16
('decoder', 'layers', 'unshared_mlp_3', 'router_gate', 'kernel') (4096, 12, 8) bfloat16
('decoder', 'layers', 'unshared_mlp_3', 'wi_0') (8, 12, 4096, 5632) bfloat16
('decoder', 'layers', 'unshared_mlp_3', 'wi_1') (8, 12, 4096, 5632) bfloat16
('decoder', 'layers', 'unshared_mlp_3', 'wo') (8, 12, 5632, 4096) bfloat16


In [14]:
# _METADATA, checkpoint三个文件转移到实际保存的checkpoint文件夹中进行替换
start_time = time.time()
mlp_dim = 5632 // scale
moe_params = {}
for k, v in xm3p5_w.items():
    v = jnp.array([100]).astype(jnp.bfloat16)
    print(f'k: {k} take: {time.time() - start_time:.3f}s')
    # if 'mgate' not in k:
    moe_params[k] = v
    if 'decoder.layers.mlp_' in k:
        mlp_inx = k.find('mlp_')
        l = k[mlp_inx+4: mlp_inx+5]
        # if int(l) % 2 != 0: continue
        unshared_mlp = k.replace('decoder.layers.mlp_', 'decoder.layers.unshared_mlp_')
        unshared_mlp = unshared_mlp.replace('.kernel', '')
        moe_params[unshared_mlp] = v
        if 'mgate' in unshared_mlp:
            router_gate = unshared_mlp.replace('mgate', 'router_gate.kernel')
            moe_params[router_gate] = v
moe_params = {tuple(k.split('.')): v for k,v in moe_params.items()}
example_step = 8
save_params(example_step, save_dir, moe_params)




k: decoder.decoder_norm.scale take: 0.001s
k: decoder.layers.mlp_0.mgate.kernel take: 0.002s
k: decoder.layers.mlp_0.wi_0.kernel take: 0.002s
k: decoder.layers.mlp_0.wi_1.kernel take: 0.002s
k: decoder.layers.mlp_0.wo.kernel take: 0.002s
k: decoder.layers.mlp_1.mgate.kernel take: 0.002s
k: decoder.layers.mlp_1.wi_0.kernel take: 0.003s
k: decoder.layers.mlp_1.wi_1.kernel take: 0.003s
k: decoder.layers.mlp_1.wo.kernel take: 0.003s
k: decoder.layers.mlp_2.mgate.kernel take: 0.003s
k: decoder.layers.mlp_2.wi_0.kernel take: 0.003s
k: decoder.layers.mlp_2.wi_1.kernel take: 0.003s
k: decoder.layers.mlp_2.wo.kernel take: 0.004s
k: decoder.layers.mlp_3.mgate.kernel take: 0.004s
k: decoder.layers.mlp_3.wi_0.kernel take: 0.004s
k: decoder.layers.mlp_3.wi_1.kernel take: 0.004s
k: decoder.layers.mlp_3.wo.kernel take: 0.004s
k: decoder.layers.post_self_attention_layer_norm_0.scale take: 0.004s
k: decoder.layers.post_self_attention_layer_norm_1.scale take: 0.004s
k: decoder.layers.post_self_attention

In [None]:
import subprocess

# save_dir1 = str(save_dir).rstrip('/')
source_dir = 'gs://llm_base_models_europe-west4/v5p_256/7B/xm_45x7B_moe_129/checkpoints'
target_dir = 'gs://llm_base_models_europe-west4/v5p_256/7B/xm_45x7B_moe_129/checkpoints'
source_step = example_step
command = f'gsutil cp {source_dir}/{source_step}/state/_METADATA {target_dir}/0/state/ '
r = subprocess.run(command, stdout=subprocess.PIPE, shell=True)

command = f'gsutil cp {source_dir}/{source_step}/state/checkpoint {target_dir}/0/state/ '
r = subprocess.run(command, stdout=subprocess.PIPE, shell=True)

# command = f'gsutil cp {source_dir}/{source_step}/state/_sharding {target_dir}/0/state/ '
# r = subprocess.run(command, stdout=subprocess.PIPE, shell=True)

Copying gs://llm_base_models_europe-west4/v5p_256/7B/xm_45x7B_moe_129/checkpoints/8/state/_METADATA [Content-Type=application/octet-stream]...
/ [1 files][ 34.0 KiB/ 34.0 KiB]                                                
Operation completed over 1 objects/34.0 KiB.                                     


In [10]:
# 基于tpu type 构建_sharding文件
import base64

def decode_base64(encoded_str):
    decoded_bytes = base64.b64decode(encoded_str)
    decoded_str = decoded_bytes.decode('utf-8')
    return decoded_str

def encode_base64(decoded_str):
    # decoded_str = "opt_state.mu.params.token_embedder.embedding"
    encoded_string = base64.b64encode(decoded_str.encode('utf-8')).decode('utf-8')
    return encoded_string

'''
_sharding文件格式如下：
{
  b3B0X3N0YXRlLm11LnBhcmFtcy50b2tlbl9lbWJlZGRlci5lbWJlZGRpbmc=': {'sharding_type': 'NamedSharding',
  'shape': [1, 1, 4, 1, 1, 1, 1],
  'axis_names': ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor','autoregressive'],
  'partition_spec': [['tensor', 'autoregressive'], ['fsdp', 'fsdp_transpose', 'sequence']],
   2: 4},
   ...
   }
   '''
# moe sharding
_sharding_path = 'gs://llm_base_models_us-east5/v5p_256/7B/xm_45x7B_moe_1017/xm3p5_moe_params_no_opt_v5p_64_sharding.copy'
_sharding_path = epath.Path(_sharding_path)
# 读取已有的_sharding文件
with _sharding_path.open('r') as f:
    _sharding = json.load(f)

tpu_type = 'v5p-64'
core_nums = int(tpu_type.split('-')[-1])
if 'v3' not in tpu_type:
    core_nums = core_nums // 2
print(f'core_nums: {core_nums}')
updated_sharding = {}
for k, v in _sharding.items():
    v = json.loads(v)
    v['shape'][2] = core_nums
    base_k = decode_base64(k)
    updated_sharding[k] = json.dumps(v)
    if 'unshared_mlp_0' in base_k: # 因为已有的sharding文件是隔层moe，因此需要进行扩展
        unshared_mlp_1 = base_k.replace('unshared_mlp_0', 'unshared_mlp_1')
        unshared_mlp_3 = base_k.replace('unshared_mlp_0', 'unshared_mlp_3')
        encode_unshared_mlp_1 = encode_base64(unshared_mlp_1)
        encode_unshared_mlp_3 = encode_base64(unshared_mlp_3)
        print(f'encode_unshared_mlp_1: {unshared_mlp_1}')
        print(f'encode_unshared_mlp_3: {unshared_mlp_3}')
        
        updated_sharding[encode_unshared_mlp_1] = json.dumps(v)
        updated_sharding[encode_unshared_mlp_3] = json.dumps(v)
    
updated_sharding_path = f'{save_dir}/0/state/_sharding'

updated_sharding_path = epath.Path(updated_sharding_path)
with updated_sharding_path.open('w') as f:
    json.dump(updated_sharding, f)

core_nums: 32


In [13]:
# p = 'gs://llm_base_models_us-east5/v5p_256/7B/xm_45x7B_moe_1017/checkpoints/0/state/_METADATA'
# # p = 'gs://llm_base_models_us-east5/v5p_256/7B/moe_test/checkpoints/0/state/_METADATA'
# # p = 'gs://llm_base_models_us-east5/v5p_256/7B/moe_test/checkpoints/0/state/checkpoint'
# p = epath.Path(p)

# structure_path = p
# msgpack = ocp.aggregate_handlers.MsgpackHandler(0)
# structure = msgpack.deserialize(structure_path)

# p = 'gs://llm_base_models_us-east5/v5p_256/7B/xm_45x7B_moe_1017/checkpoints/6/state/_METADATA'
# # p = 'gs://llm_base_models_us-east5/v5p_256/7B/moe_test/checkpoints/0/state/_METADATA'

# p = epath.Path(p)
# with p.open('r') as f:
#     meta = json.load(f)

# def decode_base64(encoded_str):
#     decoded_bytes = base64.b64decode(encoded_str)
#     decoded_str = decoded_bytes.decode('utf-8')
#     return decoded_str

# ps = []
# for k, v in updated_sharding.items():
#     a = decode_base64(k)
#     if 'opt_state' in a or 'step' in a: continue
#     ps.append(a)
#     print(a)

73