In [None]:
import glob as glob
import time 
from transformers import AutoTokenizer

shard_files = sorted(glob.glob("./../../data/train_shard_*.json"))
print(f"Found {len(shard_files)} shards.")
tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)

In [None]:
from helpers.dataset import load_shard_as_dataloader

for shard_idx, shard_path in enumerate(shard_files[0:4]):
    start_time = time.time()
    print('Start')
    shard_dl = load_shard_as_dataloader(shard_path, tokenizer, batch_size = 64 * 4, seq_len = 2048, eos_seperator_id = tokenizer.eos_token_id)
    print(time.time() - start_time)

## Alt Version 1 - Multiprocess import

In [None]:
from helpers.dataset import load_shard_as_dataloader2

for shard_idx, shard_path in enumerate(shard_files[0:5]):
    start_time = time.time()
    print('Start')
    shard_dl = load_shard_as_dataloader2(shard_path, tokenizer, batch_size = 64 * 4, seq_len = 2048, eos_seperator_id = tokenizer.eos_token_id)
    print(time.time() - start_time)

## Alt Version 2 - Preprocess first, shard later

In [None]:
import torch
import multiprocessing
from functools import partial
from glob import glob
import os
import json
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

def _tokenize_line(line, tokenizer_encode, eos_id):
    """
    Helper function for parallel tokenization.
    'tokenizer_encode' is a partial or bound method that does:tokenizer.encode(..., add_special_tokens=False)
    """
    line_toks = tokenizer_encode(line)
    return line_toks + [eos_id]

def _parallel_tokenize_shard(shard_path: str, tokenizer, eos_seperator_id: int, num_processes: int):
    """
    Reads 'shard_path' JSON file, parallel-tokenizes lines, concatenates them into single 1D list of tokens
     
    Return a 1D Python list of token IDs returns (input_ids, attention_mask) as PyTorch tensors.
    """
    # 1) Read lines from JSON
    with open(shard_path, 'r') as f:
        lines = json.load(f)

    # 2) Parallel tokenization
    tokenizer_encode = partial(tokenizer.encode, add_special_tokens = False)
    with multiprocessing.Pool(processes = num_processes) as pool:
        tokenized_lines = pool.map(partial(_tokenize_line, tokenizer_encode = tokenizer_encode, eos_id = eos_seperator_id), lines)

    # 3) Concatenate
    big_token_buffer = []
    for toks in tokenized_lines:
        big_token_buffer.extend(toks)

    return big_token_buffer

def preprocess_all_shards(input_dir: str, output_base_dir: str, tokenizer, eos_seperator_id: int, num_processes: int = None):
    """
    Pre-process all shards in 'input_dir' that match 'train_shard_*.json'.
    For each shard, parallel-tokenize + chunk, -> one big 1D list of token IDs, then save that list as a .pt file with {'tokens': 1D LongTensor}
    No chunking is done here. We'll chunk at load time.

    We'll create an output subdir named after 'tokenizer'.
    Params:
        @input_dir: Directory containing e.g. train_shard_0.json, train_shard_1.json, ...
        @output_base_dir: The base directory. We'll create a subdir with the tokenizer name
        @tokenizer: A HuggingFace tokenizer object
        @eos_seperator_id: e.g. tokenizer.eos_token_id
        @num_processes: How many CPU processes to use for parallel tokenization. If None, defaults to half the CPU cores.

    Returns:
        The path to the newly created directory with .pt files.
    """
    # 1) Create an output dir named after the tokenizer
    tokenizer_name = getattr(tokenizer, "name_or_path", "custom_tokenizer")
    tokenizer_name_clean = os.path.basename(tokenizer_name)
    output_dir = os.path.join(output_base_dir, f"{tokenizer_name_clean}")
    os.makedirs(output_dir, exist_ok = True)

    # 2) Find all shard files
    shard_files = sorted(glob(os.path.join(input_dir, "train_shard_*.json")))
    if len(shard_files) == 0:
        raise RuntimeError(f"No shard files found in {input_dir} with pattern train_shard_*.json")

    # 3) Determine num_processes if not set
    if num_processes is None:
        num_processes = max(1, multiprocessing.cpu_count() // 2)

    # 4) Process each shard
    for shard_path in shard_files:
        print(f"Preprocessing shard: {shard_path}")
        token_list = _parallel_tokenize_shard(shard_path, tokenizer, eos_seperator_id, num_processes)

        # Convert to 1D LongTensor
        tokens_1d = torch.tensor(token_list, dtype = torch.long)
        shard_name = os.path.splitext(os.path.basename(shard_path))[0] 
        out_pt_path = os.path.join(output_dir, f"{shard_name}.pt")

        print(f"Saving {len(token_list)} tokens to {out_pt_path}")
        torch.save({"tokens": tokens_1d}, out_pt_path)

    print(f"All shards processed -> .pt in {output_dir}")
    return output_dir


from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)
preprocessed_dir = preprocess_all_shards(
    input_dir = '.',
    output_base_dir = '.',
    tokenizer = tokenizer,
    eos_seperator_id = tokenizer.eos_token_id
)

In [None]:
shard_pt_files = sorted(glob.glob("./../../data/OLMoE-1B-7B-0924/train_shard_*.pt"))
print(f"Found {len(shard_pt_files)} shards.")
tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)

In [None]:
from torch.utils.data import DataLoader
from helpers.dataset import TextDataset

def load_shard_as_dataloader(shard_pt_path: str, tokenizer,  batch_size: int, seq_len: int, shuffle: bool = True):
    """
    Loads the .pt file with a 1D tokens_1d array, then chunk it into seq_len 
    blocks on the fly. Returns a DataLoader of (input_ids, attention_mask).

    If the final chunk is shorter, we pad it. 
    We use 'tokenizer.pad_token_id' as the pad ID.
    """
    data_dict = torch.load(shard_pt_path, weights_only = False)
    tokens_1d = data_dict["tokens"]
    pad_id = tokenizer.pad_token_id

    # 1) Chunk in memory
    total_tokens = tokens_1d.shape[0]
    examples = []
    i = 0
    while i < total_tokens:
        chunk = tokens_1d[i : i + seq_len]
        i += seq_len
        if chunk.shape[0] < seq_len:
            pad_size = seq_len - chunk.shape[0]
            chunk = torch.cat(
                [chunk, torch.full((pad_size,), pad_id, dtype = torch.long)]
            )
        examples.append(chunk.unsqueeze(0))  # shape (1, seq_len)

    # 2) Build a big [num_chunks, seq_len] Tensor
    input_ids = torch.cat(examples, dim = 0)  # shape (num_chunks, seq_len)
    attention_mask = (input_ids != pad_id).long()

    ds = TextDataset({"input_ids": input_ids, "attention_mask": attention_mask})
    dl = DataLoader(ds, batch_size = batch_size, shuffle = shuffle)
    return dl

for shard_idx, shard_path in enumerate(shard_pt_files[0:5]):
    start_time = time.time()
    print('Start')
    shard_dl = load_shard_as_dataloader(shard_path, tokenizer, batch_size = 64 * 4, seq_len = 2048)
    print(time.time() - start_time)