In [1]:
from typing import Optional

import functools

import numpy as np

import ml_collections
import tensorflow as tf
import tensorflow_io as tfio

import tokenizer
import sequence_packing
import random
import os

2024-04-25 07:40:03.540359: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-25 07:40:03.679689: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-04-25 07:40:03.680962: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [3]:
def loadjson_and_rekey(ds):
    """normalization with key mapping"""
    json_specs = {
        "text": tf.TensorSpec(tf.TensorShape([]), tf.string, name="text"),
    }
    key_map={"inputs": None, "targets": "text"}
    text_max_len = 10 * 1024 * 1024 # 10M

    def _loadjson_and_rekey(x, json_specs, key_map=None):
        """Replace the feature keys according to the mapping in `key_map`.
        For example, if the dataset returns examples of the format:
        {'foo': 'something', 'bar': 'something else', 'zoo': 'others'}
        and key_map = {'boo': 'foo', 'spar': 'bar', 'zoo': None} then this function will return
        examples with the format
        {'boo': 'something', 'spar': 'something else'}
        If a mapping is to None, then the key will be dropped.
        Args:
          x: an example to process.
          key_map: dictionary mapping new keys to original keys
        Returns:
          A preprocessed example with the format listed above.
        """
        x = tfio.experimental.serialization.decode_json(x, specs=json_specs)
        x["text"] = tf.strings.substr(x["text"], 0, text_max_len, unit='BYTE')

        x = {
            new_key: x[old_key] for new_key, old_key in key_map.items() if old_key
        }

        return x

    return ds.map(
        functools.partial(_loadjson_and_rekey, json_specs=json_specs, key_map=key_map), num_parallel_calls=AUTOTUNE
    )


def reduce_concat_tokens(
    dataset,
    feature_key="targets",
    batch_size=128,
):
    """Token-preprocessor to concatenate multiple unrelated documents.
    If we want to generate examples of exactly the right length,
    (to avoid wasting space on padding), then we use this function, folowed by
    split_tokens.
    Args:
      dataset: a tf.data.Dataset with dictionaries containing the key feature_key.
      feature_key: an string
      batch_size: an integer - how many documents to concatenate into one
    Returns:
      a dataset
    """
    dataset = dataset.map(
        lambda x: {feature_key: x[feature_key]}, num_parallel_calls=AUTOTUNE
    )
    dataset = dataset.padded_batch(batch_size, padded_shapes={feature_key: [-1]})

    def _my_fn(x):
        tokens = tf.reshape(x[feature_key], [-1])
        # strip padding
        tokens = tf.boolean_mask(tokens, tf.cast(tokens, tf.bool))
        return {feature_key: tokens}

    return dataset.map(_my_fn, num_parallel_calls=AUTOTUNE)


def split_tokens(
    dataset,
    max_tokens_per_segment=128,
    feature_key="targets",
):
    """Split examples into multiple examples each.
    The intended use case is to break up long examples for use in unsupervised
    transfer-learning.
    This function is generally preceded by select_random_chunk.
    Args:
      dataset: a tf.data.Dataset with dictionaries containing the key feature_key.
      max_tokens_per_segment: an integer, the maximum number of tokens in each
        segment. Only the final segment may be shorter.
      feature_key: a string, the feature to split
    Returns:
      a dataset
    """

    def _split_tokens(x):
        """Split one token sequence into multiple multiple."""
        tokens = x[feature_key]
        n_tokens = tf.size(tokens)
        length = max_tokens_per_segment

        # Pad to a multiple of length, then use tf.reshape to split up the tokens
        # into num_segments segments each of the given length.
        num_segments = tf.cast(
            tf.math.ceil(tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32)),
            tf.int32,
        )
        padding = num_segments * length - tf.size(tokens)
        tokens = tf.pad(tokens, [[0, padding]])
        return tf.reshape(tokens, [-1, length])

    def _strip_padding(x):
        return {feature_key: tf.boolean_mask(x, tf.cast(x, tf.bool))}

    # Filter empty examples.
    dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0))
    dataset = dataset.map(_split_tokens, num_parallel_calls=AUTOTUNE)
    dataset = dataset.unbatch()
    return dataset.map(_strip_padding, num_parallel_calls=AUTOTUNE)


def split_tokens_to_targets_length(dataset, sequence_length):
    return split_tokens(dataset, max_tokens_per_segment=sequence_length)


def load_base_dataset(
    pattern,
    seed,
):
    data_paths = sorted(tf.io.gfile.glob(pattern))

    # shard dataset now
    print((pattern, "all_file_count", len(data_paths)))
    data_num_shards = 1
    data_index = 0
    data_paths = [
        d
        for i, d in enumerate(data_paths)
        if i % data_num_shards == data_index
    ]
    random.seed(seed + data_index)
    random.shuffle(data_paths)

    print((pattern, "share_file_count", data_num_shards, data_index, len(data_paths), data_paths[:2]))

    ds = tf.data.TextLineDataset(
        data_paths,
        compression_type="GZIP",
        buffer_size=8 * 1024 * 1024,
        num_parallel_reads=2,
    )

    return ds

In [4]:
dataset_path = "/home/genggui001/gdrive/gg-nlp-lm-new-3"
data_shuffle_seed = 1234
max_target_length = 4096

In [5]:
"""Load and return dataset of batched examples for use during training."""
en_ds = load_base_dataset(
    pattern=os.path.join(dataset_path, "gg_en/**/*.jsonl.gz"),
    seed=data_shuffle_seed,
)

zh_ds = load_base_dataset(
    pattern=os.path.join(dataset_path, "gg_zh/**/*.jsonl.gz"),
    seed=data_shuffle_seed,
)

other_ds = load_base_dataset(
    pattern=os.path.join(dataset_path, "uonlp_culturax_shuffle/**/*.jsonl.gz"),
    seed=data_shuffle_seed,
)

code_ds = load_base_dataset(
    pattern=os.path.join(dataset_path, "the-stack-dedup/**/*.jsonl.gz"),
    seed=data_shuffle_seed,
)

train_ds = tf.data.Dataset.sample_from_datasets(
    datasets = [
        en_ds.repeat(),
        zh_ds.repeat(),
        other_ds.repeat(),
        code_ds.repeat(),
    ], 
    weights=[
        0.45,
        0.2,
        0.25,
        0.1,
    ],
    seed=data_shuffle_seed,
)

train_ds = loadjson_and_rekey(
    train_ds, 
)

('/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_en/**/*.jsonl.gz', 'all_file_count', 2088)
('/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_en/**/*.jsonl.gz', 'share_file_count', 1, 0, 2088, ['/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_en/WebText-en/chunk-00029.jsonl.gz', '/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_en/falcon-refinedweb/chunk-00497.jsonl.gz'])
('/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_zh/**/*.jsonl.gz', 'all_file_count', 873)
('/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_zh/**/*.jsonl.gz', 'share_file_count', 1, 0, 873, ['/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_zh/TeleChat-PTD/chunk-00292.jsonl.gz', '/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_zh/WebText-cn/chunk-00153.jsonl.gz'])
('/home/genggui001/gdrive/gg-nlp-lm-new-3/uonlp_culturax_shuffle/**/*.jsonl.gz', 'all_file_count', 5000)
('/home/genggui001/gdrive/gg-nlp-lm-new-3/uonlp_culturax_shuffle/**/*.jsonl.gz', 'share_file_count', 1, 0, 5000, ['/home/genggui001/gdrive/gg-nlp-lm-new-3/uonlp_culturax_shuffle/s/schunk-02004

2024-04-25 07:40:32.292905: I tensorflow_io/core/kernels/cpu_check.cc:128] Your CPU supports instructions that this TensorFlow IO binary was not compiled to use: AVX AVX2 AVX512F FMA
2024-04-25 07:40:32.297564: W tensorflow_io/core/kernels/audio_video_mp3_kernels.cc:271] libmp3lame.so.0 or lame functions are not available


In [6]:
from sentencepiece import SentencePieceProcessor

tokenize_processor = SentencePieceProcessor()
tokenize_processor.Load("/home/genggui001/code/maxtext/assets/llama_add_world.model")


def tokenize_fn(t: np.ndarray):
    text = t.decode('utf-8')
    token_ids = tokenize_processor.EncodeAsIds(text)
    token_ids = [tokenize_processor.bos_id()] + token_ids + [tokenize_processor.eos_id()]

    return np.asarray(token_ids, dtype=np.int32)


train_ds = train_ds.shuffle(128, seed=data_shuffle_seed)
train_ds = train_ds.map(
    lambda x: {
        'targets': tf.numpy_function(
            func=tokenize_fn, 
            inp=[x['targets']], 
            Tout=tf.int32, 
            stateful=False,
        )
    },
    num_parallel_calls=AUTOTUNE,
)

train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=512)
train_ds = split_tokens_to_targets_length(train_ds, max_target_length+1)

# # note eval_ds is pre tokenized, reduce_concated and splitted to target_length
# #   mainly to avoid eval sequences change depending on the number of hosts
# train_ds = sequence_packing.pack_dataset(train_ds, max_target_length+1)


def format_fn(x):
    tokens = x["targets"]
    x["inputs"] = tokens[:-1]
    x["targets"] = tokens[1:]
    
    x["inputs_segmentation"] = tf.ones_like(x["inputs"])
    x["targets_segmentation"] = x["inputs_segmentation"]

    position = tf.range(tf.size(tokens)-1, dtype=tf.int32)

    x["inputs_position"] = position
    x["targets_position"] = position

    return x

train_ds = train_ds.map(format_fn, num_parallel_calls=AUTOTUNE)


train_ds = train_ds.padded_batch(
    8, 
    padded_shapes={
        "inputs": max_target_length,
        "targets": max_target_length,
        "inputs_segmentation": max_target_length,
        "targets_segmentation": max_target_length,
        "inputs_position": max_target_length,
        "targets_position": max_target_length,
    },
    drop_remainder=True
)

train_ds = train_ds.prefetch(32)

In [7]:
iterator = train_ds.as_numpy_iterator()
step = tf.Variable(0)

In [8]:
ckpt = tf.train.Checkpoint(step=step, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, '/home/genggui001/code/maxtext/tmp/tf_ckpts', max_to_keep=3)

In [9]:
ckpt.restore(manager.latest_checkpoint)

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f27a119bca0>

In [1]:
manager.latest_checkpoint

NameError: name 'manager' is not defined

In [10]:
step

<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=64>

In [10]:
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
for _ in tqdm(range(64)):
    tmp = next(iterator)
    step.assign_add(1)

  0%|          | 0/64 [00:00<?, ?it/s]2024-04-25 07:24:01.904674: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 3 of 128
2024-04-25 07:24:02.787369: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 4 of 128
2024-04-25 07:24:02.787420: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 5 of 128
2024-04-25 07:24:03.355352: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 12 of 128
2024-04-25 07:24:03.464820: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] Shuffle buffer filled.
100%|██████████| 64/64 [01:00<00:00,  1.05it/s]


In [12]:
manager.save(step)

'/home/genggui001/code/maxtext/tmp/tf_ckpts/ckpt-64'

In [26]:
manager

<tensorflow.python.checkpoint.checkpoint_management.CheckpointManager at 0x7fecec765240>

In [13]:
next(iterator)

{'targets': array([[  527, 29872, 29946, ...,   293,  6959,  7582],
        [ 2966, 29872, 29946, ...,  4468,  5173,   334],
        [  317,  1303,   736, ..., 31038, 30745, 33118],
        ...,
        [29946, 29893, 29872, ..., 30215, 32176, 36023],
        [30504, 33955, 30215, ...,  1493, 12232,  2134],
        [   14, 14805,  2306, ..., 30331, 37771, 35690]], dtype=int32),
 'inputs': array([[29880,   527, 29872, ...,   568,   293,  6959],
        [  324,  2966, 29872, ..., 60983,  4468,  5173],
        [13806,   317,  1303, ..., 52506, 31038, 30745],
        ...,
        [29956, 29946, 29893, ..., 34880, 30215, 32176],
        [32570, 30504, 33955, ...,   279,  1493, 12232],
        [29890,    14, 14805, ..., 40356, 30331, 37771]], dtype=int32),
 'inputs_segmentation': array([[1, 1, 1, ..., 1, 1, 1],
        [1, 1, 1, ..., 1, 1, 1],
        [1, 1, 1, ..., 1, 1, 1],
        ...,
        [1, 1, 1, ..., 1, 1, 1],
        [1, 1, 1, ..., 1, 1, 1],
        [1, 1, 1, ..., 1, 1, 1]], dtyp