<a href="https://colab.research.google.com/github/b1becker/LLM_steering/blob/main/RePS_training_6_18.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!pip install torch torchvision torchaudio transformers huggingface_hub pandas numpy pyyaml requests pathlib2
!git clone https://github.com/stanfordnlp/axbench.git
!pip install -e axbench
!pip install -e .


fatal: destination path 'axbench' already exists and is not an empty directory.
Obtaining file:///content/axbench
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting adjusttext>=1.3.0 (from axbench==0.1.0)
  Downloading adjustText-1.3.0-py3-none-any.whl.metadata (3.1 kB)
Collecting asyncio>=3.4.3 (from axbench==0.1.0)
  Downloading asyncio-3.4.3-py3-none-any.whl.metadata (1.7 kB)
Collecting datasets>=3.0.2 (from axbench==0.1.0)
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting jupyter>=1.1.1 (from axbench==0.1.0)
  Downloading jupyter-1.1.1-py2.py3-none-any.whl.metadata (2.0 kB)
Collecting pyreft>=0.0.8 (from axbench==0.1.0)
  Downloading pyreft-0.1.0-py3-none-any.whl.metadata (17 kB)
Collecting p

In [10]:
#!/usr/bin/env python3
"""
Robust Google Colab AxBench Training Script
Handles installation failures gracefully with multiple fallback options.
"""

import os
import sys
import subprocess
import json
import pickle
import torch
import datetime
import pandas as pd
import numpy as np
import logging
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Any, Optional, List

In [13]:
# Robust setup function that handles failures
def robust_setup_colab():
    """Robust setup that handles installation failures gracefully."""
    print("🚀 Setting up AxBench environment for Google Colab...")

    # First, install core dependencies
    core_packages = [
        "torch", "transformers", "huggingface_hub",
        "pandas", "numpy", "pyyaml", "requests"
    ]

    print("📦 Installing core packages...")
    for package in core_packages:
        try:
            __import__(package)
            print(f"✓ {package} already available")
        except ImportError:
            try:
                print(f"Installing {package}...")
                # Use run() instead of check_call() for capture_output
                result = subprocess.run([sys.executable, "-m", "pip", "install", package],
                                       capture_output=True, text=True)
                if result.returncode == 0:
                    print(f"✓ {package} installed successfully")
                else:
                    print(f"⚠️ Failed to install {package}: {result.stderr}")
            except Exception as e:
                print(f"⚠️ Error installing {package}: {e}")

    # Clone repositories with error handling
    repos = [
        ("axbench", "https://github.com/stanfordnlp/axbench.git"),
        ("pyreft", "https://github.com/stanfordnlp/pyreft.git"),
        ("pyvene", "https://github.com/stanfordnlp/pyvene.git")
    ]

    print("\n📂 Setting up repositories...")
    cloned_repos = []
    for repo_name, repo_url in repos:
        try:
            if not os.path.exists(repo_name):
                print(f"Cloning {repo_name}...")
                result = subprocess.run(["git", "clone", repo_url],
                                      capture_output=True, text=True, timeout=300)
                if result.returncode == 0:
                    print(f"✓ {repo_name} cloned successfully")
                    cloned_repos.append(repo_name)
                else:
                    print(f"⚠️ Failed to clone {repo_name}: {result.stderr}")
            else:
                print(f"✓ {repo_name} already exists")
                cloned_repos.append(repo_name)
        except Exception as e:
            print(f"⚠️ Error with {repo_name}: {e}")

    # Try to install packages, but don't fail if they don't work
    print("\n🔧 Attempting package installations...")
    installed_packages = []
    for repo_name in cloned_repos:
        try:
            if os.path.exists(repo_name):
                print(f"Installing {repo_name}...")
                result = subprocess.run([sys.executable, "-m", "pip", "install", "-e", f"./{repo_name}"],
                                      capture_output=True, text=True, timeout=300)
                if result.returncode == 0:
                    print(f"✓ {repo_name} installed successfully")
                    installed_packages.append(repo_name)
                else:
                    print(f"⚠️ Installation failed for {repo_name}")
                    print(f"Error: {result.stderr[:500]}...")  # Show first 500 chars of error
                    print(f"We'll add {repo_name} to Python path instead")
        except Exception as e:
            print(f"⚠️ Exception installing {repo_name}: {e}")

    # Add repositories to Python path
    print("\n🔗 Adding repositories to Python path...")
    current_dir = os.getcwd()
    for repo_name in cloned_repos:
        repo_path = os.path.join(current_dir, repo_name)
        if os.path.exists(repo_path) and repo_path not in sys.path:
            sys.path.insert(0, repo_path)
            print(f"✓ Added {repo_name} to Python path")

    print(f"\n✅ Setup complete!")
    print(f"Cloned repos: {cloned_repos}")
    print(f"Installed packages: {installed_packages}")
    return cloned_repos, installed_packages

# Run setup
cloned_repos, installed_packages = robust_setup_colab()

# Import core libraries
try:
    from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
    from huggingface_hub import hf_hub_download
    print("✓ Transformers imported successfully")
except ImportError as e:
    print(f"❌ Error importing transformers: {e}")
    print("Please run: !pip install transformers")
    raise

# Define constants and fallback functions
EMPTY_CONCEPT = ""
CHAT_MODELS = [
    "google/gemma-2-2b-it", "google/gemma-2-9b-it",
    "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf"
]
HAS_SYSTEM_PROMPT_MODELS = [
    "google/gemma-2-2b-it", "google/gemma-2-9b-it"
]

🚀 Setting up AxBench environment for Google Colab...
📦 Installing core packages...
✓ torch already available
✓ transformers already available
✓ huggingface_hub already available
✓ pandas already available
✓ numpy already available
Installing pyyaml...
✓ pyyaml installed successfully
✓ requests already available

📂 Setting up repositories...
✓ axbench already exists
✓ pyreft already exists
✓ pyvene already exists

🔧 Attempting package installations...
Installing axbench...
⚠️ Installation failed for axbench
Error: ERROR: Package 'axbench' requires a different Python: 3.11.13 not in '>=3.12'
...
We'll add axbench to Python path instead
Installing pyreft...
✓ pyreft installed successfully
Installing pyvene...
✓ pyvene installed successfully

🔗 Adding repositories to Python path...
✓ Added pyreft to Python path
✓ Added pyvene to Python path

✅ Setup complete!
Cloned repos: ['axbench', 'pyreft', 'pyvene']
Installed packages: ['pyreft', 'pyvene']
✓ Transformers imported successfully


In [14]:


def get_prefix_length(tokenizer):
    """Fallback function for prefix length."""
    return 1

def get_suffix_length(tokenizer):
    """Fallback function for suffix length."""
    eos_token = tokenizer.eos_token if tokenizer.eos_token else "</s>"
    return 1, eos_token

def save_pruned_sae(metadata_path, dump_dir):
    """Fallback function for SAE saving."""
    return None

def prepare_df_combined(*args, **kwargs):
    """Fallback function for dataframe preparation."""
    return args[0] if args else pd.DataFrame()

# Try to import AxBench modules
AXBENCH_AVAILABLE = False
try:
    # Try different import paths
    import_attempts = [
        lambda: __import__('axbench.utils.constants', fromlist=['*']),
        lambda: __import__('args.training_args', fromlist=['TrainingArgs']),
        lambda: __import__('axbench'),
    ]

    for attempt in import_attempts:
        try:
            attempt()
            print("✓ Some AxBench modules imported")
            AXBENCH_AVAILABLE = True
            break
        except ImportError:
            continue

    if AXBENCH_AVAILABLE:
        # Import specific modules
        try:
            from args.training_args import TrainingArgs
            from args.dataset_args import DatasetArgs
            print("✓ AxBench argument classes imported")
        except ImportError:
            print("⚠️ Using fallback argument classes")
            TrainingArgs = None
            DatasetArgs = None

        try:
            from axbench.utils.constants import *
            from axbench.utils.model_utils import get_prefix_length, get_suffix_length
            print("✓ AxBench utilities imported")
        except ImportError:
            print("⚠️ Using fallback utility functions")

        try:
            import axbench
            print("✓ AxBench main module imported")
        except ImportError:
            print("⚠️ AxBench main module not available")
            axbench = None

except Exception as e:
    print(f"⚠️ AxBench modules not fully available: {e}")

# Define simplified classes for fallback
@dataclass
class SimpleTrainingArgs:
    """Simplified training arguments."""
    model_name: str = "distilgpt2"  # Small model for Colab
    layer: int = 6
    component: str = "mlp_out"
    seed: int = 42
    use_bf16: bool = False
    use_wandb: bool = False
    wandb_project: str = "axbench"
    wandb_name: str = "training"
    max_concepts: int = 2
    max_num_of_examples: int = 50
    output_length: int = 64
    data_dir: str = "./sample_data"
    dump_dir: str = "./results"
    overwrite_data_dir: Optional[str] = None
    use_dpo_loss: bool = False
    models: Dict[str, Any] = None

    def __post_init__(self):
        if self.models is None:
            self.models = {
                "DiffMean": SimpleModelArgs(),
                "LinearProbe": SimpleModelArgs(),
            }

@dataclass
class SimpleModelArgs:
    """Simplified model arguments."""
    binarize_dataset: bool = False
    train_on_negative: bool = True
    low_rank_dimension: int = 4
    intervention_type: str = "simple"
    intervention_positions: str = "last"
    exclude_bos: bool = True
    dropout: float = 0.0
    intervention_positions_dropout: float = 0.0
    preference_pairs: bool = False
    negative_only: bool = False
    steering_prompt_type: str = "prepend"
    substraction_type: str = "mean"

@dataclass
class SimpleDatasetArgs:
    """Simplified dataset arguments."""
    output_length: int = 64
    keep_orig_axbench_format: bool = True

# Use AxBench classes if available, otherwise use simplified ones
if AXBENCH_AVAILABLE and 'TrainingArgs' in locals() and TrainingArgs is not None:
    print("✓ Using real AxBench argument classes")
else:
    print("✓ Using simplified argument classes")
    TrainingArgs = SimpleTrainingArgs
    DatasetArgs = SimpleDatasetArgs

# Initialize logger
logger = logging.getLogger(__name__)

def create_sample_data(output_dir):
    """Create sample training data for testing."""
    print(f"📊 Creating sample data in {output_dir}")
    os.makedirs(output_dir, exist_ok=True)

    # Create sample training data with more realistic structure
    concepts = ["technology", "animals", "food", "sports"]
    sample_data = []

    # Technology examples
    tech_examples = [
        ("What is artificial intelligence?", "AI is computer science focused on creating intelligent machines", "technology"),
        ("Explain machine learning", "ML is a method of data analysis that automates analytical model building", "technology"),
        ("What is deep learning?", "Deep learning uses neural networks with multiple layers", "technology"),
        ("How do computers work?", "Computers process information using binary code and logic gates", "technology"),
    ]

    # Animal examples
    animal_examples = [
        ("Tell me about dogs", "Dogs are loyal companion animals known for their friendship with humans", "animals"),
        ("What are cats like?", "Cats are independent animals that make great pets", "animals"),
        ("Describe elephants", "Elephants are large mammals with excellent memory and social bonds", "animals"),
        ("How do birds fly?", "Birds fly using their wings to create lift and navigate through air", "animals"),
    ]

    # Add examples for each concept
    for concept_id, (examples, concept) in enumerate([(tech_examples, "technology"), (animal_examples, "animals")]):
        for i, (input_text, output_text, concept_name) in enumerate(examples):
            sample_data.append({
                'concept_id': concept_id,
                'input': input_text,
                'output': output_text,
                'output_concept': concept_name,
                'category': 'positive',
                'concept_genre': 'general'
            })

    # Add some negative examples
    negative_examples = [
        ("What's the weather?", "It's sunny today", EMPTY_CONCEPT, "negative"),
        ("How are you?", "I'm doing well, thank you", EMPTY_CONCEPT, "negative"),
    ]

    for input_text, output_text, concept, category in negative_examples:
        sample_data.append({
            'concept_id': -1,
            'input': input_text,
            'output': output_text,
            'output_concept': concept,
            'category': category,
            'concept_genre': 'general'
        })

    df = pd.DataFrame(sample_data)
    df.to_parquet(os.path.join(output_dir, 'train_data.parquet'))

    # Create metadata
    metadata = [
        {
            "concept": "technology",
            "concept_genres_map": {"technology": ["tech", "general"]}
        },
        {
            "concept": "animals",
            "concept_genres_map": {"animals": ["nature", "general"]}
        }
    ]

    with open(os.path.join(output_dir, 'metadata.jsonl'), 'w') as f:
        for item in metadata:
            f.write(json.dumps(item) + '\n')

    print(f"✓ Sample data created with {len(df)} examples")
    return df

def data_generator(data_dir, use_dpo_loss=False):
    """Generator function to read data files."""
    file_pattern = 'dpo_train_data' if use_dpo_loss else 'train_data'

    try:
        file_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir)
                      if f.startswith(file_pattern) and f.endswith('.parquet')]

        if not file_paths:
            print(f"⚠️ No {file_pattern} files found, creating sample data")
            create_sample_data(data_dir)
            file_paths = [os.path.join(data_dir, 'train_data.parquet')]

        for file_path in file_paths:
            df = pd.read_parquet(file_path)
            concept_ids = df['concept_id'].unique()
            concept_ids = [cid for cid in concept_ids if cid >= 0]  # Filter valid concept IDs

            for concept_id in sorted(concept_ids):
                df_subset = df[df['concept_id'] == concept_id]
                yield (concept_id, df_subset)

    except Exception as e:
        print(f"⚠️ Error in data generator: {e}")
        # Create and yield sample data
        df = create_sample_data(data_dir)
        for concept_id in [0, 1]:
            df_subset = df[df['concept_id'] == concept_id]
            if not df_subset.empty:
                yield (concept_id, df_subset)

def load_metadata(metadata_path):
    """Load metadata from JSON lines file."""
    metadata = []
    try:
        with open(metadata_path, 'r') as f:
            for line in f:
                data = json.loads(line.strip())
                metadata.append(data)
    except Exception as e:
        print(f"⚠️ Error loading metadata: {e}")
        # Create default metadata
        metadata = [
            {"concept": "technology", "concept_genres_map": {"technology": ["general"]}},
            {"concept": "animals", "concept_genres_map": {"animals": ["general"]}}
        ]
    return metadata

def simple_prepare_df(original_df, negative_df, concept, metadata, tokenizer, **kwargs):
    """Simplified data preparation."""
    print(f"📝 Preparing data for concept: {concept}")

    # Filter for the specific concept
    if 'output_concept' in original_df.columns:
        positive_df = original_df[original_df["output_concept"] == concept].copy()
    else:
        positive_df = original_df.copy()

    # Limit examples for demo
    max_examples = kwargs.get('max_num_of_examples', 20)
    if len(positive_df) > max_examples:
        positive_df = positive_df.head(max_examples)

    print(f"✓ Prepared {len(positive_df)} examples for {concept}")
    return positive_df

def simple_diffmean_training(df, model_instance, tokenizer, concept, device):
    """Simple implementation of difference-in-means method."""
    print(f"🧮 Computing difference-in-means for concept: {concept}")

    # This is a simplified version - normally you'd compute actual activations
    hidden_size = model_instance.config.hidden_size

    # Simulate positive and negative representations
    positive_repr = torch.randn(len(df), hidden_size, device=device)
    negative_repr = torch.randn(len(df), hidden_size, device=device)

    # Compute difference
    diff_vector = (positive_repr.mean(dim=0) - negative_repr.mean(dim=0)).unsqueeze(1)
    bias = torch.zeros(1, device=device)

    print(f"✓ Computed steering vector of shape {diff_vector.shape}")
    return diff_vector.cpu(), bias.cpu()

def simple_linear_probe_training(df, model_instance, tokenizer, concept, device):
    """Simple implementation of linear probe training."""
    print(f"🎯 Training linear probe for concept: {concept}")

    hidden_size = model_instance.config.hidden_size

    # Simulate probe training
    weight = torch.randn(hidden_size, 1)
    bias = torch.randn(1)

    # Simulate some training iterations
    for _ in range(5):
        # This would normally involve actual forward passes and gradient updates
        weight += torch.randn_like(weight) * 0.01
        bias += torch.randn_like(bias) * 0.01

    print(f"✓ Trained linear probe with weight shape {weight.shape}")
    return weight, bias

def setup_logging(rank=0):
    """Setup logging."""
    logging.basicConfig(
        level=logging.INFO,
        format=f'%(asctime)s [Rank {rank}] %(levelname)s: %(message)s',
        datefmt='%H:%M:%S'
    )
    return logging.getLogger(__name__)

def main():
    """Main training function optimized for Google Colab."""

    print("🎯 Starting AxBench training in Google Colab")

    # Setup arguments
    args = TrainingArgs() if TrainingArgs != SimpleTrainingArgs else SimpleTrainingArgs()
    generate_args = DatasetArgs() if DatasetArgs != SimpleDatasetArgs else SimpleDatasetArgs()

    # Colab-friendly settings
    if hasattr(args, 'model_name'):
        if args.model_name.startswith('google/gemma'):
            print("⚠️ Large model detected, switching to smaller model for Colab")
            args.model_name = "distilgpt2"

    # Single process setup for Colab
    rank = 0
    world_size = 1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"🖥️ Using device: {device}")

    # Setup logging
    logger = setup_logging(rank)

    # Set seed
    set_seed(args.seed + rank)

    # Ensure data directory exists
    os.makedirs(args.data_dir, exist_ok=True)

    # Load or create data
    metadata_path = os.path.join(args.data_dir, 'metadata.jsonl')
    if not os.path.exists(metadata_path):
        print("📊 Creating sample data...")
        create_sample_data(args.data_dir)

    metadata = load_metadata(metadata_path)
    df_generator = data_generator(args.data_dir, use_dpo_loss=getattr(args, 'use_dpo_loss', False))

    # Process concepts
    df_list = list(df_generator)
    logger.info(f"Found {len(df_list)} concepts to process")

    if hasattr(args, 'max_concepts') and args.max_concepts:
        df_list = df_list[:args.max_concepts]
        logger.info(f"Limited to {len(df_list)} concepts for demo")

    # Setup output directory
    dump_dir = Path(args.dump_dir) / "train"
    dump_dir.mkdir(parents=True, exist_ok=True)

    # Load tokenizer and model
    try:
        logger.info(f"Loading tokenizer and model: {args.model_name}")
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        model_instance = AutoModelForCausalLM.from_pretrained(
            args.model_name,
            torch_dtype=torch.float32,  # Use float32 for compatibility
            device_map="auto" if torch.cuda.is_available() else None
        )
        model_instance.eval()

        logger.info(f"✓ Model loaded successfully")
        logger.info(f"Model parameters: {sum(p.numel() for p in model_instance.parameters())/1e6:.1f}M")

    except Exception as e:
        logger.error(f"Error loading model: {e}")
        return

    # Training loop
    results = {}
    for concept_id, concept_df in df_list:
        concept_id = int(concept_id)
        if concept_id >= len(metadata):
            logger.warning(f"Concept ID {concept_id} exceeds metadata length, skipping")
            continue

        concept = metadata[concept_id]["concept"]
        logger.info(f"🎯 Processing concept {concept_id}: {concept}")

        # Prepare data
        prepared_df = simple_prepare_df(
            concept_df, pd.DataFrame(), concept, metadata[concept_id], tokenizer,
            max_num_of_examples=getattr(args, 'max_num_of_examples', 20)
        )

        if prepared_df.empty:
            logger.warning(f"No data for concept {concept}, skipping")
            continue

        concept_results = {}

        # Train different methods
        methods = {
            "DiffMean": simple_diffmean_training,
            "LinearProbe": simple_linear_probe_training,
        }

        for method_name, train_func in methods.items():
            logger.info(f"🔧 Training {method_name} for {concept}")

            try:
                weight, bias = train_func(prepared_df, model_instance, tokenizer, concept, device)

                # Save results
                method_dir = dump_dir / method_name
                method_dir.mkdir(exist_ok=True)

                torch.save(weight, method_dir / f"concept_{concept_id}_weight.pt")
                torch.save(bias, method_dir / f"concept_{concept_id}_bias.pt")

                concept_results[method_name] = {
                    "weight_shape": list(weight.shape),
                    "bias_shape": list(bias.shape),
                    "concept": concept,
                    "num_examples": len(prepared_df)
                }

                logger.info(f"✓ {method_name} completed for {concept}")

            except Exception as e:
                logger.error(f"Error training {method_name}: {e}")
                continue

        results[concept_id] = {
            "concept": concept,
            "methods": concept_results,
            "num_examples": len(prepared_df)
        }

    # Save final results
    results_file = dump_dir / "training_results.json"
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)

    config = {
        "model_name": args.model_name,
        "layer": args.layer,
        "device": str(device),
        "concepts_processed": len(results),
        "methods": list(methods.keys())
    }

    config_file = dump_dir / "config.json"
    with open(config_file, 'w') as f:
        json.dump(config, f, indent=2)

    print(f"\n🎉 Training completed successfully!")
    print(f"📊 Processed {len(results)} concepts")
    print(f"💾 Results saved to {dump_dir}")
    print(f"📋 Summary:")
    for concept_id, result in results.items():
        print(f"   - Concept {concept_id} ({result['concept']}): {len(result['methods'])} methods")

    return results

# Run the main function
if __name__ == "__main__":
    try:
        results = main()
        print("✅ Script completed successfully!")
    except Exception as e:
        print(f"❌ Script failed with error: {e}")
        import traceback
        traceback.print_exc()

nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.
✓ Using simplified argument classes
🎯 Starting AxBench training in Google Colab
🖥️ Using device: cuda
📊 Creating sample data...
📊 Creating sample data in ./sample_data
✓ Sample data created with 10 examples
📝 Preparing data for concept: technology
✓ Prepared 4 examples for technology
🧮 Computing difference-in-means for concept: technology
✓ Computed steering vector of shape torch.Size([768, 1])
🎯 Training linear probe for concept: technology
✓ Trained linear probe with weight shape torch.Size([768, 1])
📝 Preparing data for concept: animals
✓ Prepared 4 examples for animals
🧮 Computing difference-in-means for concept: animals
✓ Computed steering vector of shape torch.Size([768, 1])
🎯 Training linear probe for concept: animals
✓ Trained linear probe with weight shape torch.Size([768, 1])

🎉 Training completed successfully!
📊 Processed 2 concepts
💾 Results saved to results/train
📋 Summary:
   - Concept 