In [1]:
!pip install -qq pyspark matplotlib
import os
from pyspark.sql import SparkSession

In [2]:
# Set the target directory to '../data' relative to the notebook
data_dir = "../data"
os.makedirs(data_dir, exist_ok=True)

# Base URL for the 10BT Sample
base_url = "https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu/resolve/main/sample/10BT"

# Loop through files 000 to 013
for i in range(14):
    filename = f"{i:03d}_00000.parquet"
    save_path = os.path.join(data_dir, filename)
    
    # Only download if missing
    if not os.path.exists(save_path):
        print(f"Downloading {filename}...")
        !wget -q -O {save_path} {base_url}/{filename}
    else:
        print(f"Skipping {filename} (Already exists)")

print("All 14 sample files downloaded.")

Skipping 000_00000.parquet (Already exists)
Skipping 001_00000.parquet (Already exists)
Skipping 002_00000.parquet (Already exists)
Skipping 003_00000.parquet (Already exists)
Skipping 004_00000.parquet (Already exists)
Skipping 005_00000.parquet (Already exists)
Skipping 006_00000.parquet (Already exists)
Skipping 007_00000.parquet (Already exists)
Skipping 008_00000.parquet (Already exists)
Skipping 009_00000.parquet (Already exists)
Skipping 010_00000.parquet (Already exists)
Skipping 011_00000.parquet (Already exists)
Skipping 012_00000.parquet (Already exists)
Skipping 013_00000.parquet (Already exists)
All 14 sample files downloaded.


In [3]:
# Create SparkSession with specified configurations

# Driver memory = 1 - 2GB (fixed, small) => 2
# Executor memory = (Total Memory - Driver Memory) / Executor Instances => floor((128 - 2) / 7) = 18
# Executor instances = Total Cores - 1 => 8 - 1 = 7

spark = SparkSession.builder \
    .config("spark.driver.memory", "2g") \
    .config("spark.executor.memory", "18g") \
    .config("spark.executor.instances", 7) \
    .getOrCreate()

spark

In [4]:
# Read all parquet files from the '../data' directory into a single DataFrame
df = spark.read.parquet("../data")

In [5]:
# Show the schema of the DataFrame to verify it loaded correctly
df.printSchema()

root
 |-- text: string (nullable = true)
 |-- id: string (nullable = true)
 |-- dump: string (nullable = true)
 |-- url: string (nullable = true)
 |-- file_path: string (nullable = true)
 |-- language: string (nullable = true)
 |-- language_score: double (nullable = true)
 |-- token_count: long (nullable = true)
 |-- score: double (nullable = true)
 |-- int_score: long (nullable = true)



In [8]:
%%time
# Total number of rows
print(f"Total rows: {df.count():,}")

Total rows: 9,672,101
CPU times: user 1.17 ms, sys: 159 µs, total: 1.32 ms
Wall time: 1.35 s


In [9]:
%%time
# How many partitions Spark created (affects parallelism)
print(f"Number of partitions: {df.rdd.getNumPartitions()}")

Number of partitions: 213
CPU times: user 553 µs, sys: 59 µs, total: 612 µs
Wall time: 370 µs
