# Interactive Chat with Steering

In [1]:
import torch
import os
import json
import sys
import asyncio
from pathlib import Path

sys.path.append('.')
sys.path.append('..')
sys.path.append(str(Path.home() / 'git' / 'chatspace'))

from vllm import SamplingParams
from chatspace.generation.vllm_steer_model import (
    VLLMSteerModel,
    VLLMSteeringConfig,
    SteeringSpec,
    LayerSteeringSpec,
    ProjectionCapSpec,
    AddSpec
)

torch.set_float32_matmul_precision('high')

INFO 11-06 14:52:26 [__init__.py:216] Automatically detected platform cuda.


## Model Configuration

In [None]:
MODEL_NAME = "Qwen/Qwen3-32B"
MODEL_READABLE = "Qwen 3 32B"
OUTPUT_DIR = "./results/qwen-3-32b/interactive_steering"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Steering config path
STEERING_CONFIG_PATH = "/workspace/qwen-3-32b/capped/configs/contrast/role_trait_sliding_config.pt"

## Load Steering Config

In [None]:
def load_steering_config(config_path):
    """Load steering config and return available experiments."""
    print(f"Loading steering config from {config_path}")
    cfg_data = torch.load(config_path, map_location='cpu')
    
    print(f"\nFound {len(cfg_data['vectors'])} vectors")
    print(f"Found {len(cfg_data['experiments'])} experiments")
    
    # Show first few experiment IDs
    print("\nFirst 10 experiments:")
    for i, exp in enumerate(cfg_data['experiments'][:10]):
        print(f"  {i}: {exp['id']}")
    
    return cfg_data

steering_config = load_steering_config(STEERING_CONFIG_PATH)

In [None]:
def build_steering_spec(cfg_data, experiment_id):
    """Build a SteeringSpec from a config experiment.
    
    Handles both 'cap' (projection capping) and 'coeff' (additive) interventions.
    """
    # Find the experiment
    experiment = None
    if isinstance(experiment_id, int):
        experiment = cfg_data['experiments'][experiment_id]
    else:
        for exp in cfg_data['experiments']:
            if exp['id'] == experiment_id:
                experiment = exp
                break
    
    if experiment is None:
        raise ValueError(f"Experiment '{experiment_id}' not found")
    
    print(f"\nBuilding steering spec for: {experiment['id']}")
    print(f"Interventions: {len(experiment['interventions'])}")
    
    layers = {}
    for intervention in experiment['interventions']:
        vector_name = intervention['vector']
        
        # Get the vector and layer from the vectors dict
        vec_data = cfg_data['vectors'][vector_name]
        layer_idx = vec_data['layer']
        vector = vec_data['vector'].to(dtype=torch.float32).contiguous()
        
        # Check if this is capping or additive intervention
        if 'cap' in intervention:
            cap_value = float(intervention['cap'])
            
            
            # Create ProjectionCapSpec
            cap_spec = ProjectionCapSpec(
                vector=vector,
                max=cap_value
            )
            
            # Add to or update layer spec
            if layer_idx not in layers:
                layers[layer_idx] = LayerSteeringSpec(projection_cap=cap_spec)
            else:
                # Note: This will overwrite if multiple caps on same layer
                # You may need more complex logic for multiple interventions per layer
                layers[layer_idx].projection_cap = cap_spec
        
        elif 'coeff' in intervention:
            coeff = float(intervention['coeff'])
            
            # Normalize vector and create AddSpec with scale
            vector_norm = vector / (vector.norm() + 1e-8)
            original_norm = vector.norm().item()
            
            add_spec = AddSpec(
                vector=vector_norm,
                scale=coeff * original_norm
            )
            
            if layer_idx not in layers:
                layers[layer_idx] = LayerSteeringSpec(add=add_spec)
            else:
                layers[layer_idx].add = add_spec
    
    affected_layers = sorted(layers.keys())
    print(f"Affecting layers: {affected_layers}")
    
    return SteeringSpec(layers=layers)

## Initialize Model

In [None]:
# Determine bootstrap layers from config
all_layers = set()
for vec_data in steering_config['vectors'].values():
    all_layers.add(vec_data['layer'])
bootstrap_layers = tuple(sorted(all_layers))

print(f"Bootstrap layers: {len(bootstrap_layers)} layers")
print(f"Range: {min(bootstrap_layers)} to {max(bootstrap_layers)}")

# Create model config
vllm_cfg = VLLMSteeringConfig(
    model_name=MODEL_NAME,
    tensor_parallel_size=1,  # Adjust based on your GPU setup
    gpu_memory_utilization=0.95,
    max_model_len=8192,
    dtype="auto",
    bootstrap_layers=bootstrap_layers
)

# Initialize model (this may take a while)
print(f"\nInitializing {MODEL_READABLE}...")
model = VLLMSteerModel(vllm_cfg, enforce_eager=True)
print(f"âœ… Model {MODEL_READABLE} loaded successfully!")

## Conversation State

In [None]:
conversation_history = []
current_steering_spec = None

## Chat Functions

In [None]:
async def chat_interactive_async(
    message, 
    steering_spec=None,
    show_history=False,
    max_tokens=3000,
    temperature=0.7,
    use_steering=True
):
    """Interactive chat function with optional steering.
    
    Args:
        message: User message to send
        steering_spec: SteeringSpec to apply (uses current_steering_spec if None)
        show_history: Whether to print conversation history
        max_tokens: Maximum tokens to generate
        temperature: Sampling temperature
        use_steering: Whether to apply steering at all
    """
    global conversation_history, current_steering_spec
    
    # Add user message to history
    conversation_history.append({"role": "user", "content": message})
    
    # Prepare sampling params
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        temperature=temperature
    )
    
    # Use provided steering spec or fall back to current
    active_spec = steering_spec if steering_spec is not None else current_steering_spec
    
    # Generate with or without steering
    if use_steering and active_spec is not None:
        # Use steering context manager for temporary application
        async with model.steering(active_spec):
            responses = await model.chat(conversation_history, sampling_params)
    else:
        # No steering
        responses = await model.chat(conversation_history, sampling_params)
    
    # Extract response text
    response = responses[0].full_text()
    
    # Add assistant response to history
    conversation_history.append({"role": "assistant", "content": response})
    
    # Print conversation
    print(f"ðŸ‘¤ You: {message}")
    print(f"ðŸ¤– {MODEL_READABLE}: {response}")
    
    if show_history:
        print(f"\nðŸ“œ Conversation so far ({len(conversation_history)} turns):")
        for i, turn in enumerate(conversation_history):
            role_emoji = "ðŸ‘¤" if turn["role"] == "user" else "ðŸ¤–" 
            content_preview = turn['content'][:100] + "..." if len(turn['content']) > 100 else turn['content']
            print(f"  {i+1}. {role_emoji} {content_preview}")
    
    return response

# Synchronous wrapper for notebook convenience
def chat_interactive(
    message,
    steering_spec=None,
    show_history=False,
    max_tokens=3000,
    temperature=0.7,
    use_steering=True
):
    """Synchronous wrapper for chat_interactive_async."""
    return asyncio.run(
        chat_interactive_async(
            message,
            steering_spec=steering_spec,
            show_history=show_history,
            max_tokens=max_tokens,
            temperature=temperature,
            use_steering=use_steering
        )
    )

## Helper Functions

In [None]:
def save_conversation(filename):
    """Save the current conversation to a file."""
    if not conversation_history:
        print("No conversation to save!")
        return
    
    conversation_data = {
        "model": MODEL_NAME,
        "turns": len(conversation_history),
        "conversation": conversation_history,
        "steering_active": current_steering_spec is not None
    }
    
    filepath = os.path.join(OUTPUT_DIR, filename)
    with open(filepath, 'w') as f:
        json.dump(conversation_data, f, indent=2)
    
    print(f"ðŸ’¾ Conversation saved to: {filepath}")
    return filepath


def reset_conversation():
    """Reset the conversation history."""
    global conversation_history
    conversation_history = []
    print("ðŸ”„ Conversation history cleared!")


def delete_last_turn():
    """Delete the last turn from the conversation history."""
    global conversation_history
    if conversation_history:
        # Remove last two entries (user + assistant)
        if len(conversation_history) >= 2:
            conversation_history = conversation_history[:-2]
            print("ðŸ”„ Last turn deleted!")
        else:
            conversation_history = []
            print("ðŸ”„ Conversation cleared (was incomplete turn)!")
    else:
        print("No conversation to delete!")


def set_steering(experiment_id):
    """Set the current steering spec by experiment ID or index."""
    global current_steering_spec
    current_steering_spec = build_steering_spec(steering_config, experiment_id)
    print(f"âœ… Steering set to experiment: {experiment_id}")


def clear_steering():
    """Clear the current steering."""
    global current_steering_spec
    current_steering_spec = None
    print("âœ… Steering cleared")


def list_experiments(start=0, end=20):
    """List available experiments."""
    experiments = steering_config['experiments'][start:end]
    print(f"\nExperiments {start} to {min(end, len(steering_config['experiments']))}:")
    for i, exp in enumerate(experiments, start=start):
        print(f"  {i}: {exp['id']} ({len(exp['interventions'])} interventions)")
    print(f"\nTotal experiments: {len(steering_config['experiments'])}")

## Example Usage

In [None]:
# List available experiments
list_experiments(0, 20)

In [None]:
# Set steering to a specific experiment
set_steering(0)  # Use experiment index 0

In [None]:
# Chat with steering active
chat_interactive("Hello! How are you today?")

In [None]:
# Continue the conversation
chat_interactive("What's your perspective on artificial consciousness?")

In [None]:
# Compare with no steering
reset_conversation()
chat_interactive("Hello! How are you today?", use_steering=False)

In [None]:
# Try a different steering experiment
reset_conversation()
set_steering(5)  # Different experiment
chat_interactive("Hello! How are you today?")

In [None]:
# Save the conversation
save_conversation("steering_experiment_5.json")

In [None]:
# Clear steering and continue
clear_steering()
chat_interactive("What do you think about AI safety?")

## One-off Steering

You can also apply steering for just a single message without changing the global state:

In [None]:
# Build a one-off steering spec
temp_spec = build_steering_spec(steering_config, 10)

# Use it for just this message
chat_interactive(
    "Tell me about yourself.",
    steering_spec=temp_spec
)