In [None]:
#@title Config

#@markdown HF Hub:
HF_TOKEN = "" #@param
DATASET_NAME = "maximuspowers/intp-class-small-pca-10-seperate" #@param
BASE_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" #@param
OUTPUT_DIR = "./llama-pattern-classifier"
PUSH_TO_HUB = True  #@param
OUTPUT_HUB_MODEL_NAME = "maximuspowers/llama-3.1-8b-interpreter-pca-10-separate" #@param

#@markdown Dataset:
MAX_LENGTH = 7000 #@param # this can be found in the hf dataset readme

## Set up

In [1]:
!pip install -q -U transformers peft bitsandbytes accelerate datasets trl wandb huggingface_hub

In [2]:
import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, DataCollatorForLanguageModeling, Trainer)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
from datasets import load_dataset
import json
from collections import defaultdict
import numpy as np
import pandas as pd
from huggingface_hub import login
import matplotlib.pyplot as plt

login(token=HF_TOKEN)

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB


## Training Prep

In [None]:
#@title Load dataset
dataset = load_dataset(DATASET_NAME, split="train")

# splits
dataset_split = dataset.train_test_split(test_size=0.2, seed=42) # 20% - held out for final eval
test_dataset = dataset_split["test"]
train_val_split = dataset_split["train"].train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]  # 72% of full dataset
eval_dataset = train_val_split["test"]    # 8% of full dataset (used during training for validation)

print(f"Total dataset size: {len(dataset)} examples")
print(f"Train: {len(train_dataset)} - Val: {len(eval_dataset)} - Test: {len(test_dataset)}")

In [6]:
#@title Format/Tokenize for Llama 3.1

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=HF_TOKEN)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

def format_and_tokenize(example): # for llama 3.1
    formatted_text = (
        f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
        f"{example["classification_prompt"]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        f"{example["classification_completion"]}<|eot_id|>"
    )
    tokenized = tokenizer(
        formatted_text,
        truncation=True,
        max_length=MAX_LENGTH,
        padding="max_length",
        return_tensors=None,
    )
    tokenized["labels"] = tokenized["input_ids"].copy() # casualLM uses input id for labels
    return tokenized

train_dataset = train_dataset.map(
    format_and_tokenize,
    remove_columns=train_dataset.column_names,
    desc="Formatting and tokenizing train dataset"
)
eval_dataset = eval_dataset.map(
    format_and_tokenize,
    remove_columns=eval_dataset.column_names,
    desc="Formatting and tokenizing eval dataset"
)

Formatting and tokenizing eval dataset:   0%|          | 0/473 [00:00<?, ? examples/s]


=== Tokenized Example ===
Input IDs length: 7000
First 10 tokens: [128000, 128000, 128006, 882, 128007, 271, 567, 5008, 38943, 198]


In [7]:
#@title Load model + config qlora

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    token=HF_TOKEN,
    trust_remote_code=True,
    dtype=torch.bfloat16,
)
model = prepare_model_for_kbit_training(model)

peft_config = LoraConfig(
    r=64,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Loading meta-llama/Llama-3.1-8B-Instruct with 4-bit quantization...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded successfully!


In [9]:
#@title Training config

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4, # effective batch = per device * grad steps
    num_train_epochs=3,
    learning_rate=2e-5,
    bf16=True,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True,
    optim="paged_adamw_8bit",
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    report_to="none",
)

print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Total training steps: ~{len(train_dataset) * training_args.num_train_epochs // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)}")

Effective batch size: 8
Total training steps: ~1596


## Start Training

In [10]:
#@title Create Trainer

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

Trainer initialized. Starting training...


In [None]:
#@title Train

trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss
50,1.6523,1.599219
100,1.5154,1.505608
150,1.5019,1.497993
200,1.4961,1.489805
250,1.5456,1.48274


In [None]:
#@title Save model

trainer.model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
if PUSH_TO_HUB:
    trainer.model.push_to_hub(OUTPUT_HUB_MODEL_NAME, token=HF_TOKEN)
    tokenizer.push_to_hub(OUTPUT_HUB_MODEL_NAME, token=HF_TOKEN)

## Evaluate

In [None]:
#@title Load model with LoRA

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    token=HF_TOKEN,
    torch_dtype=torch.bfloat16,
)
inference_model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
inference_model.eval()

In [None]:
#@title Prompt Formatter

def generate_prediction(prompt, model, tokenizer, max_new_tokens=100):
    formatted_prompt = (
        f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
        f"{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return generated_text.strip()

test_example = test_dataset[0]
test_prompt = test_example['classification_prompt']
test_label = test_example['classification_completion']
prediction = generate_prediction(test_prompt, inference_model, tokenizer)

print(f"Test Prompt: {test_prompt}")
print(f"Prediction: {prediction}")

In [None]:
#@title Eval Fn

def parse_patterns(text):
    text = text.strip().lower()
    patterns = [p.strip() for p in text.split(',')]
    return set(patterns)

def evaluate_on_test_set(test_dataset, model, tokenizer, max_samples=None):
    results = {
        'predictions': [],
        'ground_truth': [],
        'correct': 0,
        'total': 0,
        'pattern_stats': defaultdict(lambda: {'correct': 0, 'total': 0})
    }
    samples_to_eval = min(max_samples or len(test_dataset), len(test_dataset))

    for i in range(samples_to_eval):
        example = test_dataset[i]
        prompt = example['classification_prompt']
        ground_truth = example['classification_completion']
        prediction = generate_prediction(prompt, model, tokenizer, max_new_tokens=150)
        pred_patterns = parse_patterns(prediction)
        true_patterns = parse_patterns(ground_truth)
        results['predictions'].append(prediction)
        results['ground_truth'].append(ground_truth)
        results['total'] += 1
        if pred_patterns == true_patterns:
            results['correct'] += 1
        for pattern in true_patterns:
            results['pattern_stats'][pattern]['total'] += 1
            if pattern in pred_patterns:
                results['pattern_stats'][pattern]['correct'] += 1
        if (i + 1) % 10 == 0:
            print(f"Progress: {i + 1}/{samples_to_eval} (Accuracy so far: {results['correct']/(i+1):.2%})")

    return results

print("Starting evaluation")
eval_results = evaluate_on_test_set(test_dataset, inference_model, tokenizer, max_samples=100)
print(f"Overall Accuracy: {eval_results['correct']}/{eval_results['total']} ({eval_results['correct']/eval_results['total']:.2%})")

## Results Visualization

In [None]:
#@title Per Pattern Accuracy

pattern_results = []
for pattern, stats in sorted(eval_results['pattern_stats'].items()):
    accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
    pattern_results.append({
        'Pattern': pattern,
        'Correct': stats['correct'],
        'Total': stats['total'],
        'Accuracy': f"{accuracy:.2%}"
    })

df = pd.DataFrame(pattern_results)
print(df.to_string(index=False))

In [None]:
#@title Example Predictions

print("\n=== Example Predictions ===")
for i in range(min(5, len(eval_results['predictions']))):
    print(f"\nExample {i+1}:")
    print(f"Ground Truth: {eval_results['ground_truth'][i]}")
    print(f"Prediction:   {eval_results['predictions'][i]}")
    match = "✓" if parse_patterns(eval_results['predictions'][i]) == parse_patterns(eval_results['ground_truth'][i]) else "✗"
    print(f"Match: {match}")

In [None]:
#@title Plot Results

patterns = [r['Pattern'] for r in pattern_results]
accuracies = [float(r['Accuracy'].strip('%'))/100 for r in pattern_results]

plt.figure(figsize=(12, 6))
plt.bar(range(len(patterns)), accuracies)
plt.xticks(range(len(patterns)), patterns, rotation=45, ha='right')
plt.ylabel('Accuracy')
plt.title('Pattern Classification Accuracy')
plt.ylim(0, 1.0)
plt.axhline(y=eval_results['correct']/eval_results['total'], color='r', linestyle='--', label='Overall Accuracy')
plt.legend()
plt.tight_layout()
plt.show()