# Simple example of how to use the SAE to intervene on a diffusion model

In [14]:
import torch
from SAE.sae import Sae
from SAE.hooked_sd_noised_pipeline import HookedStableDiffusionPipeline
import utils.hooks as hooks

In [10]:
model_name = "CompVis/stable-diffusion-v1-4"
dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
hub_name = "bcywinski/SAeUron_coco"
hookpoint = "unet.up_blocks.1.attentions.1"

## Load models

In [None]:
model = HookedStableDiffusionPipeline.from_pretrained(
    model_name,
    torch_dtype=dtype,
    safety_checker=None,
).to(device)


In [None]:
sae = Sae.load_from_hub(hub_name, hookpoint=hookpoint, device=device).to(dtype)


## Create intervention hook

In [79]:
feature_idx = 11627
multiplier = -1.0

In [80]:
intervention_hook = hooks.SAEFeatureInterventionHook(
    sae=sae,
    feature_idx=feature_idx,
    multiplier=multiplier,
)

## Run intervention
Run intervention multiplying the selected feature on each denoising step

In [81]:
prompt = "A photo of a cat"
steps = 50
guidance_scale = 7.5
seed = 0

In [82]:
images = model.run_with_hooks(
    prompt=prompt,
    generator=torch.Generator(device=device).manual_seed(seed),
    num_inference_steps=steps,
    guidance_scale=guidance_scale,
    position_hook_dict={hookpoint: intervention_hook},
)

In [None]:
images[0]