## Import raw data as JSON shards

In [None]:
from datasets import load_dataset
import json

def get_fineweb_edu_data_sharded(
    shard_size = 50000,
    max_samples = 40000000,
    out_prefix = "./train_shard",
    val_filename = "./val_shard.json",
    val_size = 1000
):
    """
    Stream the FineWeb-Edu dataset and write out training samples in shards. Also create a validation shard of 'val_size' samples at the beginning.
    These are stored as raw samples.

    Params:
        @shard_size: Number of samples per training shard.
        @max_samples: Total samples for training. If `None`, read until dataset ends.
        @out_prefix: Filename prefix for train shards.
        @val_filename: Filename for the validation shard.
        @val_size: Number of samples in the validation set.
    """
    ds = load_dataset("HuggingFaceFW/fineweb-edu", name="default", split="train", streaming=True)
    ds = ds.filter(lambda x: x.get("language") == "en") #and x.get("score") >= 4
    ds_iter = iter(ds)

    # ------------------------------------------------
    # Collect validation samples
    # ------------------------------------------------
    val_data = []
    for _ in range(val_size):
        sample = next(ds_iter, None)
        if sample is None:
            break
        val_data.append(sample["text"])

    with open(val_filename, "w", encoding="utf-8") as f:
        json.dump(val_data, f, ensure_ascii=False)
    print(f"Saved {len(val_data)} validation samples to {val_filename}")

    # ------------------------------------------------
    # Collect training shards in a single pass
    # ------------------------------------------------
    total_written = 0
    shard_idx = 0

    while True:
        # If we have a max_samples limit and we've reached it, stop
        if max_samples is not None and total_written >= max_samples:
            break

        # Gather up to shard_size items
        chunk = []
        for _ in range(shard_size):
            sample = next(ds_iter, None)
            if sample is None:
                # No more data in the stream
                break
            chunk.append(sample)

        if not chunk:
            break  # We reached EOF on the stream

        # Extract text from each sample
        texts = [x["text"] for x in chunk]

        # Write shard
        shard_path = f"{out_prefix}_{shard_idx}.json"
        with open(shard_path, "w", encoding="utf-8") as f:
            json.dump(texts, f, ensure_ascii=False)

        shard_idx += 1
        total_written += len(chunk)
        print(f"Wrote shard {shard_path} with {len(chunk)} samples (total so far: {total_written}).")

    print("Done generating shards.")

get_fineweb_edu_data_sharded()

## Tokenize each shard and save into tokenizer-specific directories
- This is optional but will save significant time during training, as it uses multi-processing to pre-tokenize the data!
- If you run this, during training you should import data using `helpers.dataset.load_pt_shard_as_dataloader`, instead of `helpers.dataset.load_shard_as_dataloader`.

In [None]:
import torch
import multiprocessing
from functools import partial
from glob import glob
import os
import re
import json
from tqdm import tqdm
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

def _tokenize_line(line, tokenizer_encode, eos_id):
    """
    Helper function for parallel tokenization.
    'tokenizer_encode' is a partial or bound method that does:tokenizer.encode(..., add_special_tokens=False)
    """
    line_toks = tokenizer_encode(line)
    return line_toks + [eos_id]

def _parallel_tokenize_shard(shard_path: str, tokenizer, eos_seperator_id: int, num_processes: int):
    """
    Reads 'shard_path' JSON file, parallel-tokenizes lines, concatenates them into single 1D list of tokens
     
    Return a 1D Python list of token IDs returns (input_ids, attention_mask) as PyTorch tensors.
    """
    # 1) Read lines from JSON
    with open(shard_path, 'r') as f:
        lines = json.load(f)

    # 2) Parallel tokenization
    tokenizer_encode = partial(tokenizer.encode, add_special_tokens = False)
    with multiprocessing.Pool(processes = num_processes) as pool:
        tokenized_lines = pool.map(partial(_tokenize_line, tokenizer_encode = tokenizer_encode, eos_id = eos_seperator_id), lines)

    # 3) Concatenate
    big_token_buffer = []
    for toks in tokenized_lines:
        big_token_buffer.extend(toks)

    return big_token_buffer

def preprocess_all_shards(input_dir: str, output_dir: str, tokenizer, eos_seperator_id: int, num_processes: int = None):
    """
    Pre-process all shards in 'input_dir' that match 'train_shard_*.json'.
     For each shard, parallel-tokenize + chunk, -> one big 1D list of token IDs, then save that list as a .pt file with {'tokens': 1D LongTensor}
     No chunking is done here. We'll chunk at load time.

    Params:
        @input_dir: Directory containing e.g. train_shard_0.json, train_shard_1.json, ...
        @output_dir: The output directory, use the tokenizer name
        @tokenizer: A HuggingFace tokenizer object
        @eos_seperator_id: e.g. tokenizer.eos_token_id
        @num_processes: How many CPU processes to use for parallel tokenization. If None, defaults to half the CPU cores.

    Returns:
        The path to the newly created directory with .pt files.
    """
    # 1) Create an output dir for the tokenizer
    os.makedirs(output_dir, exist_ok = True)

    # 2) Find all shard files
    shard_files = sorted(glob(os.path.join(input_dir, "train_shard_*.json")))
    
    # Sort numerically
    def extract_shard_num(path):
        m = re.search(r'train_shard_(\d+)\.json', path)
        if not m:
            return 999999  # fallback if not found
        return int(m.group(1))
    
    shard_files.sort(key = extract_shard_num)

    if len(shard_files) == 0:
        raise RuntimeError(f"No shard files found in {input_dir} with pattern train_shard_*.json")

    # 3) Determine num_processes if not set
    if num_processes is None:
        num_processes = max(1, multiprocessing.cpu_count() // 2)

    # 4) Process each shard
    for shard_path in tqdm(shard_files):

        shard_name = os.path.splitext(os.path.basename(shard_path))[0]
        out_pt_path = os.path.join(output_dir, f"{shard_name}.pt")

        if os.path.isfile(out_pt_path):
            # Shard already processed -> skip
            tqdm.write(f"Skipping {shard_name} because {out_pt_path} already exists.")
            continue

        token_list = _parallel_tokenize_shard(shard_path, tokenizer, eos_seperator_id, num_processes)

        # Convert to 1D LongTensor
        tokens_1d = torch.tensor(token_list, dtype = torch.long)

        # print(f"Saving {len(token_list)} tokens to {out_pt_path}")
        torch.save({"tokens": tokens_1d}, out_pt_path)
        
    print(f"All shards processed -> .pt in {output_dir}")
    return output_dir

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)

preprocessed_dir = preprocess_all_shards(
    input_dir = '.',
    output_dir = './olmoe-tokenizer',
    tokenizer = tokenizer,
    eos_seperator_id = tokenizer.eos_token_id
)