In [5]:
import os
from pathlib import Path
from datasets import load_dataset


local_dir =  "data" 
DATA_CACHE_DIR = os.path.join("/home/ubuntu/traingpt2data2", local_dir)
os.makedirs(DATA_CACHE_DIR, exist_ok=True)

# download the dataset
dataset = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT", split="train", cache_dir=DATA_CACHE_DIR)


Resolving data files:   0%|          | 0/2110 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/98 [00:00<?, ?it/s]

In [3]:
import time
import math
from functools import partial

import multiprocessing
#multiprocessing.set_start_method('fork')   
import concurrent.futures as cf
import numpy as np

import tiktoken

num_cpus = os.cpu_count()
print(f"""
sytem statistics:
-----------------
cpu count: {num_cpus}""")
total_docs = len(dataset)
docs_per_cpu = int(math.ceil(total_docs/num_cpus))
print(f"""
dataset statistics
------------------
documents: {total_docs:,}
docs_per_cpu: {docs_per_cpu}""")

enc = tiktoken.get_encoding("gpt2")


sytem statistics:
-----------------
cpu count: 240

dataset statistics
------------------
documents: 9,672,101
docs_per_cpu: 40301


In [3]:
def count_tokens(dataset, enc, idx):
    tokens = enc.encode_ordinary(dataset[idx]['text'])
    return tokens

f = partial(count_tokens, dataset, enc)

with cf.ProcessPoolExecutor(max_workers = num_cpus) as ex:
    start = time.time()
    documents = 0
    tokens = 0
    
    for result in ex.map(f, range(len(dataset)), chunksize=docs_per_cpu):
        documents += 1
        tokens += len(result)
        documents % 100000 == 0 and print(f"processed {documents:,}", end="\r")
        
    print(f"processed documents in {time.time()-start:0.2f} seconds")
    print(f"total tokens: {tokens:,}")
    print(f"total documents: {documents:,}")   
    assert(documents == total_docs)

processed 2,500,000

Process ForkProcess-4:
Process ForkProcess-47:
Process ForkProcess-145:
Process ForkProcess-51:
Process ForkProcess-150:
Process ForkProcess-48:
Process ForkProcess-2:
Process ForkProcess-54:
Process ForkProcess-100:
Process ForkProcess-3:
Process ForkProcess-85:
Process ForkProcess-11:
Process ForkProcess-41:
Process ForkProcess-5:
Process ForkProcess-76:
Process ForkProcess-75:
Process ForkProcess-50:
Process ForkProcess-27:
Process ForkProcess-45:
Process ForkProcess-57:
Process ForkProcess-40:
Process ForkProcess-46:
Process ForkProcess-10:
Process ForkProcess-29:
Process ForkProcess-99:
Process ForkProcess-32:
Process ForkProcess-33:
Process ForkProcess-28:
Process ForkProcess-59:
Process ForkProcess-8:
Process ForkProcess-49:
Process ForkProcess-101:
Process ForkProcess-44:
Process ForkProcess-6:
Process ForkProcess-52:
Process ForkProcess-60:
Process ForkProcess-9:
Process ForkProcess-43:
Process ForkProcess-53:
Process ForkProcess-78:
Process ForkProcess-37:
Process ForkProcess

KeyboardInterrupt: 

In [None]:
import tqdm

SHARD_SIZE = int(1e8)
output_dir = "processed"
os.makedirs(os.path.join(DATA_CACHE_DIR, output_dir), exist_ok=True)

def write_shard(shard, shard_idx):
    if shard_idx == 0:
        split = "valid"
    else:
        split = "train"
    
    
    f_path = os.path.join(DATA_CACHE_DIR, output_dir, f"fineweb_edu_{split}_{shard_idx}")
    np.savez(f_path, shard)

def tokenize(dataset, encoder, idx):
    eot = encoder._special_tokens['<|endoftext|>']
    tokens = [eot] + encoder.encode_ordinary(dataset[idx]['text'])
    return tokens

f = partial(tokenize, dataset, enc)

with cf.ProcessPoolExecutor(max_workers = num_cpus) as ex:
    start = time.time()
    
    docs_processed = 0
    shards_written = 0
    tokens_generated = 0
    shard_token_count = 0

    shard = np.empty((SHARD_SIZE,), dtype=np.uint16)
    
    for tokens in ex.map(f, range(len(dataset)), chunksize=docs_per_cpu):
        docs_processed += 1
        tokens_generated += len(tokens)
        print(f"processed {docs_processed:,} documents | generated {tokens_generated:,} tokens | wrote {shards_written} shards", end="\r")

        if shard_token_count + len(tokens) < SHARD_SIZE:
            shard[shard_token_count:shard_token_count + len(tokens)] = tokens 
            shard_token_count += len(tokens)
        else:
            remainder = SHARD_SIZE - shard_token_count
            shard[shard_token_count:shard_token_count + remainder] = tokens[:remainder]
            write_shard(shard, shards_written)
            shards_written += 1
            #print(f"Shards written: {shards_written}", end="\r")
            
            shard[:len(tokens) - remainder] = tokens[remainder:]
            shard_token_count = len(tokens) - remainder
    
    write_shard(shard, shards_written) #write the final shard
    shards_written += 1
    print(f"processed {docs_processed:,} documents | generated {tokens_generated:,} tokens | wrote {shards_written} shards", end="\r")        
    print(f"finished in {time.time()-start:.2f} seconds")
    assert(docs_processed == total_docs)
    print(f"total shards written: {shards_written:,}")
    print(f"total tokens: {tokens_generated:,}")