# Preprocess the Fineweb-Edu Dataset


## Load the Dataset


In [1]:
import os
from datasets import load_dataset


local_dir =  "data" 
DATA_CACHE_DIR = os.path.join("/home/ubuntu/gpt-3-train", local_dir)
os.makedirs(DATA_CACHE_DIR, exist_ok=True)

# download the dataset
dataset = load_dataset("HuggingFaceFW/fineweb-edu", "sample-100BT", split="train", cache_dir=DATA_CACHE_DIR)


Resolving data files:   0%|          | 0/2110 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/140 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/984 [00:00<?, ?it/s]

## Setup


In [2]:
import time
import math
from functools import partial
import concurrent.futures as cf

import numpy as np

import tiktoken

num_cpus = os.cpu_count()
print(f"""
sytem statistics:
-----------------
cpu count: {num_cpus}""")
total_docs = len(dataset)
docs_per_cpu = int(math.ceil(total_docs/num_cpus))
print(f"""
dataset statistics
------------------
documents: {total_docs:,}
docs_per_cpu: {docs_per_cpu:,}""")

enc = tiktoken.get_encoding("gpt2")


sytem statistics:
-----------------
cpu count: 64

dataset statistics
------------------
documents: 97,270,686
docs_per_cpu: 1,519,855


## Dummy Preprocessing Operation

In [14]:
def count_tokens(dataset, enc, idx):
    tokens = enc.encode_ordinary(dataset[idx]['text'])
    return len(tokens)

f = partial(count_tokens, dataset, enc)

with cf.ProcessPoolExecutor(max_workers = num_cpus) as ex:
    start = time.time()
    documents = 0
    tokens = 0
    
    for result in ex.map(f, range(len(dataset)), chunksize=docs_per_cpu//10):
        documents += 1
        tokens += result
        documents % 1e4 == 0 and print(f"processed {documents:,}", end="\r")
        
    print(f"processed documents in {time.time()-start:0.2f} seconds")
    print(f"total tokens: {tokens:,}")
    print(f"total documents: {documents:,}")   
    assert(documents == total_docs)

processed documents in 987.50 seconds
total tokens: 100,069,194,385
total documents: 97,270,686


## Actual Preprocessing Operation

In [None]:
SHARD_SIZE = int(1e8)
output_dir = "processed"
os.makedirs(os.path.join(DATA_CACHE_DIR, output_dir), exist_ok=True)

def write_shard(shard, shard_idx):
    if shard_idx % 100 == 0:
        split = "valid"
    else:
        split = "train"
    
    
    f_path = os.path.join(DATA_CACHE_DIR, output_dir, f"fineweb_edu_{split}_{shard_idx}")
    np.savez(f_path, shard)

def tokenize(dataset, encoder, idx):
    eot = encoder._special_tokens['<|endoftext|>']
    tokens = [eot] + encoder.encode_ordinary(dataset[idx]['text'])
    return tokens

f = partial(tokenize, dataset, enc)

with cf.ProcessPoolExecutor(max_workers = num_cpus) as ex:
    start = time.time()
    
    docs_processed = 0
    shards_written = 0
    tokens_generated = 0
    shard_token_count = 0

    shard = np.empty((SHARD_SIZE,), dtype=np.uint16)
    
    for tokens in ex.map(f, range(len(dataset)), chunksize=docs_per_cpu//100):
        docs_processed += 1
        tokens_generated += len(tokens)

        if docs_processed % 1e4 == 0:
            print(f"processed {docs_processed:,} documents | generated {tokens_generated:,} tokens | wrote {shards_written} shards", end="\r")

        if shard_token_count + len(tokens) < SHARD_SIZE:
            shard[shard_token_count:shard_token_count + len(tokens)] = tokens 
            shard_token_count += len(tokens)
        else:
            remainder = SHARD_SIZE - shard_token_count
            shard[shard_token_count:shard_token_count + remainder] = tokens[:remainder]
            write_shard(shard, shards_written)
            shards_written += 1
            
            shard[:len(tokens) - remainder] = tokens[remainder:]
            shard_token_count = len(tokens) - remainder
    
    write_shard(shard, shards_written) #write the final shard
    shards_written += 1
    print(f"processed {docs_processed:,} documents | generated {tokens_generated:,} tokens | wrote {shards_written} shards", end="\r")        
    print(f"finished in {time.time()-start:.2f} seconds")
    assert(docs_processed == total_docs)
    print(f"total shards written: {shards_written:,}")
    print(f"total tokens: {tokens_generated:,}")

processed 590,000 documents | generated 610,238,999 tokens | wrote 6 shards