# Math Dataset Preparation

This notebook prepares the OpenMathInstruct-2 dataset for CHG training with few-shot prompting.

In [None]:
import os
import torch
import yaml
from pathlib import Path
from tqdm.auto import tqdm

from transformers import AutoTokenizer
from causal_head_gating.data import load_math_dataset, MaskedSequenceDataset

In [None]:
# Load config (resolve paths relative to config file location)
config_path = Path("../config.yaml")
with open(config_path, "r") as f:
    config = yaml.safe_load(f)
config_dir = config_path.parent.resolve()
directories = {k: (config_dir / v).resolve() for k, v in config['directories'].items()}
os.environ['HF_HOME'] = str(directories['huggingface'])

In [None]:
# Process for each model
model_names = [
    'meta-llama/Llama-3.2-3B',
    'meta-llama/Llama-3.2-1B',
    'meta-llama/Llama-3.1-8B',
]

for model_name in tqdm(model_names):
    print(f"\nProcessing {model_name}")
    
    save_dir = directories['save'] / f'datasets/math/{model_name}'
    if (save_dir / 'train.pt').exists():
        print(f"Already exists, skipping")
        continue
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Load and prepare dataset (full workflow in one call)
    df_prompts, input_ids, loss_masks = load_math_dataset(
        tokenizer=tokenizer,
        num_examples=50,
        num_train=50000,
        num_validation=5000,
        lengths_cache_path=str(directories['save'] / 'datasets/math/lengths_cache.parquet'),
        verbose=True,
    )
    
    # Save prompts and datasets
    save_dir.mkdir(parents=True, exist_ok=True)
    df_prompts.to_parquet(save_dir / 'prompts.parquet')
    
    for split in ['train', 'validation']:
        split_mask = df_prompts['split'] == split
        torch.save({
            'input_ids': input_ids[split_mask.values],
            'loss_masks': loss_masks[split_mask.values],
        }, save_dir / f'{split}.pt')
    
    print(f"Saved to {save_dir}")