# CHG Example

This notebook demonstrates Causal Head Gating (CHG) analysis using the `causal_head_gating` package.

In [None]:
import os
import torch
import pandas as pd
import yaml
from pathlib import Path

from transformers import AutoModelForCausalLM, AutoTokenizer
from causal_head_gating import CHGTrainer
from causal_head_gating.data import MaskedSequenceDataset
from causal_head_gating.utils import to_long_df

In [None]:
# Load config
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)
directories = {k: Path(v) for k, v in config['directories'].items()}
os.environ['HF_HOME'] = str(directories['huggingface'])

In [None]:
# Initialize model
device = 0
model_name = 'meta-llama/Llama-3.2-3B-Instruct'

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.eval()
for param in model.parameters():
    param.requires_grad = False

In [None]:
# Load dataset
dataset_path = directories['save'] / f'datasets/aba_abb/{model_name}/train.pt'
data = torch.load(dataset_path)
if 'text_tokens' in data:
    data['input_ids'] = data.pop('text_tokens')
dataset = MaskedSequenceDataset(tokenizer.pad_token_id, **data).to(device)

In [None]:
# Train CHG
trainer = CHGTrainer(model, dataset, gradient_accum_steps=2)

masks, metrics = [], []
for mask, metric in trainer.fit(num_updates=500, num_reg_updates=500, verbose=True):
    masks.append(mask)
    metrics.append(metric)

In [None]:
# Analyze results
masks = torch.stack(masks)
masks = masks.view(3, -1, masks.shape[-2], masks.shape[-1])
df_metrics = pd.DataFrame(metrics)

df = to_long_df(masks, ['regularization', 'step', 'layer_idx', 'head_idx'])
df