In [1]:
import tqdm.auto as tqdm
import numpy as np

In [2]:
import dask.dataframe as dd

data_files = "/local/home/dvruette/nemotron_tokenized/data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/**/*.parquet"

ddf = dd.read_parquet(data_files, engine='pyarrow')
df = ddf
ddf.npartitions

5950

In [3]:
import re
import glob


regex = re.compile(r".*/quality=(?P<quality>[a-z-]+)/kind=(?P<kind>[a-z-]+)/kind2=(?P<kind2>[a-z-_]+)/CC-MAIN-(?P<year>\d+)-(?P<month>\d+)-part-(?P<part>\d+)(\.parquet|\.jsonl\.zstd)$")

def parse_filename(file):
    match = re.search(regex, file)
    if match:
        meta = match.groupdict()
        meta['part'] = int(meta['part'])
        meta['year'] = int(meta['year'])
        meta['month'] = int(meta['month'])
        return meta
    print(file)
    return None

In [4]:
from collections import defaultdict

files = glob.glob(data_files, recursive=True)
file_meta = [parse_filename(file) for file in files]

sorted_meta, sorted_files = zip(*sorted(zip(file_meta, files), key=lambda x: (x[0]["quality"], x[0]["kind"], x[0]["kind2"], x[0]['year'], x[0]['month'], x[0]['part'])))

grouped_files = defaultdict(list)
for meta, file in zip(sorted_meta, sorted_files):
    grouped_files[(meta['quality'], meta['kind'], meta['kind2'])].append(file)

grouped_files[("medium", "actual", "actual")] = grouped_files[("medium", "actual", "actual")][:int(0.75*len(grouped_files[("medium", "actual", "actual")]))]
grouped_files[("medium-high", "actual", "actual")] = grouped_files[("medium-high", "actual", "actual")][:int(0.75*len(grouped_files[("medium-high", "actual", "actual")]))]
grouped_files[("medium-low", "actual", "actual")] = grouped_files[("medium-low", "actual", "actual")][:int(0.75*len(grouped_files[("medium-low", "actual", "actual")]))]

total = sum(len(vals) for vals in grouped_files.values())
print(total, len(files))
{key: (len(vals), f"{100*len(vals) / total:.2f}%") for key, vals in grouped_files.items()}

5005 5950


{('high', 'actual', 'actual'): (404, '8.07%'),
 ('high', 'synthetic', 'distill'): (135, '2.70%'),
 ('high', 'synthetic', 'diverse_qa_pairs'): (398, '7.95%'),
 ('high', 'synthetic', 'extract_knowledge'): (254, '5.07%'),
 ('high', 'synthetic', 'knowledge_list'): (150, '3.00%'),
 ('high', 'synthetic', 'wrap_medium'): (274, '5.47%'),
 ('low', 'actual', 'actual'): (298, '5.95%'),
 ('low', 'synthetic', 'wrap_medium'): (263, '5.25%'),
 ('medium', 'actual', 'actual'): (1614, '32.25%'),
 ('medium-high', 'actual', 'actual'): (396, '7.91%'),
 ('medium-low', 'actual', 'actual'): (819, '16.36%')}

In [5]:
split = [f for files in grouped_files.values() for f in files]

with open("nemotron-cc-split.txt", "w") as f:
    f.writelines(f"{line}\n" for line in split)

In [7]:
import random

shuffled_files = random.sample(files, 5000)
shuffled_files_meta = [parse_filename(file) for file in shuffled_files]

sorted_meta, sorted_files = zip(*sorted(zip(shuffled_files_meta, shuffled_files), key=lambda x: (x[0]["quality"], x[0]["kind"], x[0]["kind2"], x[0]['year'], x[0]['month'], x[0]['part'])))

grouped_files = defaultdict(list)
for meta, file in zip(sorted_meta, sorted_files):
    grouped_files[(meta['quality'], meta['kind'], meta['kind2'])].append(file)

{key: (len(vals), f"{100*len(vals) / len(shuffled_files):.2f}%") for key, vals in grouped_files.items()}

{('high', 'actual', 'actual'): (330, '6.60%'),
 ('high', 'synthetic', 'distill'): (110, '2.20%'),
 ('high', 'synthetic', 'diverse_qa_pairs'): (326, '6.52%'),
 ('high', 'synthetic', 'extract_knowledge'): (204, '4.08%'),
 ('high', 'synthetic', 'knowledge_list'): (122, '2.44%'),
 ('high', 'synthetic', 'wrap_medium'): (227, '4.54%'),
 ('low', 'actual', 'actual'): (251, '5.02%'),
 ('low', 'synthetic', 'wrap_medium'): (222, '4.44%'),
 ('medium', 'actual', 'actual'): (1808, '36.16%'),
 ('medium-high', 'actual', 'actual'): (456, '9.12%'),
 ('medium-low', 'actual', 'actual'): (944, '18.88%')}

In [8]:
total_elements = ddf.partitions[0]["tokens"].map(len, meta=("len", "int64")).sum().compute()

In [9]:
import random

max_num_files = 5000

all_jsonls = sorted(glob.glob(f"/pub/hofmann-scratch/data/Nemotron-CC/**/*.jsonl.zstd", recursive=True))
rng = random.Random(0)
rng.shuffle(all_jsonls)
if max_num_files is not None:
    all_jsonls = all_jsonls[:max_num_files]

In [10]:
sampled_files = [f.replace("/pub/hofmann-scratch/data/Nemotron-CC", "/local/home/dvruette/nemotron_tokenized").replace("data-jsonl", "data").replace(".jsonl.zstd", ".parquet") for f in all_jsonls]

sampled_file_meta = [parse_filename(file) for file in sampled_files]

sorted_sampled_meta, sorted_sampled_files = zip(*sorted(zip(sampled_file_meta, sampled_files), key=lambda x: (x[0]["quality"], x[0]["kind"], x[0]["kind2"], x[0]['year'], x[0]['month'], x[0]['part'])))

grouped_sampled_files = defaultdict(list)
for meta, file in zip(sorted_sampled_meta, sorted_sampled_files):
    grouped_sampled_files[(meta['quality'], meta['kind'], meta['kind2'])].append(file)

{key: (len(vals), f"{100*len(vals) / len(sampled_files):.2f}%") for key, vals in grouped_sampled_files.items()}

{('high', 'actual', 'actual'): (428, '8.56%'),
 ('high', 'synthetic', 'distill'): (141, '2.82%'),
 ('high', 'synthetic', 'diverse_qa_pairs'): (428, '8.56%'),
 ('high', 'synthetic', 'extract_knowledge'): (265, '5.30%'),
 ('high', 'synthetic', 'knowledge_list'): (158, '3.16%'),
 ('high', 'synthetic', 'wrap_medium'): (293, '5.86%'),
 ('low', 'actual', 'actual'): (311, '6.22%'),
 ('low', 'synthetic', 'wrap_medium'): (283, '5.66%'),
 ('medium', 'actual', 'actual'): (1622, '32.44%'),
 ('medium-high', 'actual', 'actual'): (378, '7.56%'),
 ('medium-low', 'actual', 'actual'): (693, '13.86%')}

In [11]:
extra_files = sorted(set(files) - set(sampled_files))

stats = {}
for file in extra_files:
    meta = parse_filename(file)
    key = (meta['quality'], meta['kind'], meta['kind2'])
    if key not in stats:
        stats[key] = 0
    stats[key] += 1

stats

{('high', 'synthetic', 'diverse_qa_pairs'): 1,
 ('medium-high', 'actual', 'actual'): 167,
 ('medium-low', 'actual', 'actual'): 421,
 ('medium', 'actual', 'actual'): 609}

In [15]:
existing_files = sorted(set(files) & set(sampled_files[:4612]))

stats = {}
for file in existing_files:
    meta = parse_filename(file)
    key = (meta['quality'], meta['kind'], meta['kind2'])
    if key not in stats:
        stats[key] = 0
    stats[key] += 1

print(len(existing_files))
{key: (val, f"{100*val / len(existing_files):.2f}%") for key, val in stats.items()}

4608


{('high', 'actual', 'actual'): (395, '8.57%'),
 ('high', 'synthetic', 'distill'): (130, '2.82%'),
 ('high', 'synthetic', 'diverse_qa_pairs'): (390, '8.46%'),
 ('high', 'synthetic', 'extract_knowledge'): (241, '5.23%'),
 ('high', 'synthetic', 'knowledge_list'): (147, '3.19%'),
 ('high', 'synthetic', 'wrap_medium'): (261, '5.66%'),
 ('low', 'actual', 'actual'): (291, '6.32%'),
 ('low', 'synthetic', 'wrap_medium'): (257, '5.58%'),
 ('medium-high', 'actual', 'actual'): (353, '7.66%'),
 ('medium-low', 'actual', 'actual'): (650, '14.11%'),
 ('medium', 'actual', 'actual'): (1493, '32.40%')}

In [6]:
orig_files = glob.glob("/pub/hofmann-scratch/data/Nemotron-CC/**/*.jsonl.zstd", recursive=True)
orig_file_meta = [parse_filename(file) for file in orig_files]

sorted_orig_meta, sorted_orig_files = zip(*sorted(zip(orig_file_meta, orig_files), key=lambda x: (x[0]["quality"], x[0]["kind"], x[0]["kind2"], x[0]['year'], x[0]['month'], x[0]['part'])))

grouped_orig_files = defaultdict(list)
for meta, file in zip(sorted_orig_meta, sorted_orig_files):
    grouped_orig_files[(meta['quality'], meta['kind'], meta['kind2'])].append(file)

{key: (len(vals), f"{100*len(vals) / len(orig_files):.2f}%") for key, vals in grouped_orig_files.items()}

{('high', 'actual', 'actual'): (2755, '8.81%'),
 ('high', 'synthetic', 'distill'): (939, '3.00%'),
 ('high', 'synthetic', 'diverse_qa_pairs'): (2564, '8.20%'),
 ('high', 'synthetic', 'extract_knowledge'): (1694, '5.42%'),
 ('high', 'synthetic', 'knowledge_list'): (1140, '3.64%'),
 ('high', 'synthetic', 'wrap_medium'): (2016, '6.45%'),
 ('low', 'actual', 'actual'): (1964, '6.28%'),
 ('low', 'synthetic', 'wrap_medium'): (1788, '5.72%'),
 ('medium', 'actual', 'actual'): (9678, '30.94%'),
 ('medium-high', 'actual', 'actual'): (2454, '7.85%'),
 ('medium-low', 'actual', 'actual'): (4287, '13.71%')}

In [18]:
split = sorted(set(files) & set(sampled_files[:4612]))
assert len(split) == 4608  # 4096 + 512; ~1T tokens

with open("nemotron-cc-split.txt", "w") as f:
    f.writelines(f"{line}\n" for line in split)

In [34]:
import os
import tqdm.auto as tqdm

# delete files not in split

remaining_files = sorted(glob.glob(f"/local/home/dvruette/nemotron_tokenized/**/*.parquet", recursive=True))


split_files = []
with open("nemotron-cc-split.txt", "r") as f:
    for line in f:
        split_files.append(line.strip())

num_deleted = 0
num_remaining = len(set(remaining_files) - (set(remaining_files) - set(split_files)))
assert num_remaining == 4608, f"Expected 4608 files, found {num_remaining}"
for file in tqdm.tqdm(list(set(remaining_files) - set(split_files))):
    # print(f"Deleting {file}")
    try:
        os.remove(file)
        num_deleted += 1
        pass
    except FileNotFoundError:
        print(f"File {file} not found, skipping.")
    except Exception as e:
        print(f"Error deleting {file}: {e}")

print(f"Deleted {num_deleted} files.")

  0%|          | 0/3 [00:00<?, ?it/s]

Deleted 3 files.


In [30]:
remaining_files = sorted(glob.glob(f"/local/home/dvruette/nemotron_tokenized/**/*.parquet", recursive=True))
len(remaining_files)

4611

In [31]:

split_files = []
with open("nemotron-cc-split.txt", "r") as f:
    for line in f:
        split_files.append(line.strip())

len(split_files)

4608

In [32]:
set(remaining_files) - set(split_files), set(split_files) - set(remaining_files)

({'/local/home/dvruette/nemotron_tokenized/data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/data/quality=high/kind=synthetic/kind2=distill/CC-MAIN-2013-20-part-00003.parquet',
  '/local/home/dvruette/nemotron_tokenized/data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/data/quality=medium/kind=actual/kind2=actual/CC-MAIN-2015-06-part-00031.parquet',
  '/local/home/dvruette/nemotron_tokenized/data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/data/quality=medium/kind=actual/kind2=actual/CC-MAIN-2024-30-part-00005.parquet'},
 set())

In [93]:
import numpy as np

tokens_per_partition = []

ddf = dd.read_parquet(existing_files, engine='pyarrow')
# select 16 random partitions
ids = [np.random.choice(ddf.npartitions, replace=False) for _ in range(16)]
for i in tqdm.tqdm(ids):
    total = ddf.partitions[i]["tokens"].map(len, meta=("len", "int64")).sum().compute()
    print(f"{i:<4}: {total}")
    tokens_per_partition.append(total)

print(f"mean: {np.mean(tokens_per_partition)}")

  0%|          | 0/16 [00:00<?, ?it/s]

1384: 213888153
625 : 217872431
3795: 224092477
754 : 208462859
978 : 194725227
926 : 192144100
2590: 221744631
2597: 224437231
1948: 225692886
4080: 222484555
2833: 231279686
2064: 227719733
256 : 222107057
4022: 221265877
737 : 209203770
2145: 218691504
mean: 217238261.0625


In [79]:
files[0], sampled_files[0]

('/local/home/dvruette/nemotron_tokenized/data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/data/quality=medium-low/kind=actual/kind2=actual/CC-MAIN-2014-10-part-00000.parquet',
 '/home/loca/dvruette/nemotron_tokenized/data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/data/quality=low/kind=synthetic/kind2=wrap_medium/CC-MAIN-2023-14-part-00013.parquet')

In [None]:
import numpy as np

tokens_per_partition = []
# select 16 random partitions
ids = [np.random.choice(ddf.npartitions, replace=False) for _ in range(16)]
for i in tqdm.tqdm(ids):
    total = ddf.partitions[i]["tokens"].map(len, meta=("len", "int64")).sum().compute()
    print(f"{i:<4}: {total}")
    tokens_per_partition.append(total)

print(f"mean: {np.mean(tokens_per_partition)}")

  0%|          | 0/16 [00:00<?, ?it/s]

534 : 214830047
2839: 221560310
1388: 210411937
2969: 225110990
3864: 230857967
58  : 222430617
2751: 223562157
4983: 226629966
2745: 222297603
483 : 210296546
4224: 226364975
742 : 209149153
3894: 230174659
2862: 224298112
4379: 227973104
3277: 219671421


np.float64(221601222.75)

In [7]:
import numpy as np
np.mean(tokens_per_partition)

np.float64(220234989.0)

In [3]:
x = next(it)

In [4]:
from gidd_easydel.sampler import BufferedPartitionSampler

sampler = BufferedPartitionSampler(
    ddf, K=16, random_state=0,
)

In [5]:
def generate_dataset():
    yield from sampler

In [6]:
from datasets import IterableDataset

ds = IterableDataset.from_generator(generate_dataset)

In [7]:
next(iter(ds))

{'tokens': array([ 24467,     80,  12081,   2093,   8333,     56,  32565,    194,
            27,   5924,   7274,    194,     27,  55970,   2851,   4695,
         66021,    370,    806,    341,    893,    361,    292,    791,
          6203,    253,   1483,    399,   1240,    291,   2899,   2937,
           484,    274,    864,   6203,    253,    250,     87,     84,
           922,    550,   3259,    287,   4483,    291,    274,    864,
           908,   6203,    253,    550,   1379,    308,    806,    341,
          4514,   9675,  15791,  38793,    484,   1592,   4728,   4149,
           593,   2937,    308,    806,    341,  58390,    361,    292,
          1302,    246,   4018,    426,    662,    340,   5320,    484,
           274,    421,    439,   3866,   2062,    308,    806,    341,
           893,    411,    399,   9760,   7623,    484,   3161,   5843,
          4234,   1483,   4296,    308,    806,    341,   4784,    697,
           292,    640,    654,    250,     87,     84

In [5]:
import numpy as np

def add_rand_key(pdf, *, seed=0, partition_info=None, as_uint64=True):
    """
    Add a random key column to a pandas DataFrame partition (pdf).
    Random stream is deterministic per partition: seed + partition_id.
    """
    pid = 0 if partition_info is None else partition_info["number"]
    rng = np.random.default_rng(seed + pid)

    if as_uint64:
        key = rng.integers(
            low=0,
            high=np.iinfo(np.uint64).max,
            size=len(pdf),
            dtype=np.uint64,
            endpoint=True,  # include max value
        )
        return pdf.assign(rand_key=key)
    else:
        # If you prefer float keys in [0, 1)
        return pdf.assign(rand_key=rng.random(len(pdf)))

# Tell Dask the dtype of the new column via meta
meta_uint64 = ddf._meta.assign(rand_key=np.uint64(0))

ddf_with_key = ddf.map_partitions(
    add_rand_key,
    seed=123456,                 # base seed (change per epoch if you like)
    as_uint64=True,              # set False to use float keys
    meta=meta_uint64,            # ensures correct dtype without a full compute
)

In [6]:
ddf_sorted = ddf_with_key.set_index(
    "rand_key",
    # npartitions=256,     # tune for your cluster
    shuffle="disk",       # scalable disk-based shuffle
    sort=True            # compute divisions via sampling
)

# 2) Ensure rows are ordered within each partition (set_index guarantees global
#    key ranges; this makes each partition locally sorted by the index)
ddf_sorted = ddf_sorted.map_partitions(lambda pdf: pdf.sort_index())

In [8]:
ddf_with_key.npartitions

28

In [7]:
ddf_sorted.npartitions

28

In [9]:

print(ddf_sorted.head(5))

                                                           tokens quality  \
rand_key                                                                    
2877534930248   [23889, 370, 17329, 370, 12944, 341, 719, 241,...    high   
2922431929578   [469, 12492, 194, 27, 289, 20729, 3876, 401, 6...    high   
6810374980414   [57460, 13718, 18564, 234, 194, 29, 469, 1810,...    high   
7745255807730   [323, 4255, 2789, 250, 87, 84, 13430, 12514, 1...    high   
11080897116137  [11569, 851, 9574, 370, 625, 361, 2459, 1165, ...    high   

                  kind   kind2  
rand_key                        
2877534930248   actual  actual  
2922431929578   actual  actual  
6810374980414   actual  actual  
7745255807730   actual  actual  
11080897116137  actual  actual  


In [4]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F


data_files = "/local/home/dvruette/nemotron_tokenized/data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/data/quality=high/kind=actual/kind2=actual/CC-MAIN-2020-*.parquet"

spark = (
    SparkSession.builder
    .appName("shuffle-by-rand-key")
    .config("spark.sql.adaptive.enabled", "true")  # helps Top-N, skew, etc.
    .getOrCreate()
)

# 1) Load and add a deterministic random key (float in [0,1))
df = spark.read.parquet(data_files)
df_with_key = df.withColumn("rand_key", F.rand(seed=123456))


25/08/13 23:37:06 WARN FileStreamSink: Assume no metadata directory. Error while looking for metadata directory in the path: /local/home/dvruette/nemotron_tokenized/data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/data/quality=high/kind=actual/kind2=actual/CC-MAIN-2020-*.parquet.
java.io.FileNotFoundException: File /local/home/dvruette/nemotron_tokenized/data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/data/quality=high/kind=actual/kind2=actual/CC-MAIN-2020-*.parquet does not exist
	at org.apache.hadoop.fs.RawLocalFileSystem.deprecatedGetFileStatus(RawLocalFileSystem.java:917)
	at org.apache.hadoop.fs.RawLocalFileSystem.getFileLinkStatusInternal(RawLocalFileSystem.java:1238)
	at org.apache.hadoop.fs.RawLocalFileSystem.getFileStatus(RawLocalFileSystem.java:907)
	at org.apache.hadoop.fs.FilterFileSystem.getFileStatus(FilterFileSystem.java:462)
	at org.apache.spark.sql.execution.streaming.FileStreamSink$.hasMetadata(FileStreamSink.scala:56)
	at org.apache.spark.sql.execution.datasources

In [None]:

# 2) Print a couple of rows in global key order.
# This uses Spark's TakeOrdered path (doesn't fully sort 1.5TB just to fetch 5 rows)
df_with_key.orderBy(F.col("rand_key").asc()).show(5, truncate=False)

# --- Optional: create a scalable globally range-ordered layout by key ---

# Range-partition by key and sort within each partition (good for full scans)
nparts = 256  # tune for your cluster
df_range_sorted = (
    df_with_key
    # .repartitionByRange(nparts, "rand_key")
    .sortWithinPartitions("rand_key")
)

# If you want to materialize this layout for future fast epochs:
# df_range_sorted.write.mode("overwrite").parquet("/path/to/bucketed_by_rand_key")

# If you still want to peek a few rows from this layout:
# (Top-N is still simplest/fastest for just a sample)
# df_range_sorted.orderBy(F.col("rand_key").asc()).show(5, truncate=False)

In [5]:
import pyarrow.parquet as pq

path = "/local/home/dvruette/nemotron_tokenized/data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/data/quality=high/kind=actual/kind2=actual/CC-MAIN-2023-40-part-00012.parquet"
pf = pq.ParquetFile(path)              # metadata only; cheap
md = pf.metadata

print(f"File: {path}")
print(f"Row groups: {md.num_row_groups}")
for i in range(md.num_row_groups):
    rg = md.row_group(i)
    nrows = rg.num_rows
    nbytes = rg.total_byte_size        # approximate physical size on disk
    print(f"  RG {i:>3}: rows={nrows:,}  ~{nbytes/1024/1024:.2f} MB")


File: /local/home/dvruette/nemotron_tokenized/data.commoncrawl.org/contrib/Nemotron/Nemotron-CC/data/quality=high/kind=actual/kind2=actual/CC-MAIN-2023-40-part-00012.parquet
Row groups: 1
  RG   0: rows=224,068  ~439.21 MB
