## StoryDiffusion: Consistent Self-Attention for Long-Range Image and Video Generation  
[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md-dark.svg)]()
[[Paper]()] &emsp; [[Project Page]()] &emsp; <br>

### Import Packages

In [None]:
# %load_ext autoreload
# %autoreload 2
import gradio as gr
import numpy as np
import torch
import requests
import random
import os
import sys
import pickle
from PIL import Image
from tqdm.auto import tqdm
from datetime import datetime
from utils.gradio_utils import is_torch2_available
if is_torch2_available():
    from utils.gradio_utils import \
        AttnProcessor2_0 as AttnProcessor
else:
    from utils.gradio_utils  import AttnProcessor

import diffusers
from diffusers import StableDiffusionXLPipeline
from diffusers import DDIMScheduler
import torch.nn.functional as F
from utils.gradio_utils import cal_attn_mask_xl
import copy
import os
from diffusers.utils import load_image
from utils.utils import get_comic
from utils.style_template import styles

In [None]:
torch.cuda.set_device(1)

In [None]:
import torch
torch.cuda.empty_cache()
torch.cuda.set_device(1)

### Set Config 

In [None]:
## Global
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "(No style)"
torch.cuda.is_available()

In [None]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

    
#################################################
########Consistent Self-Attention################
#################################################
class SpatialAttnProcessor2_0(torch.nn.Module):
    r"""
    Attention processor for IP-Adapater for PyTorch 2.0.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        text_context_len (`int`, defaults to 77):
            The context length of the text features.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
    """

    def __init__(self, hidden_size = None, cross_attention_dim=None,id_length = 4,device = "cuda",dtype = torch.float16):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.device = device
        self.dtype = dtype
        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.total_length = id_length + 1
        self.id_length = id_length
        self.id_bank = {}

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None):
        global total_count,attn_count,cur_step,mask1024,mask4096
        global sa32, sa64
        global write
        global height,width
        # __import__('ipdb').set_trace()
        
        if write:
            # print(f"white:{cur_step}")
            self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]
        else:
            encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))

        # skip in early step
        if cur_step <5:
            hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
        else:   # 256 1024 4096
            random_number = random.random()
            if cur_step <20:
                rand_num = 0.3
            else:
                rand_num = 0.1
            if random_number > rand_num:
                # __import__('ipdb').set_trace()
            # if False:
                if not write:
                    if hidden_states.shape[1] == (height//32) * (width//32):
                        attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:]
                    else:
                        attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:]
                else:
                    if hidden_states.shape[1] == (height//32) * (width//32):
                        attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length,:mask1024.shape[0] // self.total_length * self.id_length]
                    else:
                        attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length,:mask4096.shape[0] // self.total_length * self.id_length]
                if os.environ.get("DEBUG_MODE") == "true":
                    if encoder_hidden_states is not None:
                        print("call encoder hidden_states: ", encoder_hidden_states.shape)
                    else:
                        print("call encoder hidden_states: None")
                    if hidden_states is not None:
                        print("call hidden_states: ", hidden_states.shape)
                    else:
                        print("call hidden_states: None")
                    print("call attention_mask: ", attention_mask.shape)
                hidden_states = self.__call1__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
            else:
                hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
        attn_count +=1
        if attn_count == total_count:
            attn_count = 0
            cur_step += 1
            print("height, width = ", height, width)
            mask1024,mask4096 = cal_attn_mask_xl(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype)
            print("mask1024,mask4096 = ", mask1024.shape, mask4096.shape)
            if attention_mask is not None:
                print("attention_mask = ", attention_mask.shape)

        return hidden_states
    def __call1__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        if os.environ.get("DEBUG_MODE") == "true":
            print("call1 hidden_states: ", hidden_states.shape)
        residual = hidden_states
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)
        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            total_batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2)
        total_batch_size,nums_token,channel = hidden_states.shape
        img_nums = total_batch_size//2
        hidden_states = hidden_states.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel)

        batch_size, sequence_length, _ = hidden_states.shape

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states  # B, N, C
        else:
            encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)


        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        
        # print("call1 query, key, value, ", query.shape, key.shape, value.shape, )
        # if attention_mask is not None:
        #     print("attn_mask", attention_mask.shape)
        
        # __import__('ipdb').set_trace()
        
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)


        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width)
        if attn.residual_connection:
            hidden_states = hidden_states + residual
        hidden_states = hidden_states / attn.rescale_output_factor
        # print(hidden_states.shape)
        return hidden_states
   
    def __call2__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None):
        
        if os.environ.get("DEBUG_MODE") == "true":
            print("call2 hidden_states: ", hidden_states.shape)
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, channel = (
            hidden_states.shape
        )
        # print(hidden_states.shape)
        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states  # B, N, C
        else:
            encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,sequence_length,channel).reshape(-1,(self.id_length+1) * sequence_length,channel)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # print("call2 query, key, value, ", query.shape, key.shape, value.shape,)
        # if attention_mask is not None:
        #     print("call2 attention_mask: ", attention_mask.shape)
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

### Load Pipeline

In [None]:
global attn_count, total_count, id_length, total_length,cur_step, cur_model_type
global write
global  sa32, sa64
global height,width

global attn_procs,unet
attn_procs = {}

attn_count = 0
total_count = 0
cur_step = 0
id_length = 4
total_length = 5
cur_model_type = ""
device="cuda"
###
write = False
### strength of consistent self-attention: the larger, the stronger
sa32 = 0.5
sa64 = 0.5
### Res. of the Generated Comics. Please Note: SDXL models may do worse in a low-resolution! 
height = 512
width = 512


In [None]:
from omegaconf import OmegaConf
config = OmegaConf.load("./config/inference.yaml")


from diffusers import AutoencoderKL, EulerDiscreteScheduler

from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection

sd_model_path = "../models/sd_xl"
# Load Component
tokenizer	 = CLIPTokenizer.from_pretrained(sd_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(sd_model_path, subfolder="text_encoder")
vae			 = AutoencoderKL.from_pretrained(sd_model_path, subfolder="vae")
tokenizer_two = CLIPTokenizer.from_pretrained(sd_model_path, subfolder="tokenizer_2")
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(sd_model_path, subfolder="text_encoder_2")

from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
# init unet model
unet = UNet3DConditionModel.from_pretrained_2d(sd_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(config.unet_additional_kwargs))
# print("unet.state_dict()", unet.state_dict().keys())


scheduler = EulerDiscreteScheduler(timestep_spacing='leading', steps_offset=1,	**config.noise_scheduler_kwargs)



In [None]:
# ###
# sd_model_path = "../models/sd_xl"
# ### LOAD Stable Diffusion Pipeline
# pipe = StableDiffusionXLPipeline.from_pretrained(sd_model_path, torch_dtype=torch.float16, use_safetensors=False)
# pipe = pipe.to(device)
# # pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)

# from animatediff.models.unet import UNet3DConditionModel
# from omegaconf import OmegaConf
# config = OmegaConf.load("./config/inference.yaml")
# from diffusers import AutoencoderKL, EulerDiscreteScheduler
# scheduler = EulerDiscreteScheduler(timestep_spacing='leading', steps_offset=1,	**config.noise_scheduler_kwargs)
# # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
# # pipe.scheduler.set_timesteps(50)

# unet = UNet3DConditionModel.from_pretrained_2d(sd_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(config.unet_additional_kwargs))
# # print("unet.state_dict()", unet.state_dict().keys())

# from animatediff.pipelines.pipeline_animation import AnimationPipeline
# pipe = AnimationPipeline(
#     unet=unet, vae=pipe.vae, tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, scheduler=scheduler,
#     text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2,
# ).to(device)
# print("pipe", pipe)

### Insert PairedAttention

In [None]:

## Insert PairedAttention
for name in unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]
    if cross_attention_dim is None and (name.startswith("up_blocks") ) :
        attn_procs[name] =  SpatialAttnProcessor2_0(id_length = id_length)
        total_count +=1
    else:
        attn_procs[name] = AttnProcessor()
print("successsfully load consistent self-attention")
print(f"number of the processor : {total_count}")

unet.set_attn_processor(copy.deepcopy(attn_procs))
global mask1024,mask4096
mask1024, mask4096 = cal_attn_mask_xl(total_length,id_length,sa32,sa64,height,width,device=device,dtype= torch.float16)


## 将插入CSA的unet 投入到 AnimationPipeline中

In [None]:

pipeline = AnimationPipeline(
        unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=scheduler,
        text_encoder_2 = text_encoder_two, tokenizer_2=tokenizer_two
).to("cuda")
# print("pipeline = ", pipeline)

### insert motion module

In [None]:
# motion_module_path = "/userhome/37/ahhfdkx/AnimateDiff_sdxl/models/Motion_Module/mm_sdxl_v10_beta.ckpt"
# motion_module_ckpt = torch.load(motion_module_path, map_location="cpu")
# motion_module_state_dict = {}

# print("motion_module = ", motion_module_ckpt.keys())
# print("pipe.unet.state_dict()", pipe.unet.state_dict().keys())
# m_k = None
# for k, v in motion_module_ckpt.items():
#     if 'motion_module' in k and k in pipe.unet.state_dict().keys():
#         motion_module_state_dict[k] = v
#         m_k = k
#         # print("pipeline.unet.state_dict()", pipeline.unet.state_dict().keys())
#     elif 'motion_module' in k and k not in pipe.unet.state_dict().keys():
#         print(k)
# pipe.unet.load_state_dict(motion_module_state_dict, strict=False)

# del motion_module_ckpt
# del motion_module_state_dict
# print(f'Loading motion module from {motion_module_path}...')

# pipe.unet = pipe.unet.half()
# pipe.text_encoder = pipe.text_encoder.half()
# pipe.text_encoder_2 = pipe.text_encoder_2.half()
# pipe.enable_model_cpu_offload()
# pipe.enable_vae_slicing()
# # print(unet)

In [None]:

from animatediff.utils.util import load_weights, save_videos_grid
savedir = "./output"
pipeline = load_weights(
		pipeline = pipeline,
		motion_module_path = config.get("motion_module_path", ""),
		ckpt_path = config.get("ckpt_path", ""),
		lora_path = config.get("lora_path", ""),
		lora_alpha = config.get("lora_alpha", 0.8)
	)

pipeline.unet = pipeline.unet.half()
pipeline.text_encoder = pipeline.text_encoder.half()
pipeline.text_encoder_2 = pipeline.text_encoder_2.half()
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()

prompts	   = config.prompt
n_prompts  = config.n_prompt

random_seeds = config.get("seed", [-1])
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
seeds = []
samples = []

write = True
with torch.inference_mode():
    for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
        # manually set random seed for reproduction

        # if random_seed != -1: torch.manual_seed(random_seed)
        # else: torch.seed()
        # seeds.append(torch.initial_seed())
        print(f"current seed: {torch.initial_seed()}")
        print(f"sampling {prompt} ...")
        print("n_prompt", n_prompt)
        print("prompt", prompt)
        sample = pipeline(
            prompt,
            negative_prompt	  = n_prompt,
            num_inference_steps = config.get('steps', 20),
            guidance_scale	  = config.get('guidance_scale', 10),
            width				  = width,
            height			  = height,
            single_model_length = 4,
        ).videos
        print("sample = ", sample)
        samples.append(sample)
        # save video
        save_videos_grid(sample, f"{savedir}/b.mp4")
        print(f"save to {savedir}/sample/{prompt}.mp4")

samples = torch.concat(samples)
save_videos_grid(samples, f"{savedir}/sample-{datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')}.mp4", n_rows=4)
# config.seed = seeds
OmegaConf.save(config, f"{savedir}/config.yaml")



In [None]:
if os.environ.get("DEBUG_MODE") == "true":
    print("mask1024", mask1024.shape)
    print("mask4096", mask4096.shape)
    print(mask1024)
    print(mask1024.shape[0] // 5 * 4)
    print(mask1024[196:].shape)
    total_length = 5
    id_length = 4
    attention_mask = mask1024[:mask1024.shape[0] // total_length * id_length,:mask1024.shape[0] // total_length * id_length]
    print(attention_mask.shape)


### Create the text description for the comics
Tips: Existing text2image diffusion models may not always generate images that accurately match text descriptions. Our training-free approach can improve the consistency of characters, but it does not enhance the control over the text. Therefore, in some cases, you may need to carefully craft your prompts.

In [None]:
guidance_scale = 10.0
seed = 2047
sa32 = 0.5
sa64 = 0.5
id_length = 4
num_steps = 30

general_prompt = "a man with a black suit"
negative_prompt = "naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation"
prompt_array = ["wake up in the bed",
                "have breakfast",
                "is on the road, go to the company",
                "work in the company",
                "running in the playground",
                "reading book in the home"
                ]

def apply_style_positive(style_name: str, positive: str):
    p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
    return p.replace("{prompt}", positive) 
def apply_style(style_name: str, positives: list, negative: str = ""):
    p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
    return [p.replace("{prompt}", positive) for positive in positives], n + ' ' + negative
### Set the generated Style
style_name = "Comic book"
setup_seed(seed)
generator = torch.Generator(device="cuda").manual_seed(seed)
print(id_length)
prompts = [general_prompt + " " + prompt for prompt in prompt_array]
id_prompts = prompts[:id_length]
print("id_prompts", id_prompts)
real_prompts = prompts[id_length:]
write = False
cur_step = 0
attn_count = 0
_, negative_prompt = apply_style(style_name, id_prompts, negative_prompt)
# id_prompts, negative_prompt = apply_style(style_name, id_prompts, negative_prompt)
print("id_prompts", id_prompts)
print(id_prompts[0])
print(negative_prompt)


In [None]:
import os
os.environ["DEBUG_MODE"] = "false"
from animatediff.utils.util import save_videos_grid
write = True
id_images = []
print("negative_prompt", negative_prompt)
with torch.inference_mode():
    for id_prompt in id_prompts:
        print("cur_step", cur_step)
        print("id_prompt", id_prompt)
        torch.seed()
        sample = pipe(
                    id_prompt,
                    num_inference_steps = num_steps, 
                    guidance_scale=guidance_scale, 
                    height = height, # height is global once initialized all the sames ！！！ [todo]
                    width = width,
                    # negative_prompt = negative_prompt,
                    negative_prompt = "",
                    single_model_length = 16,
                    generator = generator
                ).videos
        id_images.append(sample)
        print("sample ", sample)
        save_videos_grid(sample, f"./output/sample_{cur_step}.mp4")

### Continued Creation
From now on, you can create endless stories about this character without worrying about memory constraints.

### Make pictures into comics