# Stable Diffusion 1.5 + SAE

⚠️WORK IN PROGRESS⚠️ - currently not ready for use

This notebook demonstrates the use of Stable Diffusion 1.5 for concept unlearning using Sparse Autoencoders (SAE).

## Setup and Imports

This section handles imports, device configuration, and environment checks.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# SAE part
import torch
from diffusers import StableDiffusionPipeline
from einops import rearrange
from IPython.display import display
from overcomplete.visualization import overlay_top_heatmaps

from utils.sae_utils import (
    callback_handler,
    criterion,
    sae_integration_hook,
    sae_train,
    select_features,
)

model_id = "sd-legacy/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
print(f"PyTorch CUDA version: {torch.version.cuda}")
print(f"CUDA device count: {torch.cuda.device_count()}")
print(f"Using device: {device}")

## Model loading and activation caching

Model loading

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    safety_checker=None
).to(device)


Add hooks in the Cross-Attention layers

In [None]:
cached_activations = {}
hook_handles = []

# 1. Select Text Representation (Text Embedding)
# The output of the last layer of the Text Encoder (always has shape: [1, 77, 768])
text_encoder_output_layer = pipe.text_encoder.text_model.encoder.layers[-1]
# hook_handles.append(
#     text_encoder_output_layer.register_forward_hook(
#         save_activation("text_embedding")
#     )
# )

# 2. Select Latent Representation (Latent Codes)
# We will use the output of the second Transformer in Down Block 2
target_unet_block_path = pipe.unet.up_blocks[1].attentions[1].transformer_blocks[0]
# hook_handles.append(
#     target_unet_block_path.register_forward_hook(
#         save_activation("unet_latent_up_block_1_att_1")
#     )
# )

Model inference

In [None]:
num_inference_steps = 10
guidance_scale = 8.0
num_images_per_prompt = 1
prompt = "a football ball"
generator = torch.Generator(device).manual_seed(42)

image = pipe(
    prompt,
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale,
    generator=generator,
    callback_on_step_end=callback_handler,
    callback_steps=1,
    num_images_per_prompt=num_images_per_prompt
    ).images[0]

display(image)

In [None]:
cached_activations['unet_latent_up_block_1_att_1'].shape

Zapis aktywacji w postaci tensora do pliku

In [None]:
# Processing cached activations for SAE training
guided_cached_activations = {}
for i in range(0, cached_activations['unet_latent_up_block_1_att_1'].shape[0], 2):
    guided_cached_activations[i//2] = cached_activations['unet_latent_up_block_1_att_1'][i:i+1]

# Change to a tensor of shape (num_steps, batch_size=1, channels, height, width)
guided_cached_activations_tensor = torch.cat(list(guided_cached_activations.values()), dim=0)
# Save guided_cached_activations_tensor to file for later use
torch.save(guided_cached_activations_tensor, "guided_cached_activations_tensor.pt")

# Load guided_cached_activations_tensor from file
guided_cached_activations_tensor = torch.load("guided_cached_activations_tensor.pt")

## SAE trining example

1. Get activations

In [None]:
activations = torch.load("guided_cached_activations_tensor.pt")
activations = activations.float()
activations.shape

2. Transform into 

In [None]:
activations = rearrange(activations, 'n t d -> (n t) d').float()
activations.shape

3. Use SAE

In [None]:
sae = sae_train(
    activations,
    criterion=criterion,
    expansion_factor=16,
    top_k=32,
    batch_size=1024,
    num_epochs=15,
    learning_rate=1e-3
)

## SEA feature selection

In [None]:
guided_cached_activations = torch.load("guided_cached_activations.pt")

guided_cached_activations_tensor_true = torch.cat(list(guided_cached_activations.values()), dim=0)

guided_cached_activations_tensor_false = torch.load("guided_cached_activations_tensor.pt")

In [None]:
select_features(
    guided_cached_activations_tensor_true,
    guided_cached_activations_tensor_false,
    10,
    sae
)

## Inference with SAE integration

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    safety_checker=None
).to(device)

In [None]:
# Register the hook after training SAE
hook_handles = []
target_unet_block = pipe.unet.up_blocks[1].attentions[1].transformer_blocks[0]
hook_handles.append(target_unet_block.register_forward_hook(sae_integration_hook))

In [None]:
# Clear previous encodings
num_inference_steps = 10
guidance_scale = 8.0
num_images_per_prompt = 1
prompt = "a football ball"
generator = torch.Generator(device).manual_seed(42)

image = pipe(
    prompt,
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale,
    generator=generator,
    callback_on_step_end=callback_handler,
    callback_steps=1,  # Wywołaj callback po każdym kroku
    num_images_per_prompt=num_images_per_prompt
).images[0]

display(image)

# After generation, inspect extracted encodings
# print(f"Extracted {len(sae_encodings)} sets of encodings (one per step).")

## Visualization of concept extraction
requires cached_activations,sae,image

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    safety_checker=None
).to(device)

In [None]:
activations = cached_activations['unet_latent_up_block_1_att_1'].to(device)
with torch.no_grad():
    pre_codes, codes = sae.encode(activations.reshape(-1, activations.shape[-1]))

codes = rearrange(codes, '(b t) d -> b t d', b=activations.shape[0], t=activations.shape[1])
codes = rearrange(codes, 'b (w h) d -> b w h d', w=16, h=16)

image_tensor = np.array(image)
image_tensor = torch.tensor(image_tensor).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0

In [None]:
for i in [39, 41, 78, 109, 154, 328, 362, 413, 513, 708]:
  print('Concept', i)
  # image_tensor: tensor shape (B, C, H, W), codes: tensor shape (B, C, 16, D)
  overlay_top_heatmaps(image_tensor, codes, concept_id=i)
  plt.show()