# Consistent Self Attention

### Paper Summary

Source: https://arxiv.org/pdf/2405.01434  
GitHub: https://github.com/HVision-NKU/StoryDiffusion/blob/main/README.md

The authors proposed a novel method called Consistent Self-Attention to generate subject-consistent images in a training-free manner within the U-Net architecture used for image generation models e.g. StableDiffussion models.

### Key Concepts and Methodology

**Objective:**

- To ensure subject consistency across a batch of images during generation by maintaining the consistency of characters, faces, and attires.

**Self-Attention in Diffusion Models:**

- Traditional self-attention mechanisms operate independently within each image in a batch.
- Given image features $ I \in \mathbb{R}^{B \times N \times C} $ (where $B$ is batch size, $N$ is the number of tokens, and $C$ is the number of channels), self-attention projects features to queries $(Q)$, keys $(K)$, and values $(V)$ and computes attention as:  

    $O_i = \operatorname{Attention}(Q_i, K_i, V_i)$

**Consistent Self-Attention:**

- **Token Sampling**: To establish connections between images, the method samples tokens $S_i$ from other images in the batch. This is done using a random sampling function:  

  $S_i = \operatorname{RandSample}(I_1, I_2, \ldots, I_{i-1}, I_{i+1}, \ldots, I_B)$

- **Token Pairing**: The sampled tokens $(S_i)$ are paired with the original image feature $(I_i)$ to form a new set of tokens $(P_i)$.  
- **Linear Projections**: Perform linear projections on $(P_i)$ to generate new key $(K_{P_i})$ and value $(V_{P_i})$ for Consistent Self-Attention, while keeping the original query $(Q_i)$ unchanged.
- **Consistent Self-Attention Calculation**: The self-attention is then computed across the batch, promoting interactions among features of different images:

  $O_i = \operatorname{Attention}(Q_i, K_{P_i}, V_{P_i})$

Given the paired tokens, the method performs the self-attention across a batch of images, facilitating interactions among features of different images. This type of interaction promotes the model to the convergence of characters, faces, and attires during the generation process.


## Install and Import Packages & Libraries

In [1]:
# Install packages
!pip install --quiet diffusers gradio accelerate safetensors huggingface_hub

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m39.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m101.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m309.4/309.4 kB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.0/92.0 kB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.1/318.1 kB[0m [31m44.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.0/145.0 kB[0m [31m22.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m116.3

In [2]:
# Import libraries
from google.colab import drive

import os
import sys
import gradio as gr
import numpy as np
import requests
import random
import copy
import pickle
from PIL import Image, ImageDraw, ImageFont
from tqdm.auto import tqdm
import textwrap

import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusers import StableDiffusionXLPipeline, DiffusionPipeline, StableDiffusion3Pipeline, DDIMScheduler
from diffusers.utils import load_image

from huggingface_hub import notebook_login

# Mount Google Drive to the Colab environment
drive.mount('/content/drive')

# Log in to the Hugging Face Hub
hf_token = os.getenv('HF_TOKEN')
notebook_login(hf_token)

# Check CUDA availability
print(f"CUDA is available: {torch.cuda.is_available()}")

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)


Mounted at /content/drive
User is already logged in.
CUDA is available: True


## Consistent Self-Attention Processors


In [3]:
class AttnProcessor(torch.nn.Module):
  """
  Processor for implementing scaled dot-product attention.
  This class modifies the self-attention mechanism specifically tailored for use in Stable Diffusion XL models.
  """

  def __init__(self, hidden_size=None, cross_attention_dim=None):
    """
    Initialize the attention processor.

    Arguments:
    hidden_size (int, optional): Size of the hidden layers
    cross_attention_dim (int, optional): Dimension for cross attention
    """
    super().__init__()


  def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
    """
    Apply the attention mechanism to the input hidden states.

    Arguments:
    attn: Attention module containing necessary submodules and parameters
    hidden_states (torch.Tensor): Input hidden states
    encoder_hidden_states (torch.Tensor, optional): Hidden states from the encoder for cross attention
    attention_mask (torch.Tensor, optional): Mask to apply on the attention scores
    temb: Temporal embedding, if any

    Returns:
    torch.Tensor: Output hidden states after applying the attention mechanism
    """

    residual = hidden_states

    # Apply spatial normalization if available
    if attn.spatial_norm is not None:
      hidden_states = attn.spatial_norm(hidden_states, temb)

    input_ndim = hidden_states.ndim

    # If input is 4-dimensional (batch_size, channel, height, width)
    if input_ndim == 4:
      batch_size, channel, height, width = hidden_states.shape
      hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

    batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)

    # Prepare the attention mask if provided
    if attention_mask is not None:
      attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
      attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

    # Apply group normalization if available
    if attn.group_norm is not None:
      hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

    # Compute query
    query = attn.to_q(hidden_states)

    # Set encoder hidden states and normalize if required
    if encoder_hidden_states is None:
      encoder_hidden_states = hidden_states
    elif attn.norm_cross:
      encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

    # Compute key and value
    key = attn.to_k(encoder_hidden_states)
    value = attn.to_v(encoder_hidden_states)

    inner_dim = key.shape[-1]
    head_dim = inner_dim // attn.heads

    # Reshape query, key, and value for multi-head attention
    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

    # Apply scaled dot-product attention
    hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)

    # Reshape and project the output
    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
    hidden_states = hidden_states.to(query.dtype)
    hidden_states = attn.to_out[0](hidden_states)
    hidden_states = attn.to_out[1](hidden_states)

    # Reshape back to original dimensions if input was 4-dimensional
    if input_ndim == 4:
      hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

    # Apply residual connection if required
    if attn.residual_connection:
      hidden_states = hidden_states + residual

    # Rescale the output
    hidden_states = hidden_states / attn.rescale_output_factor

    return hidden_states

In [4]:
def cal_attn_mask_xl(total_length, id_length, sa32, sa64, height, width, device="cuda", dtype=torch.float16):
  """
  Calculate attention masks for different patch sizes (32x32 and 64x64) for use in neural network attention mechanisms.

  Arguments:
  total_length (int): Total length of sequences
  id_length (int): Length of identifiers
  sa32 (float): Probability threshold for 32x32 patches
  sa64 (float): Probability threshold for 64x64 patches
  height (int): Height of the input image
  width (int): Width of the input image
  device (str, optional): Device to perform computations on. Defaults to "cuda"
  dtype (torch.dtype, optional): Data type for the tensors. Defaults to torch.float16

  Returns:
  mask1024 (torch.Tensor): Attention mask for 32x32 patches
  mask4096 (torch.Tensor): Attention mask for 64x64 patches
  """

  # Calculate the number of 32x32 and 64x64 patches in the input image
  nums_1024 = (height // 32) * (width // 32)
  nums_4096 = (height // 16) * (width // 16)

  # Create boolean matrices with random values compared to sa32 and sa64
  bool_matrix1024 = torch.rand((1, total_length * nums_1024), device=device, dtype=dtype) < sa32
  bool_matrix4096 = torch.rand((1, total_length * nums_4096), device=device, dtype=dtype) < sa64

  # Repeat the matrices to match total_length
  bool_matrix1024 = bool_matrix1024.repeat(total_length, 1)
  bool_matrix4096 = bool_matrix4096.repeat(total_length, 1)

  for i in range(total_length):
    # Set specific regions in the matrices based on id_length and i
    bool_matrix1024[i:i+1, id_length * nums_1024:] = False
    bool_matrix4096[i:i+1, id_length * nums_4096:] = False
    bool_matrix1024[i:i+1, i * nums_1024:(i + 1) * nums_1024] = True
    bool_matrix4096[i:i+1, i * nums_4096:(i + 1) * nums_4096] = True

  # Create the final masks by reshaping and repeating the boolean matrices
  mask1024 = bool_matrix1024.unsqueeze(1).repeat(1, nums_1024, 1).reshape(-1, total_length * nums_1024)
  mask4096 = bool_matrix4096.unsqueeze(1).repeat(1, nums_4096, 1).reshape(-1, total_length * nums_4096)

  return mask1024, mask4096

In [5]:
# Consistent Self Attention Processor

class SpatialAttnProcessor(torch.nn.Module):
  """
  Attention processor for IP-Adapter.

  Arguments:
  hidden_size ('int'): Hidden size of the attention layer
  cross_attention_dim ('int'): Number of channels in the 'encoder_hidden_states'
  id_length ('int', defaults to 4): Length of the identifier for stored hidden states
  device ('str', defaults to "cuda"): Device to run the computations on
  dtype ('torch.dtype', defaults to torch.float16): Data type for tensor computations
  """

  def __init__(self, hidden_size=None, cross_attention_dim=None, id_length=4, device="cuda", dtype=torch.float16):

    super().__init__()

    self.device = device
    self.dtype = dtype
    self.hidden_size = hidden_size
    self.cross_attention_dim = cross_attention_dim
    self.total_length = id_length + 1
    self.id_length = id_length
    # Dictionary to store hidden states for each step
    self.id_bank = {}

  def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):

    global total_count, attn_count, cur_step, mask1024, mask4096
    global sa32, sa64
    global write
    global height, width

    # Store the hidden states for the current step
    if write:
      self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]
    else:
      encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))

    # Skip attention computation in early steps
    if cur_step < 5:
      hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
    else:
      random_number = random.random()
      if cur_step <20:
        rand_num = 0.3
      else:
        rand_num = 0.1
      if random_number > rand_num:
        if not write:
          if hidden_states.shape[1] == (height//32) * (width//32):
            attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:]
          else:
            attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:]
        else:
          if hidden_states.shape[1] == (height//32) * (width//32):
            attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length, :mask1024.shape[0] // self.total_length * self.id_length]
          else:
            attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length, :mask4096.shape[0] // self.total_length * self.id_length]
        hidden_states = self.__call1__(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
      else:
        hidden_states = self.__call2__(attn, hidden_states, None, attention_mask, temb)

    attn_count +=1

    if attn_count == total_count:
      attn_count = 0
      cur_step += 1
      mask1024, mask4096 = cal_attn_mask_xl(self.total_length, self.id_length, sa32, sa64, height, width, device=self.device, dtype= self.dtype)

    return hidden_states

  def __call1__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):

    residual = hidden_states

    if attn.spatial_norm is not None:
      hidden_states = attn.spatial_norm(hidden_states, temb)

    input_ndim = hidden_states.ndim

    if input_ndim == 4:
      total_batch_size, channel, height, width = hidden_states.shape
      hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2)

    total_batch_size, nums_token, channel = hidden_states.shape
    img_nums = total_batch_size//2
    hidden_states = hidden_states.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel)

    batch_size, sequence_length, _ = hidden_states.shape

    if attn.group_norm is not None:
      hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

    query = attn.to_q(hidden_states)

    if encoder_hidden_states is None:
      encoder_hidden_states = hidden_states # B, N, C
    else:
      encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel)

    key = attn.to_k(encoder_hidden_states)
    value = attn.to_v(encoder_hidden_states)

    inner_dim = key.shape[-1]
    head_dim = inner_dim // attn.heads

    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

    hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
    hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim)
    hidden_states = hidden_states.to(query.dtype)

    # Linear projection and dropout
    hidden_states = attn.to_out[0](hidden_states)
    hidden_states = attn.to_out[1](hidden_states)

    if input_ndim == 4:
      hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width)
    if attn.residual_connection:
      hidden_states = hidden_states + residual

    hidden_states = hidden_states / attn.rescale_output_factor

    return hidden_states


  def __call2__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):

    residual = hidden_states

    if attn.spatial_norm is not None:
      hidden_states = attn.spatial_norm(hidden_states, temb)

    input_ndim = hidden_states.ndim

    if input_ndim == 4:
      batch_size, channel, height, width = hidden_states.shape
      hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

    batch_size, sequence_length, channel = hidden_states.shape

    if attention_mask is not None:
      attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
      attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

    if attn.group_norm is not None:
      hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

    query = attn.to_q(hidden_states)

    if encoder_hidden_states is None:
      encoder_hidden_states = hidden_states  # B, N, C
    else:
      encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,sequence_length,channel).reshape(-1,(self.id_length+1) * sequence_length,channel)

    key = attn.to_k(encoder_hidden_states)
    value = attn.to_v(encoder_hidden_states)

    inner_dim = key.shape[-1]
    head_dim = inner_dim // attn.heads

    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

    hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)

    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
    hidden_states = hidden_states.to(query.dtype)

    # Linear projection and dropout
    hidden_states = attn.to_out[0](hidden_states)
    hidden_states = attn.to_out[1](hidden_states)

    if input_ndim == 4:
      hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

    if attn.residual_connection:
      hidden_states = hidden_states + residual

    hidden_states = hidden_states / attn.rescale_output_factor

    return hidden_states

## Comic Style Lists

In [6]:
style_list = [
    {
        "name": "(No style)",
        "prompt": "{prompt}",
        "negative_prompt": "",
    },
    {
        "name": "Japanese Anime",
        "prompt": "anime artwork illustrating {prompt}. created by japanese anime studio. highly emotional. best quality, high resolution, (Anime Style, Manga Style:1.3), Low detail, sketch, concept art, line art, webtoon, manhua, hand drawn, defined lines, simple shades, minimalistic, High contrast, Linear compositions, Scalable artwork, Digital art, High Contrast Shadows",
        "negative_prompt": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
    },
    {
        "name": "Digital/Oil Painting",
        "prompt": "{prompt} . (Extremely Detailed Oil Painting:1.2), glow effects, godrays, Hand drawn, render, 8k, octane render, cinema 4d, blender, dark, atmospheric 4k ultra detailed, cinematic sensual, Sharp focus, humorous illustration, big depth of field",
        "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
    },
    {
        "name": "Pixar/Disney Character",
        "prompt": "Create a Disney Pixar 3D style illustration on {prompt} . The scene is vibrant, motivational, filled with vivid colors and a sense of wonder.",
        "negative_prompt": "lowres, bad anatomy, bad hands, text, bad eyes, bad arms, bad legs, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, blurry, grayscale, noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo",
    },
    {
        "name": "Photographic",
        "prompt": "cinematic photo {prompt} . Hyperrealistic, Hyperdetailed, detailed skin, matte skin, soft lighting, realistic, best quality, ultra realistic, 8k, golden ratio, Intricate, High Detail, film photography, soft focus",
        "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
    },
    {
        "name": "Comic Book",
        "prompt": "comic {prompt} . Graphic illustration, comic art, graphic novel art, vibrant, highly detailed",
        "negative_prompt": "photograph, deformed, glitch, noisy, realistic, stock photo, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
    },
    {
        "name": "Line Art",
        "prompt": "line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",
        "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
    },
    {
        "name": "Black and White Film Noir",
        "prompt": "{prompt} . (b&w, Monochromatic, Film Photography:1.3), film noir, analog style, soft lighting, subsurface scattering, realistic, heavy shadow, masterpiece, best quality, ultra realistic, 8k",
        "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
    },
    {
        "name": "Isometric Rooms",
        "prompt": "Tiny cute isometric {prompt} . in a cutaway box, soft smooth lighting, soft colors, 100mm lens, 3d blender render",
        "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
    },
]

styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}

# Comic Generation

## Technical Configurations

In [7]:
# Release unoccupied cached memory currently held by the CUDA memory allocator back to the GPU
torch.cuda.empty_cache()

# Function to set up random seed for reproducibility
def setup_seed(seed):
  """
  Set up random seed for reproducibility across different runs
  """
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  np.random.seed(seed)
  random.seed(seed)
  torch.backends.cudnn.deterministic = True

In [8]:
# Choice of default style - see Comic Style Lists
default_style_name = "(No style)"  # Default style for the comic generation

# Dictionary of Text2Image models holding model names and their paths
# Only works for models with UNet e.g. does not work for SD3 which has Rectified Flow Transformer
models_dict = {"RealVision":"SG161222/RealVisXL_V4.0",
               "SDXL":"stabilityai/stable-diffusion-xl-base-1.0"}

# Global variables for attention and model processing
global attn_count, total_count, id_length, total_length, cur_step, cur_model_type
global write
global sa32, sa64
global height, width
global mask1024, mask4096
global attn_procs, unet

# Initialize the attention count and step count
attn_count = 0  # Counter for the number of attention calls
cur_step = 0  # Counter for the current processing step

# ID and total length settings
id_length = 3  # Number of images used for self attention
total_length = id_length + 1  # Total length (identifier length + 1)

# Current model type and device settings
cur_model_type = ""  # Variable to hold the type of the current model
device = "cuda"  # Device to be used (GPU in this case)

# Write flag for controlling the attention processing behavior
write = False  # Flag to control if data should be written to id_bank

# Strength of consistent self-attention: the larger, the stronger
sa32 = 0.5  # Strength of self-attention for 32x32 resolution
sa64 = 0.5  # Strength of self-attention for 64x64 resolution

# Resolution of the comic to be generated
height = 800  # Height of the generated image
width = 800  # Width of the generated image

# Number of steps for the generation process
num_steps = 35

# Path to the SDXL model
sd_model_path = models_dict["RealVision"]  # Retrieve the path for the SDXL model from the dictionary

## Load Pipeline

In [9]:
## LOAD STABLE DIFFUSION PIPELINE
# Reference: https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/stable_diffusion_xl

pipe = DiffusionPipeline.from_pretrained(sd_model_path,
                                         torch_dtype=torch.float16,
                                         use_safetensors=True).to(device)

# Improve generation quality with FreeU
# Reference: https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu
pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)

# Set the scheduler to be DDIM
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

# Set timesteps or inference steps for denoising
pipe.scheduler.set_timesteps(50)

# Define the UNet architecture
unet = pipe.unet

total_count = 0
attn_procs = {}

## INSERT PAIRED ATTENTION
for name in unet.attn_processors.keys():

    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim

    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]
    if cross_attention_dim is None and (name.startswith("up_blocks") ) :
        attn_procs[name] =  SpatialAttnProcessor(id_length = id_length)
        total_count +=1
    else:
        attn_procs[name] = AttnProcessor(hidden_size, cross_attention_dim)

print("Successfully load consistent self-attention")
print(f"Number of processors : {total_count}")

unet.set_attn_processor(attn_procs)

mask1024, mask4096 = cal_attn_mask_xl(total_length, id_length, sa32, sa64, height, width, device=device, dtype= torch.float16)

model_index.json:   0%|          | 0.00/577 [00:00<?, ?B/s]

Fetching 18 files:   0%|          | 0/18 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.78G [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/474 [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

text_encoder_2/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/737 [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer_2/special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

tokenizer_2/tokenizer_config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/10.3G [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/602 [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Successfully load consistent self-attention
Number of processors : 36


In [10]:
# Create memory efficiency
pipe.enable_model_cpu_offload()
torch.cuda.empty_cache()

## Generate Comics

### [USER] Set Guidance Scale, Seed, Style and Prompts

In [11]:
# Scale factor for guidance during image generation (strength of guidance)
guidance_scale = 4.0  # Adjusts the strength of guidance in the image generation process

# Set the seed for reproducibility
seed = 2024

# Set the generated Style
style_name = "Comic Book"

# General description of the character
# general_prompt = "a medieval English king in battle armour, red cape and crown"
general_prompt = "A medieval English king wearing a red cape, silver armour and crown."

# Negative prompt to avoid undesirable features in the generated images
negative_prompt = "naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation"

# Array of specific scene descriptions
prompt_array = [
    "Standing at the top of the hill with medieval soldiers (in silver armour and red capes) surrounding and forming a shield wall.",
    "Looking down the hill at enemy Norman soldiers charging up the hill towards him and his soldiers. Arrows are flying in the air in the background.",
    "Fight and defend against enemy Norman soldiers. Swords and shields are clashing in fierce battle.",
    "Look at enemy Norman soldiers running away, thinking to himself that they have fled, the medieval English king instructed his soldiers (in silver armour and red capes) to charge.",
    "It was a trap by the enemy Norman soldiers. The medieval English soldiers (in silver armour and red capes) were attacked and lay down motionless on floor.",
    "Laying with face down on ground dead and not moving with many other medieval English soldiers (in silver armour and red capes) laying dead."
]

captions_array = [
    "King Harold Godwinson and his army were able to position themselves along a ridge at the top of a hill.",
    "The epic Battle of Hastings is about to begin. William of Normandy is surrounded by his knights and foot soldiers who are mixture of Normans and European mercenaries.",
    "William of Normandy and his knights were the first to attack.",
    "William feigned retreat leading to the English fyrdsmen chasing after them leaving the shield wall.",
    "Norman soldiers suddenly turned around and charged through the wall causing great damage.",
    "King Harold and his brothers made the final stand and fought to their death."
]

### Run Model

In [12]:
def apply_style_positive(style_name: str, positive: str):
  """
  Apply a given style and insert the positive prompt.
  Used to generate initial images for application of Consistent Self-Attention.

  Arguments:
  style_name (str): Name of the style to apply
  positive (str): Positive prompt

  Returns:
  str: Styled positive prompt.
  """
  p, n = styles.get(style_name, styles[default_style_name])

  return p.replace("{prompt}", positive)

def apply_style(style_name: str, positives: list, negative: str = ""):
  """
  Apply a given style to a list of positive prompts and combine with a negative prompt.

  Args:
        style_name (str): Name of the style to apply.
        positives (list): List of positive prompts.
        negative (str): Negative prompt to combine with.

  Returns:
  tuple: List of styled positive prompts and combined negative prompt.
  """
  p, n = styles.get(style_name, styles[default_style_name])

  return [p.replace("{prompt}", positive) for positive in positives], n + ' ' + negative


# Setup the random seed for reproducibility
setup_seed(seed)

# Create a generator for random number generation with the specified seed
generator = torch.Generator(device="cuda").manual_seed(seed)

# Combine the general prompt with each specific scene description
prompts = [general_prompt + "," + prompt for prompt in prompt_array]

# Separate combined prompts into identity prompts to enable Consistent Self-Attention and image prompts
id_prompts = prompts[:id_length]
real_prompts = prompts[id_length:]

# Enable self-attention for identity images generation
write = True
cur_step = 0
attn_count = 0

# Apply the style to the ID prompts
id_prompts, negative_prompt = apply_style(style_name, id_prompts, negative_prompt)

# Generate identity images to enable Consistent Self-Attention
id_images = pipe(id_prompts, num_inference_steps=num_steps,
                 guidance_scale=guidance_scale, height=height,
                 width=width, negative_prompt=negative_prompt, generator=generator).images

# Disable self-attention for the main images generation
write = False

# Display the identity images
id_images_list = []
for id_image in id_images:
  id_images_list.append(id_image)
  display(id_image)

# Generate and display the images
real_images_list = []
for real_prompt in real_prompts:
    cur_step = 0
    real_prompt = apply_style_positive(style_name, real_prompt)
    real_images_list.append(
        pipe(
            real_prompt,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            height=height,
            width=width,
            negative_prompt=negative_prompt,
            generator=generator
            ).images[0]
        )

# Display the final generated images
for real_image in real_images_list:
  display(real_image)

In [13]:
# Create memory efficiency
pipe.enable_model_cpu_offload()
torch.cuda.empty_cache()

### Organize images into comics

In [14]:
def create_comic_strip(images, texts, output_path, image_width, image_height, panels_horizontal, panels_vertical, border_size=10, text_height=150, font_path=None, font_size=30):
  """
  Organize images into a comic strip with borders and text boxes.

  Arguments:
  images: List of images
  texts: List of text descriptions for each panel
  output_path: Path to save the final comic strip
  image_width: Width of each image
  image_height: Height of each image
  panels_horizontal: Number of panels placed horizontally
  panels_vertical: Number of panels placed vertically
  border_size: Size of the border around each panel
  text_height: Height of the text box at the bottom of each panel
  font_path: Path to the font file for the text
  font_size: Font size for the text

  Returns:
  Comic strip saved as image
  """
  # Ensure the number of images matches the number of panels
  num_panels = panels_horizontal * panels_vertical
  if len(images) != num_panels:
      raise ValueError(f"The number of images ({len(images)}) does not match the total number of panels ({num_panels})")

  # Ensure the number of texts matches the number of panels
  if len(texts) != num_panels:
      raise ValueError(f"The number of texts ({len(texts)}) does not match the total number of panels ({num_panels})")

  # Calculate total dimensions of the comic strip
  total_width = (image_width + 2 * border_size) * panels_horizontal
  total_height = (image_height + 2 * border_size + text_height) * panels_vertical

  # Create a new image with the total dimensions
  new_image = Image.new('RGB', (total_width, total_height), 'white')

  # Load the font
  font = ImageFont.truetype(font_path, size=font_size) if font_path else ImageFont.load_default()

  # Paste images into the new image
  for i, (image, text) in enumerate(zip(images, texts)):
      image = image.resize((image_width, image_height))

      # Create a panel with a border
      panel = Image.new('RGB', (image_width + 2 * border_size, image_height + 2 * border_size + text_height), 'white')
      panel.paste(image, (border_size, border_size))

      # Draw the text box
      draw = ImageDraw.Draw(panel)
      text_position = (border_size, image_height + 2 * border_size)
      draw.rectangle([text_position, (panel.width - border_size, panel.height - border_size)], fill="white")

      # Wrap the text to fit within the text box
      wrapped_text = textwrap.fill(text, width=(image_width // font_size * 2))

      # Add the text to the text box
      draw.text((text_position[0] + 10, text_position[1] + 10), wrapped_text, font=font, fill="black")

      # Calculate position in the new image
      x_offset = (i % panels_horizontal) * (image_width + 2 * border_size)
      y_offset = (i // panels_horizontal) * (image_height + 2 * border_size + text_height)
      new_image.paste(panel, (x_offset, y_offset))

  # Save the new comic strip image
  new_image.save(output_path)

In [15]:
# Create comic strip
images = id_images_list + real_images_list

output_path = 'comic_strip.png'
image_width = 800
image_height = 800
panels_horizontal = 3
panels_vertical = 2
border_size = 10
text_height = 150
font_path = '/content/drive/My Drive/Colab Notebooks/W210_Capstone/ComicNeue-BoldItalic.ttf'
font_size = 30

create_comic_strip(images, captions_array, output_path, image_width, image_height, panels_horizontal, panels_vertical, border_size, text_height, font_path, font_size)
