In [1]:
pip install torch transformers accelerate bitsandbytes

[0mNote: you may need to restart the kernel to use updated packages.


In [3]:
# @title 1.5. For access to Gemma models, log in to HuggingFace 
from huggingface_hub import login
HUGGING_FACE_TOKEN = "TOKEN"
try:
     login(token=HUGGING_FACE_TOKEN)
     print("Hugging Face login successful (using provided token).")
except Exception as e:
     print(f"Hugging Face login failed. Error: {e}")

Hugging Face login successful (using provided token).


In [7]:
for name, module in model.named_modules():
    print(name)


model
model.embed_tokens
model.layers
model.layers.0
model.layers.0.self_attn
model.layers.0.self_attn.q_proj
model.layers.0.self_attn.k_proj
model.layers.0.self_attn.v_proj
model.layers.0.self_attn.o_proj
model.layers.0.mlp
model.layers.0.mlp.gate_proj
model.layers.0.mlp.up_proj
model.layers.0.mlp.down_proj
model.layers.0.mlp.act_fn
model.layers.0.input_layernorm
model.layers.0.post_attention_layernorm
model.layers.0.pre_feedforward_layernorm
model.layers.0.post_feedforward_layernorm
model.layers.1
model.layers.1.self_attn
model.layers.1.self_attn.q_proj
model.layers.1.self_attn.k_proj
model.layers.1.self_attn.v_proj
model.layers.1.self_attn.o_proj
model.layers.1.mlp
model.layers.1.mlp.gate_proj
model.layers.1.mlp.up_proj
model.layers.1.mlp.down_proj
model.layers.1.mlp.act_fn
model.layers.1.input_layernorm
model.layers.1.post_attention_layernorm
model.layers.1.pre_feedforward_layernorm
model.layers.1.post_feedforward_layernorm
model.layers.2
model.layers.2.self_attn
model.layers.2.se

In [9]:
# -*- coding: utf-8 -*-
# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.1
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# # Gemma2 9B Activation Steering Notebook

# This notebook demonstrates how to compute a steering vector based on the difference
# in activations between positive and negative prompts for a specific layer in Gemma2 9B.
# It then uses this vector to steer the model's generation.

# ## 1. Setup and Imports

# +
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os
import gc
from contextlib import contextmanager
from typing import List, Dict, Optional, Callable

print(f"PyTorch version: {torch.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Current device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# -

# ## 2. Configuration

# +
# --- Model Configuration ---
MODEL_ID = "google/gemma-2-9b-it" # Or "google/gemma-2-9b" if you prefer the base model
# Set to True if you have limited VRAM (e.g., < 24GB). Requires bitsandbytes
USE_4BIT_QUANTIZATION = False

# --- Steering Configuration ---
# !! IMPORTANT !! Find the correct layer name for your model.
# Example: 'model.layers[15].mlp.gate_proj' or 'model.layers[20].self_attn.o_proj'
# Use the `print(model)` output in Section 3 to find a suitable layer name.
TARGET_LAYER_NAME = 'model.layers.20' # <--- CHANGE THIS

# Lists of prompts to define the direction
POSITIVE_PROMPTS = [
    "This story should be very optimistic and uplifting.",
    "Write a hopeful and positive narrative.",
    "Generate text with a cheerful and encouraging tone.",
]
NEGATIVE_PROMPTS = [
    "This story should be very pessimistic and bleak.",
    "Write a depressing and negative narrative.",
    "Generate text with a gloomy and discouraging tone.",
]

# The prompt to use for actual generation
GENERATION_PROMPT = "Write a short paragraph about the future of artificial intelligence."

# How strongly to apply the steering vector. Tune this value (e.g., 0.5 to 5.0)
STEERING_MULTIPLIER = 1.5

# --- Generation Parameters ---
MAX_NEW_TOKENS = 150
TEMPERATURE = 0.7
DO_SAMPLE = True

# --- Output ---
OUTPUT_FILE = "gemma2_steering_output.txt"

# Check if configuration seems valid
if not TARGET_LAYER_NAME or '.' not in TARGET_LAYER_NAME:
    print("WARNING: TARGET_LAYER_NAME looks suspicious. Please verify it.")
if not POSITIVE_PROMPTS or not NEGATIVE_PROMPTS:
    raise ValueError("Positive and Negative prompt lists cannot be empty.")
# -

# ## 3. Load Model and Tokenizer

# +
# Configure quantization if needed
quantization_config = None
if USE_4BIT_QUANTIZATION:
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16 # Recommended for new models
    )
    print("Using 4-bit quantization.")

# Determine device and dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float32 # BF16 recommended on Ampere+

print(f"Loading model: {MODEL_ID}")
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token # Set pad token if not present

# Load Model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    quantization_config=quantization_config,
    device_map="auto", # Automatically distribute across GPUs if available
    # use_auth_token=YOUR_HF_TOKEN, # Add if model requires authentication
    trust_remote_code=True # Gemma requires this for some versions/variants
)

print(f"Model loaded on device(s): {model.hf_device_map}")

# --- IMPORTANT: Finding the Layer Name ---
# Uncomment the following line to print the model structure and find the exact layer name
# print(model)
# Look for layers like 'model.layers[INDEX].mlp...' or 'model.layers[INDEX].self_attn...'

# Ensure model is in evaluation mode
model.eval()
# -

# ## 4. Hooking and Activation Handling Functions

# +
# Global storage for captured activations
activation_storage = {}

def get_module_by_name(model, module_name):
    """Helper function to get a module object from its name string."""
    names = module_name.split('.')
    module = model
    for name in names:
        module = getattr(module, name)
    return module

def capture_activation_hook(module, input, output, layer_name):
    """Hook function to capture the output activation of a specific layer."""
    # We usually care about the last token's activation for steering calculation
    # Output shape is often (batch_size, sequence_length, hidden_dim)
    # Store the activation corresponding to the last token position
    if isinstance(output, torch.Tensor):
        activation_storage[layer_name] = output[:, -1, :].detach().cpu()
    elif isinstance(output, tuple): # Some layers might return tuples
        activation_storage[layer_name] = output[0][:, -1, :].detach().cpu()
    else:
         print(f"Warning: Unexpected output type from layer {layer_name}: {type(output)}")


def get_activations(model, tokenizer, prompts: List[str], layer_name: str) -> Optional[torch.Tensor]:
    """
    Runs prompts through the model and captures activations from the target layer.
    Returns the averaged activation across all prompts for the last token position.
    """
    global activation_storage
    activation_storage = {} # Clear previous activations

    target_module = get_module_by_name(model, layer_name)
    hook_handle = target_module.register_forward_hook(
        lambda module, input, output: capture_activation_hook(module, input, output, layer_name)
    )

    all_layer_activations = []
    with torch.no_grad():
        for prompt in prompts:
            inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
            # We only need the forward pass, not generation here
            _ = model(**inputs)

            if layer_name in activation_storage:
                 # Assuming batch size is 1 when processing one prompt at a time
                last_token_activation = activation_storage[layer_name] # Shape (1, hidden_dim)
                all_layer_activations.append(last_token_activation)
                del activation_storage[layer_name] # Clear for next prompt
            else:
                print(f"Warning: Activation for layer {layer_name} not captured for prompt: '{prompt}'")


    hook_handle.remove() # Clean up the hook

    if not all_layer_activations:
        print(f"Error: No activations were captured for layer {layer_name}.")
        return None

    # Stack and average activations across all prompts
    # Resulting shape: (num_prompts, hidden_dim) -> (hidden_dim)
    avg_activation = torch.stack(all_layer_activations).mean(dim=0).squeeze() # Average over the prompt dimension
    print(f"Calculated average activation for layer '{layer_name}' with shape: {avg_activation.shape}")
    return avg_activation


# --- Steering Hook during Generation ---

# Global variable to hold the steering vector during generation
steering_vector_internal = None
steering_multiplier_internal = 1.0

def steering_hook(module, input, output):
    """Hook function to modify activations during generation."""
    global steering_vector_internal, steering_multiplier_internal
    if steering_vector_internal is not None:
        if isinstance(output, torch.Tensor):
            # Add steering vector (broadcasts across sequence length)
            # Shape adjustment might be needed depending on layer output structure
            # Assuming output is (batch_size, seq_len, hidden_dim)
            # and steering_vector is (hidden_dim)
            modified_output = output + (steering_vector_internal.to(output.device, dtype=output.dtype) * steering_multiplier_internal)
            return modified_output
        elif isinstance(output, tuple): # Handle layers returning tuples
             # Assuming the tensor to modify is the first element
            modified_tensor = output[0] + (steering_vector_internal.to(output[0].device, dtype=output[0].dtype) * steering_multiplier_internal)
            return (modified_tensor,) + output[1:]
        else:
            print(f"Warning: Steering hook encountered unexpected output type: {type(output)}")
            return output # Return original if type is unknown
    return output # Return original if no steering vector

@contextmanager
def apply_steering(model, layer_name, steering_vector, multiplier):
    """Context manager to temporarily apply the steering hook."""
    global steering_vector_internal, steering_multiplier_internal

    # Ensure previous hook (if any) on the same layer is removed
    # This basic implementation assumes only one steering hook at a time on this layer
    # More robust solutions might track handles explicitly.
    
    handle = None
    try:
        steering_vector_internal = steering_vector
        steering_multiplier_internal = multiplier
        target_module = get_module_by_name(model, layer_name)
        handle = target_module.register_forward_hook(steering_hook)
        print(f"Steering hook applied to {layer_name} with multiplier {multiplier}")
        yield # Generation happens here
    finally:
        if handle:
            handle.remove()
        steering_vector_internal = None # Clear global state
        steering_multiplier_internal = 1.0
        print(f"Steering hook removed from {layer_name}")
        gc.collect() # Suggest garbage collection
        torch.cuda.empty_cache() # Clear cache if using GPU
# -

PyTorch version: 2.8.0.dev20250319+cu128
Transformers version: 4.51.3
CUDA available: True
CUDA version: 12.8
Current device: 0
Device name: NVIDIA A100 80GB PCIe
Loading model: google/gemma-2-9b-it
Using device: cuda
Using dtype: torch.bfloat16


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Model loaded on device(s): {'': 0}
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])

--- Generating Baseline Output (No Steering) ---
Write a short paragraph about the future of artificial intelligence.

The future of artificial intelligence is brimming with both promise and uncertainty.  As AI algorithms become increasingly sophisticated, we can expect breakthroughs in fields like medicine, transportation, and environmental science.  AI-powered assistants will likely become more integrated into our daily lives, automating tasks and providing personalized experiences.  However, alongside these advancements come ethical concerns regarding job displacement, algorithmic bias, and the potential misuse o

In [14]:
#possible rhyme set
POSITIVE_PROMPTS = ['A rhymed couplet:\nHe saw a carrot and had to grab it\n',
 'A rhymed couplet:\n\nHe saw a carrot and had to grab it\n',
 'Continue a rhyming poem starting with the following line:\n\nHe saw a carrot and had to grab it\n',
 'Continue a rhyming poem starting with the following line:\nHe saw a carrot and had to grab it\n']

NEGATIVE_PROMPTS = ['A rhymed couplet:\nFootsteps echoing on the schoolyard bricks\n',
 'A rhymed couplet:\n\nFootsteps echoing on the schoolyard bricks\n',
 'Continue a rhyming poem starting with the following line:\n\nFootsteps echoing on the schoolyard bricks\n',
 'Continue a rhyming poem starting with the following line:\nFootsteps echoing on the schoolyard bricks\n']

OUTPUT_FILE="gemma2_steering_rhyme.txt"

GENERATION_PROMPT='A rhymed couplet:\nHe saw a carrot and had to grab it\n'

In [17]:
GEMMA_PROMPT_TEMPLATE="<start_of_turn>user\n{instruction}<end_of_turn>\n<start_of_turn>model\n{line}\n"

In [22]:
#formatted version 1
instructions= ['A rhymed couplet:',
 'A rhymed couplet:\n',
 'Continue a rhyming poem starting with the following line:\n',
 'Continue a rhyming poem starting with the following line:']

lines= ['He saw a carrot and had to grab it',
        'A single rose, its petals unfold',
 'Footsteps echoing on the schoolyard bricks']

In [23]:
#formatted for Gemma
#separation of instruction and line
POSITIVE_PROMPTS=[GEMMA_PROMPT_TEMPLATE.format(instruction=instruction,line=lines[0]) for instruction in instructions]
NEGATIVE_PROMPTS=[GEMMA_PROMPT_TEMPLATE.format(instruction=instruction,line=lines[1]) for instruction in instructions]
GENERATION_PROMPT=GEMMA_PROMPT_TEMPLATE.format(instruction=instructions[0],line=lines[0])

In [24]:
# ## 5. Compute the Steering Vector

# +
print("Calculating activations for POSITIVE prompts...")
avg_pos_activation = get_activations(model, tokenizer, POSITIVE_PROMPTS, TARGET_LAYER_NAME)

print("\nCalculating activations for NEGATIVE prompts...")
avg_neg_activation = get_activations(model, tokenizer, NEGATIVE_PROMPTS, TARGET_LAYER_NAME)

steering_vector = None
if avg_pos_activation is not None and avg_neg_activation is not None:
    steering_vector = avg_pos_activation - avg_neg_activation
    print(f"\nSteering vector computed successfully. Shape: {steering_vector.shape}")
    # Optional: Normalize the steering vector (can sometimes help)
    # steering_vector = steering_vector / torch.norm(steering_vector)
    # print("Steering vector normalized.")
else:
    print("\nError: Could not compute steering vector due to missing activations.")

# Clean up memory
del avg_pos_activation
del avg_neg_activation
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
# -

# ## 6. Generate Text (Baseline vs. Steered)

# +
if steering_vector is not None:
    print("\n--- Generating Baseline Output (No Steering) ---")
    inputs = tokenizer(GENERATION_PROMPT, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs_baseline = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            do_sample=DO_SAMPLE,
            pad_token_id=tokenizer.eos_token_id # Important for generation
        )
    text_baseline = tokenizer.decode(outputs_baseline[0], skip_special_tokens=True)
    print(text_baseline)

    print(f"\n--- Generating Steered Output (Multiplier: {STEERING_MULTIPLIER}) ---")
    with torch.no_grad():
         # Apply the steering hook using the context manager
        with apply_steering(model, TARGET_LAYER_NAME, steering_vector, STEERING_MULTIPLIER):
            outputs_steered = model.generate(
                **inputs, # Use the same input tokens
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=TEMPERATURE,
                do_sample=DO_SAMPLE,
                pad_token_id=tokenizer.eos_token_id
            )
    text_steered = tokenizer.decode(outputs_steered[0], skip_special_tokens=True)
    print(text_steered)

    print(f"\n--- Generating Steered Output (Multiplier: {-STEERING_MULTIPLIER}) ---")
    with torch.no_grad():
         # Apply the steering hook using the context manager
        with apply_steering(model, TARGET_LAYER_NAME, steering_vector, -STEERING_MULTIPLIER):
            outputs_steered = model.generate(
                **inputs, # Use the same input tokens
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=TEMPERATURE,
                do_sample=DO_SAMPLE,
                pad_token_id=tokenizer.eos_token_id
            )
    text_negsteered = tokenizer.decode(outputs_steered[0], skip_special_tokens=True)
    print(text_negsteered)

    # Clean up generation outputs
    del outputs_baseline, outputs_steered, inputs
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

else:
    print("\nSkipping generation because the steering vector could not be computed.")
# -

# ## 7. Save Results to File

# +
if steering_vector is not None:
    try:
        with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
            f.write("--- Configuration ---\n")
            f.write(f"Model ID: {MODEL_ID}\n")
            f.write(f"Quantized: {USE_4BIT_QUANTIZATION}\n")
            f.write(f"Target Layer: {TARGET_LAYER_NAME}\n")
            f.write(f"Steering Multiplier: {STEERING_MULTIPLIER}\n")
            f.write(f"Max New Tokens: {MAX_NEW_TOKENS}\n")
            f.write(f"Temperature: {TEMPERATURE}\n")
            f.write(f"Do Sample: {DO_SAMPLE}\n")
            f.write("\n--- Positive Prompts ---\n")
            for p in POSITIVE_PROMPTS:
                f.write(f"- {p}\n")
            f.write("\n--- Negative Prompts ---\n")
            for p in NEGATIVE_PROMPTS:
                f.write(f"- {p}\n")
            f.write("\n--- Generation Prompt ---\n")
            f.write(f"{GENERATION_PROMPT}\n")
            f.write("\n" + "="*30 + "\n")
            f.write("--- Baseline Output ---\n")
            f.write(text_baseline + "\n")
            f.write("\n" + "="*30 + "\n")
            f.write("--- Steered Output ---\n")
            f.write(text_steered + "\n")
            f.write("--- Neg Steered Output ---\n")
            f.write(text_negsteered + "\n")
        print(f"\nResults saved to {OUTPUT_FILE}")
    except Exception as e:
        print(f"\nError saving results to file: {e}")
else:
    print("\nResults not saved as generation was skipped.")
# -

# ## 8. Optional: Clean up Model from Memory
# (Uncomment if you need to free up GPU memory)

# +
# del model
# del tokenizer
# gc.collect()
# if torch.cuda.is_available():
#     torch.cuda.empty_cache()
# print("Model and tokenizer removed from memory.")
# -

Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])

--- Generating Baseline Output (No Steering) ---
user
A rhymed couplet:
model
He saw a carrot and had to grab it
And now he's singing, "I'm a veggie rabbit!" 


Let me know if you'd like another one! 😊  


--- Generating Steered Output (Multiplier: 1.5) ---
Steering hook applied to model.layers.20 with multiplier 1.5
Steering hook removed from model.layers.20
user
A rhymed couplet:
model
He saw a carrot and had to grab it
And then he ran off, oh what a habit! 





--- Generating Steered Output (Multiplier: -1.5) ---
Steering hook applied to model.layers.20 with multiplier -1.5
Steering hook removed from model.layers.20
user
A rhymed couplet:
model
H