# ABA/ABB Dataset Preparation

This notebook prepares the ABA/ABB pattern dataset for CHG training.

In [None]:
import os
import torch
import yaml
from pathlib import Path
from tqdm.auto import tqdm

from transformers import AutoTokenizer
from causal_head_gating import CHGDataset
from causal_head_gating.data import get_aba_abb_path

In [None]:
# Load config (resolve paths relative to config file location)
config_path = Path("../config.yaml")
with open(config_path, "r") as f:
    config = yaml.safe_load(f)
config_dir = config_path.parent.resolve()
directories = {k: (config_dir / v).resolve() for k, v in config['directories'].items()}
os.environ['HF_HOME'] = str(directories['huggingface'])

In [None]:
# Process for each model
model_names = [
    'meta-llama/Llama-3.2-3B-Instruct',
    'meta-llama/Llama-3.2-3B',
    'meta-llama/Llama-3.2-1B',
    'meta-llama/Llama-3.1-8B',
]

# Download dataset from HuggingFace (cached after first download)
data_path = get_aba_abb_path()

for model_name in tqdm(model_names):
    print(f"Tokenizing {model_name}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Load and tokenize dataset (last_token_only=True matches original behavior)
    dataset = CHGDataset.from_tsv(
        str(data_path),
        tokenizer=tokenizer,
        prompt_column="prompt",
        target_column="target",
        last_token_only=True,
    )
    
    # Save
    save_path = directories['save'] / f'datasets/aba_abb/{model_name}/train.pt'
    save_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save({
        'input_ids': dataset['input_ids'],
        'loss_masks': dataset['loss_masks'],
    }, save_path)
    print(f"Saved to {save_path}")