In [1]:
INPUT_FILE = "../data/num_tokens.tsv"
OUTPUT_FILE = "../configs/llm-jp-4_en_15.6t.sh"

In [2]:
import collections
import os
import numpy as np
import pandas as pd

formatter = get_ipython().display_formatter.formatters["text/plain"]
formatter.for_type(np.int64, lambda n, p, cycle: p.text(f"{n:d}"))
formatter.for_type(np.float64, lambda n, p, cycle: p.text(f"{n:f}"))
pass

In [3]:
num_tokens = pd.read_csv(INPUT_FILE, sep="\t", dtype={"filename": str, "tokens": np.int64})

stem = num_tokens["filename"].apply(lambda x: x.split("/")[3])
num_tokens["language"] = stem.apply(lambda x: x.split("_")[0])
num_tokens["subset"] = stem.apply(lambda x: "_".join(x.split("_")[:2]))

num_tokens

Unnamed: 0,filename,tokens,language,subset
0,corpus/tokenized/code/code_olmo-starcoder_0000,104427769064,code,code_olmo-starcoder
1,corpus/tokenized/code/code_stack_0000,114051163723,code,code_stack
2,corpus/tokenized/en/en_dolma-books_0000,5494262694,en,en_dolma-books
3,corpus/tokenized/en/en_dolma-pes2o_0000,62853772802,en,en_dolma-pes2o
4,corpus/tokenized/en/en_dolma-reddit_0000,83015186637,en,en_dolma-reddit
...,...,...,...,...
323,corpus/tokenized/zh/zh_fineweb2_0001,192160217528,zh,zh_fineweb2
324,corpus/tokenized/zh/zh_fineweb2_0002,191629318921,zh,zh_fineweb2
325,corpus/tokenized/zh/zh_fineweb2_0003,198652395168,zh,zh_fineweb2
326,corpus/tokenized/zh/zh_fineweb2_0004,15248244538,zh,zh_fineweb2


In [4]:
# Select subsets
selected_subsets = [
    "code_olmo-starcoder",
    "code_stack",
    "en_dolma-books",
    "en_dolma-pes2o",
    "en_dolma-reddit",
    "en_dolma-wiki",
    "en_dolmino-stackexchange",
    "en_finemath-4plus",
    "en_fineweb-eduscore2",
    "en_fineweb-rest",
    "en_gsm8k",
    "en_mathpile",
    "en_olmo-algebraicstack",
    "en_olmo-arxiv",
    "en_olmo-openwebmath",
    "en_wiki",
    "ja_fineweb2",
    "ja_wiki",
    "ko_fineweb2",
    "ko_wiki",
    "zh_fineweb2",
    "zh_wiki",
]
num_tokens = num_tokens[num_tokens["subset"].isin(selected_subsets)].copy()

# Number of files
len(num_tokens)

235

In [5]:
# Total tokens
total_tokens = num_tokens["tokens"].sum()
total_tokens

19125616043765

In [6]:
# Tokens per language
tokens_per_language = num_tokens.groupby("language").sum()["tokens"]
tokens_per_language

language
code      218478932787
en      17784024165462
ja        282176188715
ko         52097144842
zh        788839611959
Name: tokens, dtype: int64

In [7]:
# Tokens per subset
tokens_per_subset = num_tokens.groupby("subset").sum()["tokens"]
tokens_per_subset

subset
code_olmo-starcoder           104427769064
code_stack                    114051163723
en_dolma-books                  5494262694
en_dolma-pes2o                 62853772802
en_dolma-reddit                83015186637
en_dolma-wiki                   3896965449
en_dolmino-stackexchange        1464772187
en_finemath-4plus              10335599308
en_fineweb-eduscore2         6187818835090
en_fineweb-rest             11366326157218
en_gsm8k                           2781710
en_mathpile                     9176535715
en_olmo-algebraicstack         13280211413
en_olmo-arxiv                  22219529548
en_olmo-openwebmath            13395295861
en_wiki                         4744259830
ja_fineweb2                   280894286561
ja_wiki                         1281902154
ko_fineweb2                    51780848623
ko_wiki                          316296219
zh_fineweb2                   787999334628
zh_wiki                          840277331
Name: tokens, dtype: int64

In [8]:
sampling_rates = collections.defaultdict(lambda: 1.0)

# Calculate FineWeb sampling rate
train_tokens = 15_600_000_000_000
fw_rest_tokens = tokens_per_subset["en_fineweb-rest"]
fw_rest_train_tokens = train_tokens - (total_tokens - fw_rest_tokens)
sampling_rates["en_fineweb-rest"] = fw_rest_train_tokens / fw_rest_tokens
sampling_rates

defaultdict(<function __main__.<lambda>()>, {'en_fineweb-rest': 0.689819})

In [9]:
num_tokens["sampling_rate"] = num_tokens["subset"].apply(lambda x: sampling_rates[x])
num_tokens["sampled_tokens"] = np.ceil(num_tokens["tokens"] * num_tokens["sampling_rate"]).astype(np.int64)
sum(num_tokens["sampled_tokens"])

15600000000048

In [10]:
# Output corpus config file

os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)

corpus_prefix="/home/shared/experiments/0111_v4-setup"  # Sakura

with open(OUTPUT_FILE, "w") as fp:
    print("TRAIN_DATA_PATH=(", file=fp)
    print("    --data-path", file=fp)

    for idx, r in num_tokens.iterrows():
        print(f"    {r['sampled_tokens']:16d} {corpus_prefix}/{r['filename']}_text_document", file=fp)

    print(")", file=fp)