In [1]:
from datasets import load_dataset
import json

def get_fineweb_edu_data_sharded(
    shard_size = 50000,
    max_samples = 10000000,
    out_prefix = "./train_shard",
    val_filename = "./val_shard.json",
    val_size = 1000
):
    """
    Stream the FineWeb-Edu dataset and write out training samples in shards. Also create a validation shard of 'val_size' samples at the beginning.
    These are stored as raw samples.

    Params:
        @shard_size: Number of samples per training shard.
        @max_samples: Total samples for training. If `None`, read until dataset ends.
        @out_prefix: Filename prefix for train shards.
        @val_filename: Filename for the validation shard.
        @val_size: Number of samples in the validation set.
    """
    ds = load_dataset("HuggingFaceFW/fineweb-edu", name="default", split="train", streaming=True)
    ds = ds.filter(lambda x: x.get("language") == "en") #and x.get("score") >= 4
    ds_iter = iter(ds)

    # ------------------------------------------------
    # Collect validation samples
    # ------------------------------------------------
    val_data = []
    for _ in range(val_size):
        sample = next(ds_iter, None)
        if sample is None:
            break
        val_data.append(sample["text"])

    with open(val_filename, "w", encoding="utf-8") as f:
        json.dump(val_data, f, ensure_ascii=False)
    print(f"Saved {len(val_data)} validation samples to {val_filename}")

    # ------------------------------------------------
    # Collect training shards in a single pass
    # ------------------------------------------------
    total_written = 0
    shard_idx = 0

    while True:
        # If we have a max_samples limit and we've reached it, stop
        if max_samples is not None and total_written >= max_samples:
            break

        # Gather up to shard_size items
        chunk = []
        for _ in range(shard_size):
            sample = next(ds_iter, None)
            if sample is None:
                # No more data in the stream
                break
            chunk.append(sample)

        if not chunk:
            break  # We reached EOF on the stream

        # Extract text from each sample
        texts = [x["text"] for x in chunk]

        # Write shard
        shard_path = f"{out_prefix}_{shard_idx}.json"
        with open(shard_path, "w", encoding="utf-8") as f:
            json.dump(texts, f, ensure_ascii=False)

        shard_idx += 1
        total_written += len(chunk)
        print(f"Wrote shard {shard_path} with {len(chunk)} samples (total so far: {total_written}).")

    print("Done generating shards.")

get_fineweb_edu_data_sharded()

Resolving data files:   0%|          | 0/2080 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/2080 [00:00<?, ?it/s]

Saved 1000 validation samples to ./d2/val_shard.json
Wrote shard ./d2/train_shard_0.json with 50000 samples (total so far: 50000).
Wrote shard ./d2/train_shard_1.json with 50000 samples (total so far: 100000).
Wrote shard ./d2/train_shard_2.json with 50000 samples (total so far: 150000).
Wrote shard ./d2/train_shard_3.json with 50000 samples (total so far: 200000).
Wrote shard ./d2/train_shard_4.json with 50000 samples (total so far: 250000).
Wrote shard ./d2/train_shard_5.json with 50000 samples (total so far: 300000).
Wrote shard ./d2/train_shard_6.json with 50000 samples (total so far: 350000).
Wrote shard ./d2/train_shard_7.json with 50000 samples (total so far: 400000).
Wrote shard ./d2/train_shard_8.json with 50000 samples (total so far: 450000).
Wrote shard ./d2/train_shard_9.json with 50000 samples (total so far: 500000).
Wrote shard ./d2/train_shard_10.json with 50000 samples (total so far: 550000).
Wrote shard ./d2/train_shard_11.json with 50000 samples (total so far: 600000)

In [None]:
# Old, non-sharded
# def get_fineweb_edu_data(n_samples: int = 1000):
#     dataset = load_dataset("HuggingFaceFW/fineweb-edu", name = "default", split = 'train', streaming = True)
#     dataset = dataset.filter(lambda x: x.get('language') == 'en' and x.get('score') >= 4)
#     dataset_pulled = list(islice(dataset, n_samples))  # Convert to a list of the first 1,000 samples
#     dataset_pulled = [x['text'] for x in dataset_pulled]    
#     return dataset_pulled