In [None]:
# Notebook / Runpod setup, safe to ignore
%pip install matplotlib accelerate datasets einops huggingface-hub jaxtyping natsort simple-parsing triton transformers gguf sentencepiece scikit-learn seaborn
%pip install -U safetensors
%pip install git+https://github.com/EleutherAI/sae.git

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys, os
notebook_dir = os.path.abspath('')  # Get current notebook directory
experiments_dir = os.path.dirname(notebook_dir)  # Get parent directory
sys.path.insert(0, experiments_dir)
os.environ["HF_HOME"] = os.path.join(experiments_dir, "hf")

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [5]:
import json

import tqdm
import torch
import numpy as np

from repeng import ControlVector, ControlModel, DatasetEntry
import repeng.saes

import matplotlib.pyplot as plt
from IPython.display import clear_output
from collections import Counter
import math
import torch.nn.functional as F

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
torch.cuda.empty_cache()

# Collect garbage
gc.collect()

# Force CUDA to sync
torch.cuda.synchronize()

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token_id = tokenizer.eos_token_id

control_layers = list(range(2, 30))

base_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    torch_dtype="auto").to(
    torch.device("cuda:0"))
base_model = ControlModel(base_model, control_layers)

instruct_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    torch_dtype="auto").to(
    torch.device("cuda:0"))
instruct_model = ControlModel(instruct_model, control_layers)

In [None]:
sae = repeng.saes.from_eleuther(device="cuda:0", layers=control_layers)

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import textwrap

with open(os.path.join(notebook_dir, "data/all_truncated_outputs.json")) as f:
    output_suffixes = json.load(f)
truncated_output_suffixes = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in output_suffixes)
    for i in range(1, len(tokens))
]

def visualize_discontinuities(x, y, discontinuities):
    """
    Visualize the detected discontinuities on the original data.
    """
    plt.figure(figsize=(12, 6))
    # Change from line plot to scatter plot
    plt.scatter(x, y, c='b', s=20, alpha=0.6, label='Original data')
    
    x_discontinuities = x[discontinuities]
    y_discontinuities = y[discontinuities]

    plt.scatter(x_discontinuities, y_discontinuities, c='r', s=100, 
               label='Detected discontinuities', zorder=3)
    
    # Add vertical lines at discontinuities
    for idx in discontinuities:
        plt.axvline(x=x[idx], color='r', linestyle='--', alpha=0.3)
    
    plt.legend()
    plt.grid(True)
    plt.title('Detected Discontinuities')
    plt.show()

def find_significant_drop(x, y, drop_threshold=0.2, window_size=1):
    """
    Find the first point where perplexity drops significantly below starting value.
    
    Parameters:
    x: array-like, coefficients
    y: array-like, perplexity values
    drop_threshold: float, minimum drop as fraction of starting value (default 0.2 = 20%)
    window_size: int, number of consecutive points to check to avoid noise (default 3)
    
    Returns:
    tuple: (x value where drop occurs, index of drop point)
    """
    # Get baseline from start of sequence
    baseline = y[0]
    target_value = baseline * (1 - drop_threshold)
    
    # Look for first window where all values are below target
    for i in range(len(y) - window_size + 1):
        window = y[i:i+window_size]
        if all([val < target_value for val in window]):
            return i
            
    # If no significant drop found
    return None


def calculate_sequence_probability(model, sequence):
    input_ids = sequence.unsqueeze(0)
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
        
    log_probs = F.log_softmax(logits, dim=-1)
    
    sequence_prob = 0.0  # Start with 0 since we're adding logs
    for i in range(input_ids.size(1)):
        token_id = input_ids[0, i].item()
        token_log_prob = log_probs[0, i, token_id].item()
        sequence_prob += token_log_prob  # Add instead of multiply
    
    return sequence_prob/input_ids.size(1)  # This will be the avg log probability

def generate_sequence(model, input_ids, vector, coeff, max_new_tokens):
    model.reset()
    model.set_control(coeff * vector)
    settings = {
        "pad_token_id": tokenizer.eos_token_id,
        "temperature": 1e-6,
        "max_new_tokens": max_new_tokens,
        #"repetition_penalty": 1,
    }
    with torch.no_grad():  # <-- Add this line
        output = model.generate(**input_ids, **settings)
    return output[0]

def calculate_perplexity_for_each_token(model, sequence):
    model.reset()
    perplexities = []
    for i in range(1, len(sequence)+1):
        #pplx = calculate_perplexity(model, sequence[0:i])
        pplx = calculate_sequence_probability(model, sequence[0:i])
        #pplx = calculate_adjusted_perplexity(pplx, current_text)
        perplexities.append(pplx)
    return perplexities

def calculate_perplexities_over_sequence(model, tokenizer, input_text, vector, token_count, start_coeff=0.16, iterations=20, end_coeff=1.0):
    coefficients = []
    all_perplexities = []
    outputs = []

    coeff = start_coeff
    input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
    for _ in tqdm.tqdm(range(iterations), desc="Testing coefficients"):
        #print(coeff)
        sequence = generate_sequence(model, input_ids, vector, coeff, token_count)
        perplexities = calculate_perplexity_for_each_token(model, sequence)
        
        coefficients.append(coeff)
        all_perplexities.append(perplexities)
        output = tokenizer.decode(sequence, skip_special_tokens=True)
        outputs.append(output)
        
        coeff += (end_coeff-start_coeff)/(iterations-1)
    
    return coefficients, all_perplexities, outputs

TEMPLATE = """{persona}. {prefill}"""

def make_dataset(
    persona_template: str,
    positive_personas: list[str],
    negative_personas: list[str],
    user_msg: str,
    suffix_list: list[str]
) -> list[DatasetEntry]:
    dataset = []
    for suffix in suffix_list:
        for positive_persona, negative_persona in zip(positive_personas, negative_personas):
            pos = persona_template.format(persona=positive_persona)
            neg = persona_template.format(persona=negative_persona)
            dataset.append(
                DatasetEntry(
                    positive=TEMPLATE.format(persona=pos, user_msg=user_msg, prefill=suffix),
                    negative=TEMPLATE.format(persona=neg, user_msg=user_msg, prefill=suffix),
                )
            )
    return dataset

# Create all widgets
mode_toggle = widgets.ToggleButtons(
    options=['Base Mode', 'Chat Mode'],
    value='Base Mode'
)

steering_input = widgets.Textarea(
    placeholder='Enter steering prompt...',
    layout={'width': '100%'}
)

input_box = widgets.Textarea(
    placeholder='Enter text...',
    layout={'width': '100%'}
)

upper_bound = 2.0

strength_slider = widgets.FloatSlider(
    min=-1 * upper_bound, max=upper_bound, value=0.0,
    description='Strength:',
    step=0.01,
    #Add marks at key points
    # marks=[
    #     (-1 * upper_bound, '-max'),
    #     (upper_bound, 'max')
    # ],
    # Make marks visible
    readout=True,
    continuous_update=True
)

generate_button = widgets.Button(description='Generate')

steering_generate_button = widgets.Button(description='Train Steering Vector')

clear_history_button = widgets.Button(description='Clear History')

status_display = widgets.HTML()
output_display = widgets.Output()

# Main container for chat history
chat_display = widgets.HTML()
chat_history = []

steering_vector = None

def on_generate_click(b):
    if mode_toggle.value == 'Base Mode':
        # Handle base mode generation
        prompt = input_box.value.strip()
        if not prompt:
            status_display.value = 'Please enter text'
            return
        if not steering_vector:
            status_display.value = 'Please generate a steering vector first'
            return
        
        try:
            status_display.value = f'Generating from base model with prompt "{prompt}" and coefficient {strength_slider.value:.2f}. Please wait...'
            input_ids = tokenizer(prompt, return_tensors="pt").to(base_model.device)
            base_model.reset()
            base_model.set_control(strength_slider.value * steering_vector)
            settings = {
                "pad_token_id": tokenizer.eos_token_id,
                "temperature": 0.7,
                "max_new_tokens": 128,
                "repetition_penalty": 1.1,
            }
            with torch.no_grad():  # <-- Add this line
                sequence = base_model.generate(**input_ids, **settings)[0]
            output = tokenizer.decode(sequence, skip_special_tokens=True)
            
            status_display.value = textwrap.fill(output, width=80)
                
        except Exception as e:
            status_display.value = f'Error: {str(e)}'
            
    else:
        # Handle chat mode
        message = input_box.value.strip()
        if not message:
            status_display.value = 'Please enter a message'
            return
            
        # Add user message to history
        chat_history.append(('user', message))
        input_box.value = ''
        
        try:
            status_display.value = f'Generating from instruct model with prompt "{message}" and coefficient {strength_slider.value:.2f}. Please wait...'
            # Format the full chat history into a single prompt
            chat_prompt = ""
            
            for entry in chat_history:
                role, content = entry
                if role == "user":
                    chat_prompt += f"<|start_header_id|>user<|end_header_id|>\n{content}<|eot_id|>\n"
                else:
                    chat_prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>\n"
            
            chat_prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"

            # Generate response using the same logic as base mode
            input_ids = tokenizer(chat_prompt, return_tensors="pt").to(instruct_model.device)
            instruct_model.reset()
            instruct_model.set_control(strength_slider.value * steering_vector)
            
            settings = {
                "pad_token_id": tokenizer.eos_token_id,
                "temperature": 0.7,
                "max_new_tokens": 128,
                "repetition_penalty": 1.1,
            }
            
            with torch.no_grad():  # <-- Add this line
                output = instruct_model.generate(
                    **input_ids, 
                    **settings
                )
            response = tokenizer.decode(output[0])

            # Split to get just the assistant's response
            response = response.split("<|start_header_id|>assistant<|end_header_id|>\n\n")[-1]

            # No need to split on eot_id since it's at the end
            response = response.strip()
            if response.endswith("<|eot_id|>"):
                response = response[:-len("<|eot_id|>")]

            # Add response to chat history
            chat_history.append(('assistant', response))
            
            # Update chat display
            html = []
            for role, content in chat_history:
                color = "blue" if role == "user" else "green"
                html.append(f'<div style="color: {color}"><b>{role}:</b> {content}</div>')
            chat_display.value = '<br>'.join(html)

            status_display.value = ""

        except Exception as e:
            status_display.value = f'Error: {str(e)}'
            # You can access chat_history for context
            chat_history.pop()
            
            # Update chat display
            html = []
            for role, content in chat_history:
                color = "blue" if role == "user" else "green"
                html.append(f'<div style="color: {color}"><b>{role}:</b> {content}</div>')
            chat_display.value = '<br>'.join(html)
            

def on_mode_change(change):
    if change['new'] == 'Base Mode':
        chat_display.layout.display = 'none'
        output_display.layout.display = 'block'
    else:
        chat_display.layout.display = 'block'
        output_display.layout.display = 'none'
    input_box.value = ''
    status_display.value = ''

def on_steering_generate_click(b):
    #setattr(status_display, 'value', 'Steering vector generated!')
    steering_prompt = steering_input.value.strip()
    if not steering_prompt:
        status_display.value = 'Please enter steering prompt'
        return
    
    steering_dataset = make_dataset(
        "{persona}",
        [steering_prompt],
        ["an AI"],
        "",
        truncated_output_suffixes,
    )

    status_display.value = f'Made dataset for prompt "{steering_prompt}", now training...'

    base_model.reset()
    global steering_vector
    steering_vector = ControlVector.train_with_sae(
        base_model,
        tokenizer,
        sae,
        steering_dataset,
        batch_size=32,
        method="pca_center",
        hidden_layers=control_layers
    )

    status_display.value = 'Trained, now calculating maximum steering coefficient...'

    input_text = "I am"
    token_count = 24

    #print(f"Calculating for {token_count} tokens...")
    coefficients, all_perplexities, _ = calculate_perplexities_over_sequence(base_model, tokenizer, input_text, steering_vector, token_count, start_coeff=0., iterations=10, end_coeff=1.0)
    x = np.array(coefficients)
    perplexities = [pplx[-1] for pplx in all_perplexities]
    normalized_perplexities = [pplx / perplexities[0] for pplx in perplexities]
    y = np.array(normalized_perplexities)

    upper_bound_idx = find_significant_drop(x, y)
    global upper_bound
    if upper_bound_idx:
        upper_bound = coefficients[upper_bound_idx-1]
    else:
        upper_bound = 1.0
        upper_bound_idx = len(coefficients)-1
    # After calculating the new upper_bound
    strength_slider.min = -1 * upper_bound
    strength_slider.max = upper_bound
    # Optionally reset value to 0 or clamp to new range
    strength_slider.value = max(min(strength_slider.value, upper_bound), -1 * upper_bound)

    status_display.value = 'Recommended min/max coefficients are -{} and {}. Ready to steer!'.format(upper_bound, upper_bound)

    visualize_discontinuities(x, y, [upper_bound_idx])

def on_clear_history(b):
    global chat_history
    chat_history = []
    chat_display.value = ""
    status_display.value = ""

def on_mode_change(change):
    if change['new'] == 'Base Mode':
        chat_display.layout.display = 'none'
        clear_history_button.layout.display = 'none'  # Hide clear button in base mode
        output_display.layout.display = 'block'
    else:
        chat_display.layout.display = 'block'
        clear_history_button.layout.display = 'block'  # Show clear button in chat mode
        output_display.layout.display = 'none'
    input_box.value = ''
    status_display.value = ''

# Wire up callbacks
generate_button.on_click(on_generate_click)
steering_generate_button.on_click(on_steering_generate_click)
mode_toggle.observe(on_mode_change, names='value')
clear_history_button.on_click(on_clear_history)

# Create and display interface
display(widgets.VBox([
    widgets.HTML("<h3>Text Generation Demo</h3>"),
    mode_toggle,
    steering_input,
    steering_generate_button,
    input_box,
    strength_slider,
    generate_button,
    clear_history_button,
    status_display,
    output_display,
    chat_display
]))