In [1]:
from typing import Optional

import functools

import numpy as np

import ml_collections
import tensorflow as tf
import tensorflow_io as tfio

import tokenizer
import sequence_packing
import random
import os

2024-05-29 12:16:02.352934: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-29 12:16:02.493551: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-29 12:16:02.495781: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [3]:
def loadjson_and_rekey(ds):
    """normalization with key mapping"""
    json_specs = {
        "text": tf.TensorSpec(tf.TensorShape([]), tf.string, name="text"),
    }
    key_map={"inputs": None, "targets": "text"}
    text_max_len = 10 * 1024 * 1024 # 10M

    def _loadjson_and_rekey(x, json_specs, key_map=None):
        """Replace the feature keys according to the mapping in `key_map`.
        For example, if the dataset returns examples of the format:
        {'foo': 'something', 'bar': 'something else', 'zoo': 'others'}
        and key_map = {'boo': 'foo', 'spar': 'bar', 'zoo': None} then this function will return
        examples with the format
        {'boo': 'something', 'spar': 'something else'}
        If a mapping is to None, then the key will be dropped.
        Args:
          x: an example to process.
          key_map: dictionary mapping new keys to original keys
        Returns:
          A preprocessed example with the format listed above.
        """
        x = tfio.experimental.serialization.decode_json(x, specs=json_specs)
        x["text"] = tf.strings.substr(x["text"], 0, text_max_len, unit='BYTE')

        x = {
            new_key: x[old_key] for new_key, old_key in key_map.items() if old_key
        }

        return x

    return ds.map(
        functools.partial(_loadjson_and_rekey, json_specs=json_specs, key_map=key_map), num_parallel_calls=AUTOTUNE
    )


In [4]:
# ds = tf.data.TextLineDataset(
#     ['/home/genggui001/code/maxtext/test.jsonl'],
#     # compression_type="GZIP",
#     buffer_size=8 * 1024 * 1024,
#     num_parallel_reads=2,
# )

# ds = loadjson_and_rekey(ds)

In [5]:
def reduce_concat_tokens(
    dataset,
    feature_key="targets",
    batch_size=128,
):
    """Token-preprocessor to concatenate multiple unrelated documents.
    If we want to generate examples of exactly the right length,
    (to avoid wasting space on padding), then we use this function, followed by
    split_tokens.
    Args:
      dataset: a tf.data.Dataset with dictionaries containing the key feature_key.
      feature_key: an string
      batch_size: an integer - how many documents to concatenate into one
    Returns:
      a dataset
    """
    dataset = dataset.map(
        lambda x: {feature_key: x[feature_key]}, num_parallel_calls=AUTOTUNE
    )
    dataset = dataset.padded_batch(batch_size, padded_shapes={feature_key: [-1]})

    def _my_fn(x):
        tokens = tf.reshape(x[feature_key], [-1])
        # strip padding
        tokens = tf.boolean_mask(tokens, tf.cast(tokens, tf.bool))
        return {feature_key: tokens}

    return dataset.map(_my_fn, num_parallel_calls=AUTOTUNE)


def split_tokens(
    dataset,
    max_tokens_per_segment=128,
    feature_key="targets",
):
    """Split examples into multiple examples each.
    The intended use case is to break up long examples for use in unsupervised
    transfer-learning.
    This function is generally preceded by select_random_chunk.
    Args:
      dataset: a tf.data.Dataset with dictionaries containing the key feature_key.
      max_tokens_per_segment: an integer, the maximum number of tokens in each
        segment. Only the final segment may be shorter.
      feature_key: a string, the feature to split
    Returns:
      a dataset
    """

    def _split_tokens(x):
        """Split one token sequence into multiple multiple."""
        tokens = x[feature_key]
        n_tokens = tf.size(tokens)
        length = max_tokens_per_segment

        # Pad to a multiple of length, then use tf.reshape to split up the tokens
        # into num_segments segments each of the given length.
        num_segments = tf.cast(
            tf.math.ceil(tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32)),
            tf.int32,
        )
        padding = num_segments * length - tf.size(tokens)
        tokens = tf.pad(tokens, [[0, padding]])
        return tf.reshape(tokens, [-1, length])

    def _strip_padding(x):
        return {feature_key: tf.boolean_mask(x, tf.cast(x, tf.bool))}

    # Filter empty examples.
    dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0))
    dataset = dataset.map(_split_tokens, num_parallel_calls=AUTOTUNE)
    dataset = dataset.unbatch()
    return dataset.map(_strip_padding, num_parallel_calls=AUTOTUNE)


def split_tokens_to_targets_length(dataset, sequence_length):
    return split_tokens(dataset, max_tokens_per_segment=sequence_length)


def load_base_dataset(
    pattern,
    seed,
):
    data_paths = sorted(tf.io.gfile.glob(pattern))

    # shard dataset now
    print((pattern, "all_file_count", len(data_paths)))
    data_num_shards = 1
    data_index = 0
    data_paths = [
        d
        for i, d in enumerate(data_paths)
        if i % data_num_shards == data_index
    ]
    random.seed(seed + data_index)
    random.shuffle(data_paths)

    print((pattern, "share_file_count", data_num_shards, data_index, len(data_paths), data_paths[:2]))

    ds = tf.data.TextLineDataset(
        data_paths,
        compression_type="GZIP",
        buffer_size=8 * 1024 * 1024,
        num_parallel_reads=2,
    )

    return ds

In [6]:
dataset_path = "/home/genggui001/gdrive/gg-nlp-lm-new-3"
data_shuffle_seed = 1234
max_target_length = 4096

In [7]:
from input_pipeline._tfds_data_processing_gg_mlperf import map_with_tokenize

train_ds = load_base_dataset(
    pattern=os.path.join(dataset_path, "the-stack-v2-train-full/**/*.gz"),
    seed=data_shuffle_seed,
)

train_ds = loadjson_and_rekey(
    train_ds, 
)

train_ds = map_with_tokenize(
    train_ds,
    vocab_path="/home/genggui001/code/maxtext/assets/llama_add_world.model",
    add_bos=True,
    add_eos=True,
    number_of_parallel_calls=16,
)



('/home/genggui001/gdrive/gg-nlp-lm-new-3/the-stack-v2-train-full/**/*.gz', 'all_file_count', 496)
('/home/genggui001/gdrive/gg-nlp-lm-new-3/the-stack-v2-train-full/**/*.gz', 'share_file_count', 1, 0, 496, ['/home/genggui001/gdrive/gg-nlp-lm-new-3/the-stack-v2-train-full/fss_record/chunk-00005.json.gz', '/home/genggui001/gdrive/gg-nlp-lm-new-3/the-stack-v2-train-full/fss_record/chunk-00463.json.gz'])


2024-05-29 12:16:10.109954: I tensorflow_io/core/kernels/cpu_check.cc:128] Your CPU supports instructions that this TensorFlow IO binary was not compiled to use: AVX AVX2 AVX512F FMA
2024-05-29 12:16:10.114003: W tensorflow_io/core/kernels/audio_video_mp3_kernels.cc:271] libmp3lame.so.0 or lame functions are not available


In [8]:
iterator = train_ds.as_numpy_iterator()

In [9]:
from tqdm.auto import tqdm

dds = []

for _ in tqdm(range(16384)):
    tmp = next(iterator)
    dds.append(len(tmp['targets']))

np.mean(dds)

  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 0/16384 [00:00<?, ?it/s]

  0%|          | 10/16384 [00:05<1:56:46,  2.34it/s]

('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')




('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')

100%|██████████| 16384/16384 [09:38<00:00, 28.34it/s] 


10899.837463378906

In [10]:
probabilities = [
            0.5,
            0.2,
            0.2,
            0.1,
]


data_mean_lens = [
501.73480224609375,
716.809814453125,
964.8450927734375,
4096,
# 10899.837463378906,
]

In [11]:
a = []
b = []

for i in range(1, len(probabilities)):
    p = probabilities[i]

    tmp_a = [-item * p for item in data_mean_lens]
    tmp_a[i] = data_mean_lens[i] * (1 - p)
    a.append(tmp_a)
    b.append(0)

a.append([1 for _ in data_mean_lens])
b.append(1)

a = np.array(a)
b = np.array(b)

print(("old probabilities", probabilities))
probabilities = np.linalg.solve(a, b).tolist()
print(("new probabilities", probabilities))

('old probabilities', [0.5, 0.2, 0.2, 0.1])
('new probabilities', [0.6611626024947151, 0.18511369734427188, 0.13752602986730192, 0.016197670293710872])


In [12]:
sum([0.6611626024947151, 0.18511369734427188, 0.13752602986730192, 0.016197670293710872])

0.9999999999999998

In [7]:
"""Load and return dataset of batched examples for use during training."""
en_ds = load_base_dataset(
    pattern=os.path.join(dataset_path, "gg_en/**/*.jsonl.gz"),
    seed=data_shuffle_seed,
)

zh_ds = load_base_dataset(
    pattern=os.path.join(dataset_path, "gg_zh/**/*.jsonl.gz"),
    seed=data_shuffle_seed,
)

other_ds = load_base_dataset(
    pattern=os.path.join(dataset_path, "uonlp_culturax_shuffle/**/*.jsonl.gz"),
    seed=data_shuffle_seed,
)

code_ds = load_base_dataset(
    pattern=os.path.join(dataset_path, "the-stack-dedup/**/*.jsonl.gz"),
    seed=data_shuffle_seed,
)

train_ds = tf.data.Dataset.sample_from_datasets(
    datasets = [
        en_ds.repeat(),
        zh_ds.repeat(),
        other_ds.repeat(),
        code_ds.repeat(),
    ], 
    weights=[
        0.45,
        0.2,
        0.25,
        0.1,
    ],
    seed=data_shuffle_seed,
)

train_ds = loadjson_and_rekey(
    train_ds, 
)

('/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_en/**/*.jsonl.gz', 'all_file_count', 2088)
('/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_en/**/*.jsonl.gz', 'share_file_count', 1, 0, 2088, ['/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_en/WebText-en/chunk-00029.jsonl.gz', '/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_en/falcon-refinedweb/chunk-00497.jsonl.gz'])


('/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_zh/**/*.jsonl.gz', 'all_file_count', 873)
('/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_zh/**/*.jsonl.gz', 'share_file_count', 1, 0, 873, ['/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_zh/TeleChat-PTD/chunk-00292.jsonl.gz', '/home/genggui001/gdrive/gg-nlp-lm-new-3/gg_zh/WebText-cn/chunk-00153.jsonl.gz'])
('/home/genggui001/gdrive/gg-nlp-lm-new-3/uonlp_culturax_shuffle/**/*.jsonl.gz', 'all_file_count', 5000)
('/home/genggui001/gdrive/gg-nlp-lm-new-3/uonlp_culturax_shuffle/**/*.jsonl.gz', 'share_file_count', 1, 0, 5000, ['/home/genggui001/gdrive/gg-nlp-lm-new-3/uonlp_culturax_shuffle/s/schunk-02004.jsonl.gz', '/home/genggui001/gdrive/gg-nlp-lm-new-3/uonlp_culturax_shuffle/s/schunk-00879.jsonl.gz'])
('/home/genggui001/gdrive/gg-nlp-lm-new-3/the-stack-dedup/**/*.jsonl.gz', 'all_file_count', 200)
('/home/genggui001/gdrive/gg-nlp-lm-new-3/the-stack-dedup/**/*.jsonl.gz', 'share_file_count', 1, 0, 200, ['/home/genggui001/gdrive/gg-nlp-lm-new-3/the-stac

2024-05-29 11:24:00.630326: I tensorflow_io/core/kernels/cpu_check.cc:128] Your CPU supports instructions that this TensorFlow IO binary was not compiled to use: AVX AVX2 AVX512F FMA
2024-05-29 11:24:00.635278: W tensorflow_io/core/kernels/audio_video_mp3_kernels.cc:271] libmp3lame.so.0 or lame functions are not available


In [8]:
from input_pipeline._tfds_data_processing_gg_mlperf import map_with_tokenize

train_ds = train_ds.shuffle(128, seed=data_shuffle_seed)

train_ds = map_with_tokenize(
    train_ds,
    vocab_path="/home/genggui001/code/maxtext/assets/llama_add_world.model",
    add_bos=True,
    add_eos=True,
    number_of_parallel_calls=16,
)

def format_fn(x):
    tokens = x["targets"][:max_target_length+1]

    return  {
        "inputs": tokens[:-1],
        "targets": tokens[1:],
    }

train_ds = train_ds.map(
    format_fn, 
    num_parallel_calls=AUTOTUNE
)

train_ds = sequence_packing.pack_dataset(train_ds, max_target_length)

# train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=512)
# train_ds = split_tokens_to_targets_length(train_ds, max_target_length+1)

# # # note eval_ds is pre tokenized, reduce_concated and split to target_length
# # #   mainly to avoid eval sequences change depending on the number of hosts
# # train_ds = sequence_packing.pack_dataset(train_ds, max_target_length+1)


# def format_fn(x):
#     tokens = x["targets"]
#     x["inputs"] = tokens[:-1]
#     x["targets"] = tokens[1:]
    
#     x["inputs_segmentation"] = tf.ones_like(x["inputs"])
#     x["targets_segmentation"] = x["inputs_segmentation"]

#     position = tf.range(tf.size(tokens)-1, dtype=tf.int32)

#     x["inputs_position"] = position
#     x["targets_position"] = position

#     return x

# train_ds = train_ds.map(format_fn, num_parallel_calls=AUTOTUNE)


train_ds = train_ds.batch(
    8, 
    drop_remainder=True
)

train_ds = train_ds.prefetch(32)



In [9]:
iterator = train_ds.as_numpy_iterator()
step = tf.Variable(0)

In [10]:
ckpt = tf.train.Checkpoint(step=step, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, '/home/genggui001/code/maxtext/tmp/tf_ckpts', max_to_keep=3)

In [11]:
ckpt.restore(manager.latest_checkpoint)





('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')




('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')


<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f53ba7010f0>

In [12]:
# dd = next(iterator)
# dd

{'inputs': array([[     2,  19898,   2156, ...,      0,      0,      0],
        [     2,  29872,  32885, ...,      0,      0,      0],
        [     2,  23862,   8853, ...,      0,      0,      0],
        ...,
        [     2,  80979,   2048, ...,      0,      0,      0],
        [     2,    320,  21372, ...,      0,      0,      0],
        [     2, 143678,   4747, ...,      0,      0,      0]], dtype=int32),
 'inputs_position': array([[0, 1, 2, ..., 0, 0, 0],
        [0, 1, 2, ..., 0, 0, 0],
        [0, 1, 2, ..., 0, 0, 0],
        ...,
        [0, 1, 2, ..., 0, 0, 0],
        [0, 1, 2, ..., 0, 0, 0],
        [0, 1, 2, ..., 0, 0, 0]], dtype=int32),
 'targets': array([[ 19898,   2156,  13377, ...,      0,      0,      0],
        [ 29872,  32885,  30384, ...,      0,      0,      0],
        [ 23862,   8853,    298, ...,      0,      0,      0],
        ...,
        [ 80979,   2048,  21079, ...,      0,      0,      0],
        [   320,  21372,  63334, ...,      0,      0,      0],


In [12]:
from tqdm.auto import tqdm



  from .autonotebook import tqdm as notebook_tqdm


In [13]:
for _ in tqdm(range(64)):
    tmp = next(iterator)
    step.assign_add(1)





('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')




('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')
('/home/genggui001/code/maxtext/assets/llama_add_world.model', True, True, 'load success')


100%|██████████| 64/64 [00:09<00:00,  6.62it/s]


In [14]:
manager.save(step)

'/home/genggui001/code/maxtext/tmp/tf_ckpts/ckpt-64'

In [15]:
dd = next(iterator)
dd

{'inputs': array([[     2,  19898,   2156, ...,      0,      0,      0],
        [     2,  29872,  32885, ...,      0,      0,      0],
        [     2,  23862,   8853, ...,      0,      0,      0],
        ...,
        [     2,  80979,   2048, ...,      0,      0,      0],
        [     2,    320,  21372, ...,      0,      0,      0],
        [     2, 143678,   4747, ...,      0,      0,      0]], dtype=int32),
 'inputs_position': array([[0, 1, 2, ..., 0, 0, 0],
        [0, 1, 2, ..., 0, 0, 0],
        [0, 1, 2, ..., 0, 0, 0],
        ...,
        [0, 1, 2, ..., 0, 0, 0],
        [0, 1, 2, ..., 0, 0, 0],
        [0, 1, 2, ..., 0, 0, 0]], dtype=int32),
 'targets': array([[ 19898,   2156,  13377, ...,      0,      0,      0],
        [ 29872,  32885,  30384, ...,      0,      0,      0],
        [ 23862,   8853,    298, ...,      0,      0,      0],
        ...,
        [ 80979,   2048,  21079, ...,      0,      0,      0],
        [   320,  21372,  63334, ...,      0,      0,      0],


In [29]:
for item in zip(dd['inputs'][1], dd['targets'][1]):
    print(item)

(2, 7741)
(7741, 3107)
(3107, 322)
(322, 2910)
(2910, 21972)
(21972, 29902)
(29902, 322)
(322, 5468)
(5468, 113978)
(113978, 13553)
(13553, 365)
(365, 13265)
(13265, 302)
(302, 9601)
(9601, 283)
(283, 2586)
(2586, 13873)
(13873, 3473)
(3473, 14)
(14, 13198)
(13198, 748)
(748, 2742)
(2742, 9429)
(9429, 314)
(314, 29910)
(29910, 3194)
(3194, 29872)
(29872, 29897)
(29897, 29956)
(29956, 29893)
(29893, 29872)
(29872, 29897)
(29897, 29956)
(29956, 29956)
(29956, 29907)
(29907, 786)
(786, 5534)
(5534, 29872)
(29872, 29907)
(29907, 29907)
(29907, 29893)
(29893, 29872)
(29872, 29897)
(29897, 29948)
(29948, 29946)
(29946, 29897)
(29897, 29898)
(29898, 472)
(472, 386)
(386, 3083)
(3083, 4361)
(4361, 1610)
(1610, 358)
(358, 714)
(714, 279)
(279, 1190)
(1190, 714)
(714, 323)
(323, 12252)
(12252, 473)
(473, 1589)
(1589, 3743)
(3743, 20667)
(20667, 451)
(451, 5997)
(5997, 9445)
(9445, 3822)
(3822, 29890)
(29890, 941)
(941, 472)
(472, 6346)
(6346, 298)
(298, 4276)
(4276, 6164)
(6164, 29893)
(29893, 8