## Tokenize each shard and save into tokenizer-specific directories
- This is optional but will save significant time during training, as it uses multi-processing to pre-tokenize the data!
- If you run this, during training you should import data using `helpers.dataset.load_pt_shard_as_dataloader`, instead of `helpers.dataset.load_shard_as_dataloader`.

In [None]:
import torch
import multiprocessing
from functools import partial
from glob import glob
import os
import re
import json
from tqdm import tqdm
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

def _tokenize_line(line, tokenizer_encode, eos_id):
    """
    Helper function for parallel tokenization.
    'tokenizer_encode' is a partial or bound method that does:tokenizer.encode(..., add_special_tokens=False)
    """
    line_toks = tokenizer_encode(line)
    return line_toks + [eos_id]

def _parallel_tokenize_shard(shard_path: str, tokenizer, eos_seperator_id: int, num_processes: int):
    """
    Reads 'shard_path' JSON file, parallel-tokenizes lines, concatenates them into single 1D list of tokens
     
    Return a 1D Python list of token IDs returns (input_ids, attention_mask) as PyTorch tensors.
    """
    # 1) Read lines from JSON
    with open(shard_path, 'r') as f:
        lines = json.load(f)

    # 2) Parallel tokenization
    tokenizer_encode = partial(tokenizer.encode, add_special_tokens = False)
    with multiprocessing.Pool(processes = num_processes) as pool:
        tokenized_lines = pool.map(partial(_tokenize_line, tokenizer_encode = tokenizer_encode, eos_id = eos_seperator_id), lines)

    # 3) Concatenate
    big_token_buffer = []
    for toks in tokenized_lines:
        big_token_buffer.extend(toks)

    return big_token_buffer

def preprocess_all_shards(input_dir: str, output_dir: str, tokenizer, eos_seperator_id: int, num_processes: int = None):
    """
    Pre-process all shards in 'input_dir' that match 'train_shard_*.json'.
     For each shard, parallel-tokenize + chunk, -> one big 1D list of token IDs, then save that list as a .pt file with {'tokens': 1D LongTensor}
     No chunking is done here. We'll chunk at load time.

    Params:
        @input_dir: Directory containing e.g. train_shard_0.json, train_shard_1.json, ...
        @output_dir: The output directory, use the tokenizer name
        @tokenizer: A HuggingFace tokenizer object
        @eos_seperator_id: e.g. tokenizer.eos_token_id
        @num_processes: How many CPU processes to use for parallel tokenization. If None, defaults to half the CPU cores.

    Returns:
        The path to the newly created directory with .pt files.
    """
    # 1) Create an output dir for the tokenizer
    os.makedirs(output_dir, exist_ok = True)

    # 2) Find all shard files
    train_shard_files = sorted(glob(os.path.join(input_dir, "train_shard_*.json")))
    val_shard_file = glob(os.path.join(input_dir, "val_shard.json"))
        
    # Sort numerically
    def extract_shard_num(path):
        m = re.search(r'train_shard_(\d+)\.json', path)
        if not m:
            return 999999  # fallback if not found
        return int(m.group(1))
    
    train_shard_files.sort(key = extract_shard_num)
    shard_files = val_shard_file + train_shard_files

    if len(shard_files) == 0:
        raise RuntimeError(f"No shard files found in {input_dir} with pattern train_shard_*.json")

    # 3) Determine num_processes if not set
    if num_processes is None:
        num_processes = max(1, multiprocessing.cpu_count() // 2)

    # 4) Process each shard
    for shard_path in tqdm(shard_files):

        shard_name = os.path.splitext(os.path.basename(shard_path))[0]
        out_pt_path = os.path.join(output_dir, f"{shard_name}.pt")

        if os.path.isfile(out_pt_path):
            # Shard already processed -> skip
            tqdm.write(f"Skipping {shard_name} because {out_pt_path} already exists.")
            continue

        token_list = _parallel_tokenize_shard(shard_path, tokenizer, eos_seperator_id, num_processes)

        # Convert to 1D LongTensor
        tokens_1d = torch.tensor(token_list, dtype = torch.long)

        # print(f"Saving {len(token_list)} tokens to {out_pt_path}")
        torch.save({"tokens": tokens_1d}, out_pt_path)
        
    print(f"All shards processed -> .pt in {output_dir}")
    return output_dir

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)

preprocessed_dir = preprocess_all_shards(
    input_dir = './raw-data',
    output_dir = './olmoe-tokenizer',
    tokenizer = tokenizer,
    eos_seperator_id = tokenizer.eos_token_id,
    num_processes = 4
)