In [None]:
from datasets import load_dataset
from itertools import islice
import json

def get_fineweb_edu_data_sharded(
    shard_size = 20_000,
    max_samples = 2_000_000,  # or None for infinite
    out_prefix = "data/train_shard",
    val_filename = "data/val_shard.json",
    val_size = 1_000
):
    """
    Stream the FineWeb-Edu dataset and write out training samples in shards. Also create a validation shard of 'val_size' samples at the beginning.
    
    Params:
        @shard_size: The # of samples to store in each train shard
        @max_samples: The total number of samples
        @out_prefix: The filename prefix for the train shards 
        @val_filename: The filename for the val set
        @val_size: The total size of the val set
    """
    ds = load_dataset("HuggingFaceFW/fineweb-edu", name = "default", split = "train", streaming=True)
    ds = ds.filter(lambda x: x.get("language") == "en" and x.get("score") >= 4)

    # We will first get 'val_size' samples for validation:
    val_iter = islice(ds, val_size) 
    val_data = [x["text"] for x in val_iter]
    with open(val_filename, "w") as f:
        json.dump(val_data, f)
    print(f"Saved {val_size} samples to {val_filename}")

    # Now the remaining data goes to training shards
    total_written = 0
    shard_idx = 0

    while True:
        # If max_samples is not None, stop if we exceed it
        if max_samples is not None and total_written >= max_samples:
            break

        # Pull up to shard_size items
        chunk = list(islice(ds, shard_size))
        if not chunk:
            break 

        texts = [x["text"] for x in chunk]
        shard_path = f"{out_prefix}_{shard_idx}.json"

        with open(shard_path, "w") as f:
            json.dump(texts, f)

        shard_idx += 1
        total_written += len(chunk)
        print(f"Wrote shard {shard_path} with {len(chunk)} samples (total so far: {total_written}).")

get_fineweb_edu_data_sharded()

In [None]:
# Old, non-sharded
# def get_fineweb_edu_data(n_samples: int = 1000):
#     dataset = load_dataset("HuggingFaceFW/fineweb-edu", name = "default", split = 'train', streaming = True)
#     dataset = dataset.filter(lambda x: x.get('language') == 'en' and x.get('score') >= 4)
#     dataset_pulled = list(islice(dataset, n_samples))  # Convert to a list of the first 1,000 samples
#     dataset_pulled = [x['text'] for x in dataset_pulled]    
#     return dataset_pulled