In [1]:
import os
import sys

daam_path = os.path.join(os.getcwd(), "daam")
if daam_path not in sys.path:
    sys.path.append(daam_path)

In [2]:
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from matplotlib import pyplot as plt
from diffusers import StableDiffusionPipeline
from daam import GenerationExperiment, trace, set_seed
from vis_utils import compute_cosine_similarity

%config InlineBackend.figure_format='retina'

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)

In [4]:
heatmaps_car_red = []
heatmaps_car_blue = []
heatmaps_blue = []
heatmaps_red = []

In [None]:
prompt = "A photo of a red car"
word = "car"

article_layers = [3, 6, 9, 10, 12, 14, 16]
layers_without_mid = [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16]

gen = set_seed(228)

fig, ax = plt.subplots(3, 5, constrained_layout=True)
plt.suptitle(f'prompt: "{prompt}",\n word: "{word}"', fontsize=18)
plt.figure(figsize=(5, 10))

for ax_ in ax.flatten():
    ax_.set_xticks([])
    ax_.set_yticks([])

with torch.no_grad():
    with trace(pipe) as tc:
        out = pipe(prompt, num_inference_steps=50, generator=gen)
        for i, layer_idx in enumerate(layers_without_mid):
            ax_ = ax[i // 5, i % 5]
            heat_map_car_red = tc.compute_global_heat_map(layer_idx=i)
            heat_map_car_red = heat_map_car_red.compute_word_heat_map(word)
            heatmaps_car_red.append(heat_map_car_red)
            heat_map_car_red.plot_overlay(out.images[0], ax=ax_)
            ax_.set_title(f'Layer {layer_idx}', size=14)
            
plt.show()

In [None]:
prompt = "A photo of a red car"
word = "red"

article_layers = [3, 6, 9, 10, 12, 14, 16]
layers_without_mid = [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16]

gen = set_seed(228)

fig, ax = plt.subplots(3, 5, constrained_layout=True)
plt.suptitle(f'prompt: "{prompt}",\n word: "{word}"', fontsize=18)
plt.figure(figsize=(5, 10))

for ax_ in ax.flatten():
    ax_.set_xticks([])
    ax_.set_yticks([])

with torch.no_grad():
    with trace(pipe) as tc:
        out = pipe(prompt, num_inference_steps=50, generator=gen)
        for i, layer_idx in enumerate(layers_without_mid):
            ax_ = ax[i // 5, i % 5]
            heat_map_red = tc.compute_global_heat_map(layer_idx=i)
            heat_map_red = heat_map_red.compute_word_heat_map(word)
            heatmaps_red.append(heat_map_red)
            heat_map_red.plot_overlay(out.images[0], ax=ax_)
            ax_.set_title(f'Layer {layer_idx}', size=14)
            
plt.show()

In [None]:
prompt = "A photo of a blue car"
word = "car"

article_layers = [3, 6, 9, 10, 12, 14, 16]
layers_without_mid = [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16]

gen = set_seed(228)

fig, ax = plt.subplots(3, 5, constrained_layout=True)
plt.suptitle(f'prompt: "{prompt}",\n word: "{word}"', fontsize=18)
plt.figure(figsize=(5, 10))

for ax_ in ax.flatten():
    ax_.set_xticks([])
    ax_.set_yticks([])

with torch.no_grad():
    with trace(pipe) as tc:
        out = pipe(prompt, num_inference_steps=50, generator=gen)
        for i, layer_idx in enumerate(layers_without_mid):
            ax_ = ax[i // 5, i % 5]
            heat_map_car_blue = tc.compute_global_heat_map(layer_idx=i)
            heat_map_car_blue = heat_map_car_blue.compute_word_heat_map(word)
            heatmaps_car_blue.append(heat_map_car_blue)
            heat_map_car_blue.plot_overlay(out.images[0], ax=ax_)
            ax_.set_title(f'Layer {layer_idx}', size=14)
            
plt.show()

In [None]:
prompt = "A photo of a blue car"
word = "blue"

article_layers = [3, 6, 9, 10, 12, 14, 16]
layers_without_mid = [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16]

gen = set_seed(228)

fig, ax = plt.subplots(3, 5, constrained_layout=True)
plt.suptitle(f'prompt: "{prompt}",\n word: "{word}"', fontsize=18)
plt.figure(figsize=(5, 10))

for ax_ in ax.flatten():
    ax_.set_xticks([])
    ax_.set_yticks([])

with torch.no_grad():
    with trace(pipe) as tc:
        out = pipe(prompt, num_inference_steps=50, generator=gen)
        for i, layer_idx in enumerate(layers_without_mid):
            ax_ = ax[i // 5, i % 5]
            heat_map_blue = tc.compute_global_heat_map(layer_idx=i)
            heat_map_blue = heat_map_blue.compute_word_heat_map(word)
            heatmaps_blue.append(heat_map_blue)
            heat_map_blue.plot_overlay(out.images[0], ax=ax_)
            ax_.set_title(f'Layer {layer_idx}', size=14)
            
plt.show()

In [None]:
heatmaps_color_sim = []
heatmaps_car_sim = []

for heatmap_car_red, heatmap_car_blue in zip(heatmaps_car_red, heatmaps_car_blue):
    cosine_similarity = compute_cosine_similarity(heatmap_car_red.heatmap, heatmap_car_blue.heatmap)
    heatmaps_car_sim.append(cosine_similarity)

for heatmap_red, heatmap_blue in zip(heatmaps_red, heatmaps_blue):
    cosine_similarity = compute_cosine_similarity(heatmap_red.heatmap, heatmap_blue.heatmap)
    heatmaps_color_sim.append(cosine_similarity)

In [None]:
for idx, value in enumerate(heatmaps_car_sim):
    print(f"layer_idx={layers_without_mid[idx]}, sim={value}")

In [None]:
for idx, value in enumerate(heatmaps_color_sim):
    print(f"layer_idx={layers_without_mid[idx]}, sim={value}")