# Notebook with Example Code for how to Convert Pretraining Data to Arrow

**NOTE: You do not need to read, understand, or use this notebook. It is provided entirely for reference.**

This notebook converts a JSONL.gz dataset to PyArrow format with **optional fixed-length sequences** for efficient GPT training.

## Key Features

- ‚úÖ **Optional sequence length** - set to None for variable-length sequences (no padding)
- ‚úÖ **Memory efficient** streaming processing
- ‚úÖ **Ready to run** from start to finish
- ‚úÖ **HuggingFace compatible** output format

## What it does

1. **Streams** JSONL.gz file without loading into memory
2. **Tokenizes** text using GPT-2 tokenizer
3. **Creates chunks** for causal language modeling:
   - **Fixed-length**: All sequences exactly `SEQUENCE_LENGTH` tokens
   - **Variable-length**: Each document becomes one sequence (no padding)
4. **Saves** as HuggingFace dataset format

**Note:** _If_ you ever run this, you probably want to use the variable-length 


In [None]:
# =============================================================================
# HYPERPARAMETERS - MODIFY THESE AS NEEDED
# =============================================================================

# Fixed sequence length for all generated sequences
# Set to None to use variable-length sequences (no padding)
SEQUENCE_LENGTH = None  # Change this to your desired sequence length, or None for variable length

# Data paths
DATA_DIR = "data/"
INPUT_FILE = "fineweb-edu-sample-10B.jsonl.gz"
OUTPUT_DIR = "fineweb-edu-sample-10B"

# Processing parameters
BATCH_SIZE = 1000  # Documents to process at once
MAX_CHUNKS_PER_SHARD = 10000  # Chunks per output file

print(f"üîß Configuration:")
if SEQUENCE_LENGTH is None:
    print(f"   Sequence length: Variable (no padding)")
else:
    print(f"   Sequence length: {SEQUENCE_LENGTH}")
print(f"   Input file: {DATA_DIR + INPUT_FILE}")
print(f"   Output directory: {OUTPUT_DIR}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Max chunks per shard: {MAX_CHUNKS_PER_SHARD}")


## Setup and Imports


In [None]:
import os
import json
import gzip
from typing import List, Dict, Any, Iterator
import tempfile

import pyarrow as pa
import pyarrow.parquet as pq
from datasets import Dataset, concatenate_datasets
from transformers import AutoTokenizer
from tqdm.auto import tqdm

print("‚úÖ All imports successful")


## Initialize Tokenizer


In [None]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Add special tokens if they don't exist
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": "<|pad|>"})

special_tokens_dict = {
    "additional_special_tokens": ["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"]
}
tokenizer.add_special_tokens(special_tokens_dict)

print(f"‚úÖ Tokenizer initialized")
print(f"   Vocab size: {tokenizer.vocab_size}")
print(f"   Special tokens: {tokenizer.special_tokens_map}")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"‚úÖ Output directory created: {OUTPUT_DIR}")


## Core Processing Functions


In [None]:
def stream_jsonl_file(file_path: str, batch_size: int) -> Iterator[List[Dict[str, Any]]]:
    """Stream a JSONL.gz file in batches without loading everything into memory."""
    batch = []

    if file_path.endswith('.gz'):
        file_handle = gzip.open(file_path, 'rt', encoding='utf-8')
    else:
        file_handle = open(file_path, 'r', encoding='utf-8')

    try:
        for line_num, line in enumerate(file_handle, 1):
            line = line.strip()
            if not line:
                continue

            try:
                doc = json.loads(line)
                if 'text' in doc:
                    batch.append({
                        'text': doc['text'],
                        'doc_id': line_num - 1
                    })

                if len(batch) >= batch_size:
                    yield batch
                    batch = []

            except json.JSONDecodeError as e:
                print(f"Warning: Skipping malformed JSON at line {line_num}: {e}")
                continue

        if batch:
            yield batch

    finally:
        file_handle.close()


def tokenize_and_chunk_documents(documents: List[Dict[str, Any]],
                                tokenizer: AutoTokenizer,
                                sequence_length: int = None) -> List[Dict[str, Any]]:
    """Tokenize documents and create chunks for causal language modeling.

    Args:
        documents: List of documents with 'text' and 'doc_id' fields
        tokenizer: Tokenizer to use
        sequence_length: Fixed length for chunks, or None for variable-length sequences
    """
    chunks = []

    for doc in documents:
        text = doc['text']
        doc_id = doc['doc_id']

        if not text or not text.strip():
            continue

        try:
            token_ids = tokenizer.encode(text, add_special_tokens=False)
        except Exception as e:
            print(f"Warning: Failed to tokenize document {doc_id}: {e}")
            continue

        if sequence_length is None:
            # Variable-length mode: use entire document as one sequence
            if len(token_ids) > 1:  # Need at least 2 tokens for input/labels
                input_chunk = token_ids[:-1]  # All except last token
                target_chunk = token_ids[1:]   # All except first token

                chunks.append({
                    'input_ids': input_chunk,
                    'labels': target_chunk,
                    'doc_id': doc_id,
                    'chunk_start': 0,
                    'chunk_end': len(token_ids)
                })
        else:
            # Fixed-length mode: create chunks of exactly sequence_length
            if len(token_ids) < sequence_length:
                continue

            # Create non-overlapping chunks of exactly sequence_length
            for i in range(0, len(token_ids) - sequence_length + 1, sequence_length):
                input_chunk = token_ids[i:i + sequence_length]
                target_chunk = token_ids[i + 1:i + sequence_length + 1]

                # Only keep chunks that are exactly the right length
                if len(input_chunk) == sequence_length and len(target_chunk) == sequence_length:
                    chunks.append({
                        'input_ids': input_chunk,
                        'labels': target_chunk,
                        'doc_id': doc_id,
                        'chunk_start': i,
                        'chunk_end': i + sequence_length
                    })

    return chunks


def create_arrow_dataset_from_chunks(chunks: List[Dict[str, Any]],
                                   output_path: str,
                                   shard_index: int) -> None:
    """Convert chunks to PyArrow format and save as a Parquet file."""
    if not chunks:
        return

    # Convert to PyArrow Table
    table_data = {
        'input_ids': [chunk['input_ids'] for chunk in chunks],
        'labels': [chunk['labels'] for chunk in chunks],
        'doc_id': [chunk['doc_id'] for chunk in chunks],
        'chunk_start': [chunk['chunk_start'] for chunk in chunks],
        'chunk_end': [chunk['chunk_end'] for chunk in chunks]
    }

    # Create Arrow schema
    schema = pa.schema([
        ('input_ids', pa.list_(pa.int64())),
        ('labels', pa.list_(pa.int64())),
        ('doc_id', pa.int64()),
        ('chunk_start', pa.int64()),
        ('chunk_end', pa.int64())
    ])

    # Create Arrow table and save as Parquet
    table = pa.table(table_data, schema=schema)
    shard_filename = f"shard_{shard_index:06d}.parquet"
    shard_path = os.path.join(output_path, shard_filename)

    pq.write_table(table, shard_path, compression='snappy')
    print(f"‚úÖ Saved shard {shard_index} with {len(chunks)} chunks")


print("‚úÖ Core processing functions defined")


## Main Conversion Function


In [None]:
def convert_dataset_to_arrow(input_path: str,
                           output_dir: str,
                           tokenizer: AutoTokenizer,
                           batch_size: int,
                           max_chunks_per_shard: int,
                           sequence_length: int = None) -> None:
    """Convert the entire dataset to Arrow format with optional fixed-length sequences."""
    if sequence_length is None:
        print(f"üöÄ Starting conversion with variable-length sequences (no padding)")
    else:
        print(f"üöÄ Starting conversion with sequence length: {sequence_length}")
    print(f"üìÅ Input: {input_path}")
    print(f"üìÅ Output: {output_dir}")

    # Initialize counters
    total_docs_processed = 0
    total_chunks_created = 0
    shard_index = 0
    current_shard_chunks = []

    # First pass: count total batches for progress bar
    print("üìä Counting total batches...")
    total_batches = 0
    for _ in stream_jsonl_file(input_path, batch_size):
        total_batches += 1

    print(f"üìä Found {total_batches} batches to process")

    # Process the file in batches with progress bar
    batch_progress = tqdm(
        stream_jsonl_file(input_path, batch_size),
        total=total_batches,
        desc="Processing batches",
        unit="batch"
    )

    for batch_idx, doc_batch in enumerate(batch_progress):
        # Tokenize and chunk this batch
        batch_chunks = tokenize_and_chunk_documents(
            doc_batch, tokenizer, sequence_length
        )

        # Add chunks to current shard
        current_shard_chunks.extend(batch_chunks)

        # Update counters
        total_docs_processed += len(doc_batch)
        total_chunks_created += len(batch_chunks)

        # Update progress bar description
        batch_progress.set_postfix({
            'docs': f"{total_docs_processed:,}",
            'chunks': f"{total_chunks_created:,}",
            'shards': shard_index
        })

        # Save shard when it reaches max size
        if len(current_shard_chunks) >= max_chunks_per_shard:
            create_arrow_dataset_from_chunks(
                current_shard_chunks, output_dir, shard_index
            )
            current_shard_chunks = []
            shard_index += 1

    # Close progress bar
    batch_progress.close()

    # Save final shard
    if current_shard_chunks:
        create_arrow_dataset_from_chunks(
            current_shard_chunks, output_dir, shard_index
        )
        shard_index += 1

    # Create HuggingFace dataset from all shards
    print(f"\nüîÑ Creating HuggingFace dataset from {shard_index} shards...")
    create_huggingface_dataset(output_dir, shard_index)

    # Save metadata
    metadata = {
        "total_documents": total_docs_processed,
        "total_chunks": total_chunks_created,
        "total_shards": shard_index,
        "sequence_length": sequence_length,
        "tokenizer_name": "gpt2",
        "vocab_size": tokenizer.vocab_size
    }

    metadata_path = os.path.join(output_dir, "dataset_info.json")
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)

    print(f"\nüéâ Conversion completed!")
    print(f"üìä Final Statistics:")
    print(f"   Total documents: {total_docs_processed:,}")
    print(f"   Total chunks: {total_chunks_created:,}")
    if sequence_length is None:
        print(f"   Sequence length: Variable (no padding)")
    else:
        print(f"   Sequence length: {sequence_length}")
    print(f"   Total shards: {shard_index}")
    print(f"\n‚úÖ Dataset ready for loading with:")
    print(f"   from datasets import load_from_disk")
    print(f"   dataset = load_from_disk('{output_dir}/hf_dataset')")


def create_huggingface_dataset(output_dir: str, num_shards: int) -> None:
    """Create a HuggingFace dataset from the Parquet shards."""
    datasets = []

    # Use tqdm for shard loading progress
    shard_progress = tqdm(range(num_shards), desc="Loading shards", unit="shard")

    for shard_idx in shard_progress:
        shard_filename = f"shard_{shard_idx:06d}.parquet"
        shard_path = os.path.join(output_dir, shard_filename)

        if os.path.exists(shard_path):
            shard_dataset = Dataset.from_parquet(shard_path)
            datasets.append(shard_dataset)
            shard_progress.set_postfix({'loaded': len(datasets)})

    shard_progress.close()

    if datasets:
        print(f"   Concatenating {len(datasets)} shards...")
        full_dataset = concatenate_datasets(datasets)

        # Save as HuggingFace dataset format
        hf_dataset_path = os.path.join(output_dir, "hf_dataset")
        full_dataset.save_to_disk(hf_dataset_path)

        print(f"‚úÖ HuggingFace dataset saved: {hf_dataset_path}")
        print(f"   Dataset size: {len(full_dataset):,} samples")
        print(f"   Columns: {full_dataset.column_names}")


print("‚úÖ Main conversion functions defined")


## Run the Conversion


In [None]:
# Run the conversion with your hyperparameters
input_path = DATA_DIR + INPUT_FILE

print(f"üöÄ Starting conversion...")
if SEQUENCE_LENGTH is None:
    print(f"   Sequence length: Variable (no padding)")
else:
    print(f"   Sequence length: {SEQUENCE_LENGTH}")
print(f"   Input file: {input_path}")
print(f"   Output directory: {OUTPUT_DIR}")

# Check if input file exists
if not os.path.exists(input_path):
    print(f"‚ùå Input file not found: {input_path}")
    print("Please update the DATA_DIR and INPUT_FILE variables at the top of the notebook")
else:
    # Run the conversion
    convert_dataset_to_arrow(
        input_path=input_path,
        output_dir=OUTPUT_DIR,
        tokenizer=tokenizer,
        batch_size=BATCH_SIZE,
        max_chunks_per_shard=MAX_CHUNKS_PER_SHARD,
        sequence_length=SEQUENCE_LENGTH
    )


## Load and Test the Dataset


In [None]:
# Load and test the converted dataset
hf_dataset_path = os.path.join(OUTPUT_DIR, "hf_dataset")

if os.path.exists(hf_dataset_path):
    print(f"üìä Loading dataset from: {hf_dataset_path}")

    # Load the dataset
    dataset = Dataset.load_from_disk(hf_dataset_path)

    print(f"‚úÖ Dataset loaded successfully!")
    print(f"   Total samples: {len(dataset):,}")
    print(f"   Columns: {dataset.column_names}")

    # Test a few samples
    print(f"\nüß™ Testing samples:")
    for i in range(min(3, len(dataset))):
        sample = dataset[i]
        input_ids = sample['input_ids']
        labels = sample['labels']

        print(f"\n   Sample {i + 1}:")
        print(f"   Doc ID: {sample['doc_id']}")
        print(f"   Input length: {len(input_ids)}")
        print(f"   Label length: {len(labels)}")
        print(f"   First 10 input tokens: {input_ids[:10]}")
        print(f"   First 10 label tokens: {labels[:10]}")

        # Decode a portion of the text
        try:
            decoded_text = tokenizer.decode(input_ids[:50])
            print(f"   Decoded preview: {decoded_text[:100]}...")
        except Exception as e:
            print(f"   Decode error: {e}")

    print(f"\nüéâ Dataset is ready for training!")
    print(f"   Use: dataset = Dataset.load_from_disk('{hf_dataset_path}')")

else:
    print(f"‚ùå Dataset not found at: {hf_dataset_path}")
    print("   Run the conversion cell above first")
