In [3]:
import os
from pathlib import Path
from datasets import load_dataset


local_dir =  "data" 
DATA_CACHE_DIR = os.path.join(Path(), local_dir)
os.makedirs(DATA_CACHE_DIR, exist_ok=True)

# download the dataset
dataset = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT", split="train", cache_dir=DATA_CACHE_DIR)


Resolving data files:   0%|          | 0/2110 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/98 [00:00<?, ?it/s]

In [62]:
import time
import math
from functools import partial

import multiprocessing
#multiprocessing.set_start_method('fork')   
import concurrent.futures as cf
import numpy as np


num_cpus = os.cpu_count()
total_docs = len(dataset)
docs_per_cpu = int(math.ceil(total_docs/num_cpus))
print(f"""
dataset statistics
------------------
documents: {total_docs:,}
docs_per_cpu: {docs_per_cpu}""")

def count_tokens(dataset, docs_per_cpu, n):

    import tiktoken
    
    # set up tokenizer
    enc = tiktoken.get_encoding("gpt2")
    
    total_length = 0
    docs_processed = 0
    start = n*docs_per_cpu
    if len(dataset) - (n*docs_per_cpu) < docs_per_cpu:
        end = len(dataset)
    else:
        end = start + docs_per_cpu
    for idx, d in enumerate(dataset[start:end]['text']):
        total_length += len(enc.encode_ordinary(d))
        docs_processed += 1
    return total_length, docs_processed

f = partial(count_tokens, dataset, docs_per_cpu)

start = time.time()
with cf.ProcessPoolExecutor(max_workers = num_cpus) as ex:
    results = list(ex.map(f, range(num_cpus)))


dataset statistics
------------------
documents: 9,672,101
docs_per_cpu: 322404


In [63]:
print(f"Processed documents in {time.time()-start:0.2f} seconds")
print(f"Total tokens: {sum([r[0] for r in results]):,}")
print(f"Total documents: {sum([r[1] for r in results]):,}")

Processed documents in 165.16 seconds
Total tokens: 9,944,317,243
Total documents: 9,672,101


In [64]:

SHARD_SIZE = int(1e8)
output_dir = "shards"

def tokenize(docs_per_cpu, shard_size, n):
    #print(shard_size, n)
    import os
    import gc
    
    from pathlib import Path
    from datasets import load_dataset
    import tiktoken
    
    # set up tokenizer
    enc = tiktoken.get_encoding("gpt2")
    eot = enc._special_tokens['<|endoftext|>']
    
    # load the dataset
    dataset = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT", split="train", cache_dir=os.path.join(Path(), "data"))

    # create the np array
    shard = np.array([eot]*shard_size)
    shard_idx = 1

    docs_processed = 0
    shards = 0
    tokens = 0
    start = n*docs_per_cpu
    if len(dataset) - (n*docs_per_cpu) < docs_per_cpu:
        end = len(dataset)
    else:
        end = start + docs_per_cpu
    #print(start, end)
    for idx, d in enumerate(dataset[start:end]['text']):
        #print(f"{idx:,}", end="\r")
        new_tokens = enc.encode_ordinary(d)
        tokens += len(new_tokens)
        if shard_idx + len(new_tokens) > shard_size:
            np.savez(os.path.join(Path(), "shards", f"fineweb_{n}_{shards}.npz"), shard)
            shard_idx = 1
            shards += 1
        else:
            shard_idx += len(new_tokens)
            shard[idx:idx+len(new_tokens)] = new_tokens
        docs_processed += 1
    return docs_processed, shards, tokens

t = partial(tokenize, docs_per_cpu, SHARD_SIZE)

start = time.time()
with cf.ProcessPoolExecutor(max_workers = num_cpus) as ex:
    results = list(ex.map(t, range(num_cpus)))
print(results)

Using the latest cached version of the dataset since HuggingFaceFW/fineweb-edu couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'sample-10BT' at data/HuggingFaceFW___fineweb-edu/sample-10BT/0.0.0/4863ab07d7520451e6f73e2912ad8bfee7d97c11 (last modified on Mon Mar  3 20:24:57 2025).
Using the latest cached version of the dataset since HuggingFaceFW/fineweb-edu couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'sample-10BT' at data/HuggingFaceFW___fineweb-edu/sample-10BT/0.0.0/4863ab07d7520451e6f73e2912ad8bfee7d97c11 (last modified on Mon Mar  3 20:24:57 2025).


Loading dataset shards:   0%|          | 0/98 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/98 [00:00<?, ?it/s]

[(322404, 3, 330967344), (322404, 3, 336233239), (322404, 3, 330429730), (322404, 3, 332577826), (322404, 3, 334436000), (322404, 3, 331046095), (322404, 3, 333893381), (322404, 3, 327421391), (322404, 3, 330047401), (322404, 3, 324200848), (322404, 3, 323097647), (322404, 3, 323135816), (322404, 3, 328331221), (322404, 3, 330057048), (322404, 3, 332079638), (322404, 3, 330523595), (322404, 3, 329566197), (322404, 3, 337826108), (322404, 3, 334898498), (322404, 3, 332973900), (322404, 3, 338259167), (322404, 3, 342887738), (322404, 3, 335663348), (322404, 3, 331463361), (322404, 3, 331103321), (322404, 3, 328913342), (322404, 3, 327942477), (322404, 3, 329962145), (322404, 3, 331341726), (322385, 3, 333037695)]
Processed documents in 188.69353890419006 seconds
total shards written: 90
total tokens: 9,944,317,243


In [65]:
print(f"Processed documents in {time.time()-start:.2f} seconds")
assert(sum([r[0] for r in results]) == total_docs)
total_shards = sum([r[1] for r in results])
print(f"total shards written: {total_shards:,}")
total_tokens = sum([r[2] for r in results])
print(f"total tokens: {total_tokens:,}")

Processed documents in 252.66 seconds
total shards written: 90
total tokens: 9,944,317,243
