In [None]:
import json
import random
from sklearn.model_selection import train_test_split
import os

In [None]:
def load_dataset(file_path):
    """Load the synthetic email dataset from a JSON file."""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return data
    except FileNotFoundError:
        print(f"Error: File {file_path} not found.")
        return []
    except json.JSONDecodeError:
        print(f"Error: Invalid JSON format in {file_path}.")
        return []

In [None]:
def save_dataset(data, file_path):
    """Save dataset to a JSON file."""
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2)
    print(f"Saved dataset to {file_path}")

In [None]:
def split_dataset(input_file, output_dir, train_size=140, val_size=30, seed=42):
    """
    Split dataset into train, validation, and test sets with exact sizes.
    Args:
        input_file (str): Path to synthetic_email_data.json
        output_dir (str): Directory to save split files
        train_size (int): Number of chains for training set (default: 140)
        val_size (int): Number of chains for validation set (default: 30)
        seed (int): Random seed for reproducibility
    """
    # Load dataset
    data = load_dataset(input_file)
    if not data:
        return

    # Verify total size
    total_chains = len(data)
    test_size = total_chains - train_size - val_size
    if total_chains < train_size + val_size:
        print(f"Error: Dataset has {total_chains} chains, but requested {train_size} train + {val_size} val.")
        return
    if test_size <= 0:
        print(f"Error: Invalid split sizes (train: {train_size}, val: {val_size}, test: {test_size}).")
        return

    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Set random seed for reproducibility
    random.seed(seed)

    # Shuffle data
    shuffled_data = data.copy()
    random.shuffle(shuffled_data)

    # Split data into exact sizes
    train_data = shuffled_data[:train_size]
    val_data = shuffled_data[train_size:train_size + val_size]
    test_data = shuffled_data[train_size + val_size:]

    # Verify split sizes
    print(f"Total chains: {total_chains}")
    print(f"Training set: {len(train_data)} chains")
    print(f"Validation set: {len(val_data)} chains")
    print(f"Test set: {len(test_data)} chains")

    # Save splits
    save_dataset(train_data, os.path.join(output_dir, 'train.json'))
    save_dataset(val_data, os.path.join(output_dir, 'val.json'))
    save_dataset(test_data, os.path.join(output_dir, 'test.json'))

In [None]:
# Configuration
input_file = "synthetic_email_data.json"  # Path to your dataset
output_dir = "dataset_splits"  # Directory to save splits
train_size = 140  # 70% of 200
val_size = 30    # 15% of 200
seed = 42        # Random seed for reproducibility

In [None]:
# Run splitting
split_dataset(input_file, output_dir, train_size, val_size, seed)

Total chains: 200
Training set: 140 chains
Validation set: 30 chains
Test set: 30 chains
Saved dataset to dataset_splits/train.json
Saved dataset to dataset_splits/val.json
Saved dataset to dataset_splits/test.json
