# Functions

In [None]:
import pyarrow.parquet as pq
import os
from streaming.base.format.mds.writer import MDSWriter
from pathlib import Path
import pandas as pd

def parquet_to_mds_modernbert(
    parquet_file: str,
    output_dir: str,
    split: str,
    chunk_size: int,
    compression: str = None,
    text_column: str = "text",
    max_samples: int = None
):
    """
    Convert parquet files to MDS format for ModernBERT training.
    
    Args:
        parquet_file: Path to input parquet file
        output_dir: Directory where MDS files will be saved
        split: Split name (train/val/test)
        chunk_size: Number of samples per batch for processing
        compression: Compression type (None, 'snappy', 'gzip', 'brotli', 'lz4', 'zstd')
        text_column: Name of the text column in parquet
        max_samples: Maximum number of samples to convert (None for all)
    """
    split_dir = os.path.join(output_dir, split)
    os.makedirs(split_dir, exist_ok=True)
    
    # Define columns as expected by ModernBERT
    columns = {"text": "str"}
    
    print(f"🔄 Converting {parquet_file} to MDS format...")
    print(f"📁 Output directory: {split_dir}")
    print(f"📦 Compression: {compression}")
    
    with MDSWriter(columns=columns, out=split_dir, compression=compression, exist_ok=True) as writer:
        pf = pq.ParquetFile(parquet_file)
        total_batches = pf.metadata.num_row_groups
        processed_samples = 0
        
        for batch_idx, batch in enumerate(pf.iter_batches(batch_size=chunk_size, columns=[text_column])):
            texts = batch.column(text_column).to_pylist()
            
            for txt in texts:
                if txt is not None and str(txt).strip():  # Skip empty/null texts
                    writer.write({"text": str(txt)})
                    processed_samples += 1
                    
                    if max_samples and processed_samples >= max_samples:
                        break
            
            if batch_idx % 10 == 0:
                print(f"📊 Processed {batch_idx + 1}/{total_batches} batches, {processed_samples} samples")
                
            if max_samples and processed_samples >= max_samples:
                break
    
    print(f"✅ Conversion complete!")
    print(f"📈 Total samples: {processed_samples}")
    print(f"📂 MDS files saved to: {split_dir}")
    return processed_samples


In [None]:
def convert_multiple_parquets(
    parquet_files: list,
    output_dir: str,
    split: str,
    chunk_size: int,
    compression: str = None,
    text_column: str = "text"
):
    """Convert multiple parquet files to a single MDS dataset."""
    print(f"🚀 Converting {len(parquet_files)} parquet files to MDS...")
    
    split_dir = os.path.join(output_dir, split)
    os.makedirs(split_dir, exist_ok=True)
    columns = {"text": "str"}
    
    total_samples = 0
    
    with MDSWriter(columns=columns, out=split_dir, compression=compression, shard_size=writer_shard_size, exist_ok=True) as writer:
        for file_idx, parquet_file in enumerate(parquet_files):
            print(f"\n📄 Processing file {file_idx + 1}/{len(parquet_files)}: {parquet_file}")
            
            pf = pq.ParquetFile(parquet_file)
            file_samples = 0
            
            for batch_idx, batch in enumerate(pf.iter_batches(batch_size=chunk_size, columns=[text_column])):
                texts = batch.column(text_column).to_pylist()
                
                for txt in texts:
                    if txt is not None and str(txt).strip():
                        writer.write({"text": str(txt)})
                        file_samples += 1
                        total_samples += 1
                
                if batch_idx % 10 == 0:
                    print(f"   Batch {batch_idx + 1}, samples: {file_samples}")
            
            print(f"   ✅ File complete: {file_samples} samples")
    
    print(f"\n🎉 All files converted!")
    print(f"📈 Total samples: {total_samples}")
    print(f"📂 MDS dataset saved to: {split_dir}")
    return total_samples

In [None]:
def validate_mds_dataset(mds_path: str, split: str = "train", sample_count: int = 5):
    """Validate the created MDS dataset by reading some samples."""
    from streaming.base.format import reader_from_json
    import json
    
    split_dir = os.path.join(mds_path, split)
    index_file = os.path.join(split_dir, "index.json")
    
    if not os.path.exists(index_file):
        print(f"❌ Index file not found: {index_file}")
        return False
    
    with open(index_file, 'r') as f:
        index_data = json.load(f)
    
    print(f"📋 Dataset info:")
    print(f"   Shards: {len(index_data['shards'])}")
    print(f"   Version: {index_data.get('version', 'unknown')}")
    
    # Read some samples
    shard_info = index_data['shards'][0]
    shard = reader_from_json(mds_path, split, shard_info)
    
    print(f"\n📖 Sample texts:")
    for i in range(min(sample_count, shard.samples)):
        sample = shard[i]
        text = sample['text'][:100] + "..." if len(sample['text']) > 100 else sample['text']
        print(f"   Sample {i}: {text}")
    
    return True

# Convertions

In [None]:
# Single parquet file conversion
samples = parquet_to_mds_modernbert(
    parquet_file="data/sentences_cleaned.parquet",
    output_dir="data/sentences_mds/",
    split="train", # "val"
    chunk_size=1_000_000,
    compression=None,  # or "zstd" for compression
    text_column="text"
)

In [None]:
# Multiple parquet files
parquet_files = [
    "data/file1.parquet",
    "data/file2.parquet", 
    "data/file3.parquet"
]

In [None]:
convert_multiple_parquets(
    parquet_files=parquet_files,
    output_dir="data/combined_mds/",
    split="train",
    compression="zstd"  # Recommended for large datasets
)

# Validate the dataset
validate_mds_dataset("data/sentences_mds/", split="train")