In [1]:
import os
import sys
import yaml
import json
import base64
from collections import defaultdict
from typing import Tuple
import functools

sys.path.append('/home/lishengping/projects/maxtext/MaxText')
os.environ['HARDWARE'] = 'tpu'

from layers import models
import max_utils
import jax
import orbax
import jax.numpy as jnp
from jax.sharding import Mesh
from flax.traverse_util import flatten_dict, unflatten_dict
from flax import linen as nn
from transformers import AutoTokenizer
from etils import epath

import pyconfig
from jax.sharding import PartitionSpec
from flax.linen import partitioning as nn_partitioning


TOKENIZER_PATH = '/home/lishengping/tokenizer'
if not os.path.exists(TOKENIZER_PATH):
    !gsutil cp -r gs://llm_base_models_us-east5/qwen/tokenizer /home/lishengping/
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True, trust_remote_code=True)

read_dir = "gs://llm_base_models/maxtext_align_pax_dc/maxtext_align2/checkpoints"
read_dir = "gs://llm_base_models_us-east5/v5p_256/7B/PileDCSlimLlama7B32Kx4x256x1v5p_0716_test/checkpoints"
read_dir = "gs://llm_base_models_us-east5/v5p_256/7B/PileDCSlimLlama7B32Kx4x256x1v5p_0705/checkpoints"
read_dir = epath.Path(read_dir)

# config_name = '/home/lishengping/projects/maxtext/MaxText/configs/dcformer_pp_405m.yml'
config_name = '/home/lishengping/projects/maxtext/MaxText/configs/dc_7b.yml'

argv = [None, config_name]
pyconfig.initialize(argv)
config = pyconfig.config
# validate_train_config(config)
devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)

2024-07-18 09:15:03.407736: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-18 09:15:03.425884: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-18 09:15:03.425923: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


SPECIAL_TOKENS: ((151643, '<|endoftext|>'), (151644, '<|im_start|>'), (151645, '<|im_end|>'), (151646, '<|extra_0|>'), (151647, '<|extra_1|>'), (151648, '<|extra_2|>'), (151649, '<|extra_3|>'), (151650, '<|extra_4|>'), (151651, '<|extra_5|>'), (151652, '<|extra_6|>'), (151653, '<|extra_7|>'), (151654, '<|extra_8|>'), (151655, '<|extra_9|>'), (151656, '<|extra_10|>'), (151657, '<|extra_11|>'), (151658, '<|extra_12|>'), (151659, '<|extra_13|>'), (151660, '<|extra_14|>'), (151661, '<|extra_15|>'), (151662, '<|extra_16|>'), (151663, '<|extra_17|>'), (151664, '<|extra_18|>'), (151665, '<|extra_19|>'), (151666, '<|extra_20|>'), (151667, '<|extra_21|>'), (151668, '<|extra_22|>'), (151669, '<|extra_23|>'), (151670, '<|extra_24|>'), (151671, '<|extra_25|>'), (151672, '<|extra_26|>'), (151673, '<|extra_27|>'), (151674, '<|extra_28|>'), (151675, '<|extra_29|>'), (151676, '<|extra_30|>'), (151677, '<|extra_31|>'), (151678, '<|extra_32|>'), (151679, '<|extra_33|>'), (151680, '<|extra_34|>'), (15168



## 构建sharding和metadata信息

In [47]:
def decode_base64(encoded_str):
    decoded_bytes = base64.b64decode(encoded_str)
    decoded_str = decoded_bytes.decode('utf-8')
    return decoded_str


def mesh_shard_rules(mesh, rules, remove_keys=[]):
    _sharding_dict = {}
    for name, rule in rules.items():
        if isinstance(rule, str):
            rule = json.loads(rule)
        name = decode_base64(name)
        param_key = tuple(name.split('.'))
        remove = any([1 if key in param_key else 0 for key in remove_keys])
        if remove: continue
        prule = [tuple(r) if isinstance(r, list) else r for r in rule['partition_spec'] ]
        spec = jax.sharding.PartitionSpec(*prule)
        _sharding_dict[param_key] = jax.sharding.NamedSharding(mesh, spec)
    return _sharding_dict


def rewrite_bucket_sharding(mesh, old_sharding, save_path):
    cur_machine_sharding = {}
    for k, v in old_sharding.items():
        if isinstance(v, str):
            v = json.loads(v)
        v['shape'] = mesh.device_ids.shape
        cur_machine_sharding[k] = v
    save_path = epath.Path(save_path)
    with save_path.open('w') as f:
        json.dump(cur_machine_sharding, f)
    
load_step = 440000
_sharding_path = read_dir / str(load_step) / 'state/_sharding'
_metadata_path = read_dir / str(load_step) / 'state/_METADATA'

# delete file or dir
# _sharding_path.unlink()

remove_keys = ['opt_state', 'step']
if _sharding_path.exists():
    with _sharding_path.open('r') as f:
        _sharding_rules = json.load(f)
    # 重写_sharding文件
    rewrite_bucket_sharding(mesh, _sharding_rules, _sharding_path)
    _sharding_dict = mesh_shard_rules(mesh, _sharding_rules, remove_keys=remove_keys)
    _sharding_dict = unflatten_dict(_sharding_dict)
elif _metadata_path.exists():
    _metadata_dict = {}
    with _metadata_path.open('r') as f:
        _metadata = json.load(f)
    for param_key in _metadata['tree_metadata']:
        if isinstance(param_key, str): param_key = eval(param_key)
        remove = any([1 if key in param_key else 0 for key in remove_keys])
        if remove: continue
        _metadata_dict[param_key] = jax.ShapeDtypeStruct(shape=(), dtype=jnp.float32)
    _metadata_dict = unflatten_dict(_metadata_dict)
    
else:
    _sharding_dict = None
    _metadata_dict = None

In [None]:
# # !pip install tensorflow==2.16.1
# # pip install numpy==1.26.4
# # 运行2遍
# import json
# import os
# import sys
# import asyncio
# import argparse
# from collections import defaultdict
# import time

# # os.environ["JAX_PLATFORMS"] = "cpu"
# from etils import epath
# import json
# import base64

# import torch
# import numpy as np
# import jax.numpy as jnp
# import jax
# import orbax
# import orbax.checkpoint as ocp
# from etils import epath
# from jax.sharding import PartitionSpec as PS
# from flax.traverse_util import flatten_dict, unflatten_dict


# METADATA_FILE = '_METADATA'
# _CHECKPOINT_FILE = 'checkpoint'

# read_dir = "gs://llm_base_models_europe-west4/v5p_256/7B/xm_M8x7B_E8_UnshareWithMgate_ShareWithMlp_AllCopymlp_1201/checkpoints"
# read_dir = epath.Path(read_dir) 
# read_dir = read_dir / '0/state'

# metadata_path = read_dir / METADATA_FILE
# back_metadata_path = read_dir / f'{METADATA_FILE}.back'
# try:
#     metadata_path.rename(back_metadata_path)
# except:
#     pass
# metadata_path.unlink(missing_ok=True) # delete
# structure_path = read_dir / _CHECKPOINT_FILE
# msgpack = ocp.aggregate_handlers.MsgpackHandler(0)
# structure = msgpack.deserialize(structure_path)
# # backup original checkpoint fil
# back_structure_path = read_dir / 'checkpoint_back'
# back_structure = structure.copy()
# if not back_structure_path.exists():
#     asyncio.run(msgpack.serialize(back_structure_path, item=back_structure))
# print(f'Old structure file keys: {structure.keys()}')
# remove_keys = ['opt_state', 'step'] # select the weight name you don't want to load, all weight name: opt_state, step, params
# _ = [structure.pop(key) for key in remove_keys if key in structure]
# print(f'New structure file keys: {structure.keys()}')
# asyncio.run(msgpack.serialize(structure_path, item=structure))  # rewrite struct file

# # load model based struct, note: axes must same as training
# mesh_axes = ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']
# axes = [1] * len(mesh_axes)
# axes[2] = 4
# devices = np.asarray(jax.devices()).reshape(axes)
# mesh = jax.sharding.Mesh(devices, mesh_axes)
# sharding = jax.sharding.NamedSharding(mesh, PS()) # Sharding is None because we use cpu to load weights
# weight_dtype = jnp.bfloat16 # set restore weights dtype
# restore_args = {}
# for k, v in flatten_dict(structure).items():
#     restore_args[k] =  ocp.ArrayRestoreArgs(restore_type=jax.Array, dtype=weight_dtype, sharding=sharding)
# restore_args = unflatten_dict(restore_args)

# # 新加的
# right_shard = flatten_dict(_sharding_dict)
# new_restore_args = {}
# for k, v in flatten_dict(restore_args).items():
#     print(k, v)
#     v.sharding = right_shard[k]
#     new_restore_args[k] = v
# restore_args = unflatten_dict(new_restore_args)


# ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
# w = ckptr.restore(read_dir, args=ocp.args.PyTreeRestore(restore_args=restore_args))
# structure_path = read_dir / _CHECKPOINT_FILE
# # rewrite struct file, otherwise occur error when continue training
# asyncio.run(msgpack.serialize(structure_path, item=back_structure))
# while 'params' in w:
#     w = w['params']
# # xm3p5_w = {'.'.join(k): np.array(v) for k, v in flatten_dict(w).items()}

# # try:
# #     back_metadata_path.rename(metadata_path)
# # except:
# #     pass

In [49]:
# 如果不行就用上面的
options = orbax.checkpoint.CheckpointManagerOptions()
item = {
    "state": orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler(use_ocdbt=False))
}
max_mngr = orbax.checkpoint.CheckpointManager(read_dir, item, options)
load_step = 441804
if _sharding_dict is not None:
    state = max_mngr.restore(load_step, items={"state": _sharding_dict})
elif _metadata_dict is not None:
    state = max_mngr.restore(load_step, items={"state": _metadata_dict})
else:
    state = max_mngr.restore(load_step, items=item)
params = state['state']['params']

I0000 00:00:1721296672.545915   40440 gcs_resource.cc:109] Using default AdmissionQueue with limit 32
I0000 00:00:1721296672.550512   45282 google_auth_provider.cc:180] Running on GCE, using service account 887571727717-compute@developer.gserviceaccount.com
tcmalloc: large alloc 1107296256 bytes == 0x97b4e000 @  0x7fd89e187680 0x7fd89e1a7ff4 0x7fd87770e500 0x7fd87770e592 0x7fd8764d0c20 0x7fd876d4e1ac 0x7fd876d51051 0x7fd877137623 0x7fd877139baf 0x7fd87790a1d0 0x7fd89e14d609 0x7fd89df16133
tcmalloc: large alloc 2491416576 bytes == 0xb43ac4000 @  0x7fd89e187680 0x7fd89e1a7ff4 0x7fd87770e500 0x7fd87770e592 0x7fd8764d0c20 0x7fd876d4e1ac 0x7fd876d51051 0x7fd877137623 0x7fd877139baf 0x7fd87790a1d0 0x7fd89e14d609 0x7fd89df16133
tcmalloc: large alloc 2491416576 bytes == 0xcdbd5e000 @  0x7fd89e187680 0x7fd89e1a7ff4 0x7fd87770e500 0x7fd87770e592 0x7fd8764d0c20 0x7fd876d4e1ac 0x7fd876d51051 0x7fd877137623 0x7fd877139baf 0x7fd87790a1d0 0x7fd89e14d609 0x7fd89df16133
tcmalloc: large alloc 2491416576

## 模型初始化和参数shard

In [52]:
assert _sharding_dict is not None

@functools.partial(jax.jit, in_shardings=None, out_shardings=_sharding_dict['params'])
def shard_to_tpu(x):
    return x
tpu_params = shard_to_tpu(params)
flat_params = flatten_dict(tpu_params)
for k, v in flat_params.items():
    print(k, v.shape)
print(f'devices: {v.devices()}')

quant = None

Transformer = models.Transformer
model = Transformer(config, mesh, quant=quant)
is_train = False
rng1, aqt_rng = jax.random.split(jax.random.key(9876))

('params', 'decoder', 'decoder_norm', 'scale') (4096,)
('params', 'decoder', 'layers', 'mlp_0', 'mgate', 'kernel') (4096, 12, 44)
('params', 'decoder', 'layers', 'mlp_0', 'wi_0', 'kernel') (4096, 12, 5632)
('params', 'decoder', 'layers', 'mlp_0', 'wi_1', 'kernel') (4096, 12, 5632)
('params', 'decoder', 'layers', 'mlp_0', 'wo', 'kernel') (5632, 12, 4096)
('params', 'decoder', 'layers', 'mlp_1', 'mgate', 'kernel') (4096, 12, 44)
('params', 'decoder', 'layers', 'mlp_1', 'wi_0', 'kernel') (4096, 12, 5632)
('params', 'decoder', 'layers', 'mlp_1', 'wi_1', 'kernel') (4096, 12, 5632)
('params', 'decoder', 'layers', 'mlp_1', 'wo', 'kernel') (5632, 12, 4096)
('params', 'decoder', 'layers', 'mlp_2', 'mgate', 'kernel') (4096, 12, 44)
('params', 'decoder', 'layers', 'mlp_2', 'wi_0', 'kernel') (4096, 12, 5632)
('params', 'decoder', 'layers', 'mlp_2', 'wi_1', 'kernel') (4096, 12, 5632)
('params', 'decoder', 'layers', 'mlp_2', 'wo', 'kernel') (5632, 12, 4096)
('params', 'decoder', 'layers', 'mlp_3', '

## 加载数据

In [83]:
import os
import time
import argparse
import socket
import random
from collections import defaultdict

import tensorflow as tf
import jax
import numpy as np

import math
from typing import Dict, List, Optional

from google.cloud import storage


seq_len = 4097

def extract_v3p5_longdata_files(dataset_path):  # lsp
    random.seed(9876)
    client = storage.Client()
    #v3: us-east1-d -> common_datasets, v4: us-central2-b -> common_datasets_us-central2-b
    path = dataset_path.replace('gs://', '')
    path_parts = path.split('/')
    bucket_name = path_parts[0]
    directory_path = '/'.join(path_parts[1:])
    directory_path = directory_path if directory_path.endswith('/') else directory_path + '/'
    train_files, valid_files = [], []
    train_long_files, train_short_files = [], []
    for blob in client.list_blobs(bucket_name, prefix=directory_path):
        path = f'gs://{os.path.join(bucket_name, blob.name)}'
        if 'valid' in path:
            valid_files.append(path)
        else:
            if '.long' in path:
                train_long_files.append(path)
            else:
                train_short_files.append(path)
    # file size short：long = 1.5: 1, 为了保证short的token: long = 3: 7, 因此 short 取 (1 / 1.5) * (3 / 7) = 2 / 7
    short_k = min(3 * len(train_long_files) // 14, len(train_short_files))
    selected_short_files = random.sample(train_short_files, k=short_k)
    train_files = selected_short_files + train_long_files
    print(f'selected_short_files: {len(selected_short_files)} train_long_files: {len(train_long_files)}')
    random.shuffle(train_files)
    print(f'first 10 train files: {train_files[:10]}')
    valid_files = sorted(valid_files)
    print(f'valid_files: {valid_files}')
    return train_files, valid_files


def extract_v3p5_data_files(dataset_path):
    client = storage.Client()
    path = dataset_path.replace('gs://', '')
    path_parts = path.split('/')
    bucket_name = path_parts[0]
    directory_path = '/'.join(path_parts[1:])
    directory_path = directory_path if directory_path.endswith('/') else directory_path + '/'
    # logging.info(f'bucket_name = {bucket_name}, directory_path = {directory_path}')
    train_files, valid_files = [], []
    for blob in client.list_blobs(bucket_name, prefix=directory_path):
        path = f'gs://{os.path.join(bucket_name, blob.name)}'
        if 'valid' in path:
            valid_files.append(path)
        else:
            train_files.append(path)
    train_files = sorted(train_files)
    valid_files = sorted(valid_files)
    print(f'Train file: {len(train_files)},  test file: {len(valid_files)}')
    return train_files, valid_files
    

def _parse_function(example_proto):
    feature_desc = {key: tf.io.VarLenFeature(tf.int64) for key in task_features}
    example = tf.io.parse_single_example(example_proto, feature_desc)
    for name in list(example.keys()):
        t = example[name]
        if t.dtype == tf.int64:
            t = tf.cast(t, dtype=tf.int32)
        example[name] = tf.sparse.to_dense(t, default_value=0)[: seq_len]
        print(f'example[name]: {example[name]}')
    return example

task_features = {'input_ids': None}
train_seed = 1234
num_infeed_hosts = 1
shuffle_buffer_size = None
pad_id = 0
batch_size = 8

fname = ['gs://jax_llm_data/xiaomeng/sft_target/tfrecord_len2k/en.test.continue_write.tfrecord']
datadir = 'gs://jax_llm_data_us-east5/xiaomeng/v3.5/tfids_4k_32k_0622/valid_tfrecord'
# train_files, eval_files = extract_v3p5_longdata_files(datadir)

datadir = 'gs://jax_llm_data_us-east5/xiaomeng/v3.5/tfids0527'
train_files, eval_files = extract_v3p5_data_files(datadir)


fname = eval_files

# fname = ['gs://jax_llm_data/xiaomeng/sft_target/tfrecord_len2k/en.test.continue_write.tfrecord']
tf.random.set_seed(train_seed)
ds = tf.data.Dataset.from_tensor_slices(fname)
ds = ds.apply(tf.data.TFRecordDataset)
# shard host data
ds = ds.shard(num_infeed_hosts, 0)
ds = ds.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE)
if shuffle_buffer_size is not None:
    ds = ds.shuffle(buffer_size=shuffle_buffer_size)
padded_shapes = {key: seq_len for key in task_features}
padding_values = {key: pad_id for key in task_features}
ds = ds.padded_batch(
    batch_size=np.prod(batch_size),
    padded_shapes=padded_shapes,
    padding_values=padding_values,
    drop_remainder=True,
)
# ds = ds.map(self.convert)
# ds = ds.prefetch(tf.data.AUTOTUNE)
iter_ds = ds.as_numpy_iterator()

Train file: 4000,  test file: 1
example[name]: Tensor("strided_slice:0", shape=(None,), dtype=int32)


In [None]:
def build_data_sharding(features, shard_names):
    shard_names = ('fsdp', None)
    data_sharding = {}
    for k in features:
        spec = jax.sharding.PartitionSpec(*shard_names)
        data_sharding[k] = jax.sharding.NamedSharding(mesh, spec)
    return data_sharding

data_features = ['inputs', 'inputs_position', 'inputs_segmentation', 'targets']
data_shard_names = ('data', None)
data_sharding = build_data_sharding(data_features, data_shard_names)

@functools.partial(jax.jit, in_shardings=(data_sharding, _sharding_dict['params'], ), out_shardings=None)
def model_forward(data, params):
    logits, intermediate_outputs = model.apply(
          params,
          data["inputs"],
          data["inputs_position"],
          decoder_segment_ids=data["inputs_segmentation"],
          enable_dropout=config.enable_dropout if is_train else False,
          rngs={"dropout": rng1, "params": aqt_rng},
          mutable="intermediates",
      )
    one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size)
    xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0)
    xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length"))
    return xent    

In [89]:
x = next(iter_ds)
input_ids = x['input_ids']
print(f'input_ids: {input_ids.shape}')
data = {}
data['inputs'] = input_ids[:, :-1]
pos = jnp.arange(data['inputs'].shape[1]).reshape(1, -1)
data["inputs_position"] = jnp.broadcast_to(pos, (batch_size, pos.shape[-1]))
data["inputs_segmentation"] = jnp.ones_like(data['inputs'])
data["targets"] = input_ids[:, 1:]
data = {k: v[:, :] for k, v in data.items()}

# loss compute
loss = model_forward(data, tpu_params)
print(f'loss shape: {loss.shape} mean: {loss.mean()}')

input_ids: (8, 4097)
loss shape: (8, 4096) mean: 2.4977970123291016
