In [None]:
import sys
import torch
import pandas as pd

sys.path.append('../src')

from utils import load_pipeline, get_unet_layers, show_images
from extraction import collect_dataset_activations, get_top_k_layers, mask_vectors_by_top_k
from analysis import compute_steering_vectors, compute_layer_scores
from steering import generate_with_steering

In [None]:
DATASET_PATH = "../data/extraction/dogs.csv" 
STEERING_STRENGTH = -2.5 
NUM_STEPS = 30           
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TOP_K_LAYER_NAV = 8
TIMESTEPS = list(range(0, NUM_STEPS + 1))
GUIDANCE = 7.5

print(f"Running on {DEVICE}")
print(f"Dataset: {DATASET_PATH}")
print(f"Steering strength: {STEERING_STRENGTH}")


In [None]:
pipe = load_pipeline()
nets = get_unet_layers(pipe, extract_resnet=False, extract_attentions=True)

extraction_dataset = pd.read_csv(DATASET_PATH)

forget_prompts = extraction_dataset['positive'].tolist()[:50]
retain_prompts = extraction_dataset['negative'].tolist()[:50]

print(f"Loaded {len(extraction_dataset)} prompt pairs.")

In [None]:
forget_acts, retain_acts = collect_dataset_activations(
    pipe,
    forget_prompts,
    retain_prompts,
    total_steps=NUM_STEPS,
    guidance=GUIDANCE,
    nets=nets,
    layers=nets.keys(),
    timesteps=TIMESTEPS
)

print("Extraction complete.")

In [None]:
steering_vectors = compute_steering_vectors(forget_acts, retain_acts)

print(f"Computed steering vectors for {len(steering_vectors)} layers.")

In [None]:
results = compute_layer_scores(retain_acts, forget_acts, timesteps=TIMESTEPS, top_k=TOP_K_LAYER_NAV)
top_k_per_timestep = get_top_k_layers(results, k=TOP_K_LAYER_NAV)
filtered_steering_vectors = mask_vectors_by_top_k(
    steering_vectors,
    timesteps=TIMESTEPS,
    top_k_per_timestep=top_k_per_timestep
)

In [None]:
user_prompt = input("Enter a prompt (e.g., 'A cute dog in the park'): ")
if not user_prompt: 
    user_prompt = "A cute dog in the park"

print(f"\nGenerating for: '{user_prompt}'")

# base image
base_images = pipe(
    user_prompt,
    num_inference_steps=NUM_STEPS,
    generator=torch.Generator(device="cuda").manual_seed(362),
).images


steered_images = generate_with_steering(
    pipe,
    user_prompt,
    GUIDANCE,
    nets,
    filtered_steering_vectors,
    timesteps=TIMESTEPS,
    inference_steps=NUM_STEPS,
    lam=STEERING_STRENGTH,
)

show_images([base_images[0], steered_images[0]], [user_prompt, f'Steered with lambda = {STEERING_STRENGTH}'])