## StoryDiffusion: Consistent Self-Attention for Long-Range Image and Video Generation  
[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md-dark.svg)]()
[[Paper]()] &emsp; [[Project Page]()] &emsp; <br>

### Import Packages

In [1]:
# %load_ext autoreload
# %autoreload 2
import gradio as gr
import numpy as np
import torch
import requests
import random
import os
import sys
import pickle
from PIL import Image
from tqdm.auto import tqdm
from datetime import datetime
from utils.gradio_utils import is_torch2_available
if is_torch2_available():
    from utils.gradio_utils import \
        AttnProcessor2_0 as AttnProcessor
else:
    from utils.gradio_utils  import AttnProcessor

import diffusers
from diffusers import StableDiffusionXLPipeline
from diffusers import DDIMScheduler
import torch.nn.functional as F
from utils.gradio_utils import cal_attn_mask_xl
import copy
import os
from diffusers.utils import load_image
from utils.utils import get_comic
from utils.style_template import styles

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# torch.cuda.set_device(1)

In [3]:
import torch
torch.cuda.empty_cache()

### Set Config 

In [4]:
## Global
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "(No style)"
torch.cuda.is_available()

True

In [5]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

    
#################################################
########Consistent Self-Attention################
#################################################
class SpatialAttnProcessor2_0(torch.nn.Module):
    r"""
    Attention processor for IP-Adapater for PyTorch 2.0.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        text_context_len (`int`, defaults to 77):
            The context length of the text features.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
    """
    def __init__(self, hidden_size = None, cross_attention_dim=None,id_length = 4,device = "cuda",dtype = torch.float16, single_model_length = 4):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.device = device
        self.dtype = dtype
        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.total_length = id_length + 1
        self.id_length = id_length
        self.id_bank = {}
        self.single_model_length = single_model_length

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None):
        global total_count,attn_count,cur_step,mask1024,mask4096
        global sa32, sa64
        global write
        global height,width
        
        # print("original hidden_states.shape", hidden_states.shape)
        # __import__('ipdb').set_trace()
        hidden_states = hidden_states.view(-1, self.single_model_length , hidden_states.shape[-2], hidden_states.shape[-1])
        if write:
            pass
            # self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]] # batch_size  = id_length * 2
        else:
            encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].repeat(self.id_length,1,1).to(self.device),
                                               hidden_states[:1],
                                               self.id_bank[cur_step][1].repeat(self.id_length,1,1).to(self.device),
                                               hidden_states[1:]))
        # skip in early step
        if cur_step <5:
            hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
        else:   # 256 1024 4096
            random_number = random.random()
            if cur_step <20:
                rand_num = 0.3
            else:
                rand_num = 0.1
            # if True:
            if random_number > rand_num:
            # if False:
                if not write:
                    if hidden_states.shape[2] == (height//32) * (width//32):
                        attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:]
                    else:
                        attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:]
                else:
                    if hidden_states.shape[2] == (height//32) * (width//32): 
                        attention_mask = mask1024[: mask1024.shape[0] // self.total_length * self.id_length, : mask1024.shape[0] // self.total_length * self.id_length]
                    else:
                        attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length,:mask4096.shape[0] // self.total_length * self.id_length]
                hidden_states = self.__call1__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
            else:
                hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
        attn_count +=1
        if attn_count == total_count:
            attn_count = 0
            cur_step += 1
            # mask1024,mask4096 = cal_attn_mask_xl(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype)
            print("total_count",total_count)
        return hidden_states
    def __call1__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        # __import__('ipdb').set_trace()
        residual = hidden_states
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)
        input_ndim = hidden_states.ndim

        if input_ndim == 5:
            total_batch_size, video_length, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(total_batch_size, video_length, channel, height * width).transpose(2, 3)
            
        total_batch_size, video_length, nums_token, channel = hidden_states.shape
        
        img_nums = total_batch_size//2  # = id_length 
        hidden_states = hidden_states.view(-1,img_nums,nums_token,channel) # (video_length * 2, id_length, nums_token, channel)
        # nums_token = 256, channel = 1280
        hidden_states_frame1 = hidden_states[:2, :, :, :] #[1 / 2, id_length, 256, 1280]
        hidden_states_frame1 = hidden_states_frame1.reshape(-1,img_nums * nums_token,channel) # (1, id_length*nums_token, 1280)
        
        batch_size = hidden_states_frame1.shape[0]
        hidden_states = hidden_states.reshape(-1,img_nums * nums_token,channel)
        
        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
            
        hidden_states = hidden_states.view(-1, nums_token, channel)
        hidden_states_frame1 = hidden_states_frame1.view(-1, nums_token, channel)

        query = attn.to_q(hidden_states_frame1)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states_frame1  # B, N, C
        else:
            print("this is encoder_hidden_states, encoder_hidden_states.shape", encoder_hidden_states.shape)
            encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 默认是 选择了hiddenstate的第一个token作为query， 只考虑了batchsize 而没有考虑video_length
        key = key.view(batch_size , -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size , -1, attn.heads, head_dim).transpose(1, 2)
        # query = query.view(batch_size // (video_length * 4 ), -1, attn.heads, head_dim).transpose(1, 2)  # 默认是 选择了hiddenstate的第一个token作为query， 只考虑了batchsize 而没有考虑video_length

        # key = key.view(batch_size // (video_length * 4 ), -1, attn.heads, head_dim).transpose(1, 2)
        # value = value.view(batch_size // (video_length * 4 ), -1, attn.heads, head_dim).transpose(1, 2)
        # __import__('ipdb').set_trace()
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.2, is_causal=False
        ) #[2, 20, 768, 64] 对于sample单帧没有问题了
        hidden_states = hidden_states.repeat(16, 1, 1, 1)
        
        # __import__("ipdb").set_trace()
        # query.shape = [24, 20, 1024, 64]
        hidden_states = hidden_states.transpose(1, 2)
        
        hidden_states = hidden_states.reshape(total_batch_size * video_length, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 5:
            hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, video_length, channel, height, width)
        if attn.residual_connection:
            hidden_states = hidden_states + residual
        hidden_states = hidden_states / attn.rescale_output_factor
        # print(hidden_states.shape)
        return hidden_states
   
    def __call2__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None):
        
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim
        if input_ndim == 5:
            batch_size, video_length, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, video_length, channel, height * width).transpose(2, 3)

        batch_size, video_length, sequence_length, channel = (
            hidden_states.shape
        ) # sequence_length = nums_token
        
        if attention_mask is not None:
            print("call2 attention_mask.shape", attention_mask.shape)
        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, video_length, sequence_length, batch_size)
            print("call2 attention_mask.shape", attention_mask.shape)
            # scaled_dot_product_attention expects attention_mask shape to be  (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        hidden_states = hidden_states.view(-1, sequence_length, channel)
        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states  # B, N, C
        else:
            encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,sequence_length,channel).reshape(-1,(self.id_length+1) * sequence_length,channel)
            # [10, 256, 1280] -> [2, 5, 256, 1280] -> [10, 1280, 1280]
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # print("call2 query.shape", query.shape)
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.2, is_causal=False
        )
        # print("hidden_states.shape", hidden_states.shape)

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size * video_length , -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 5:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, video_length, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor
        # print("final hidden_states.shape", hidden_states.shape)

        return hidden_states

### Load Pipeline

In [6]:
from omegaconf import OmegaConf
config = OmegaConf.load("./config/inference.yaml")
global attn_count, total_count, id_length, total_length,cur_step, cur_model_type
global write
global  sa32, sa64
global height,width

global attn_procs,unet

global pipeline
global sd_model_path
attn_procs = {}

attn_count = 0
total_count = 0
cur_step = 0
id_length = 2 # >=3 都可以
cur_model_type = ""
device="cuda"
###
write = False
### strength of consistent self-attention: the larger, the stronger
sa32 = 0.5
sa64 = 0.5
### Res. of the Generated Comics. Please Note: SDXL models may do worse in a low-resolution! 
height = config.get("height", 512)
width = config.get("width", 512)
single_model_length = config.get("single_model_length", 4)
# 
# id_length * single_model_length // 4 
total_length = id_length + 1
print("id_length", id_length)
print("total_length", total_length)
print("single_model_length", single_model_length)
print("height, width", height, width)


id_length 3
total_length 4
single_model_length 16
height, width 512 512


In [7]:

from diffusers import AutoencoderKL, EulerDiscreteScheduler

from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection

sd_model_path = "../models/sd_xl"
# Load Component
tokenizer	 = CLIPTokenizer.from_pretrained(sd_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(sd_model_path, subfolder="text_encoder")
vae			 = AutoencoderKL.from_pretrained(sd_model_path, subfolder="vae")
tokenizer_two = CLIPTokenizer.from_pretrained(sd_model_path, subfolder="tokenizer_2")
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(sd_model_path, subfolder="text_encoder_2")

from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
# init unet model
unet = UNet3DConditionModel.from_pretrained_2d(sd_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(config.unet_additional_kwargs))
# print("unet.state_dict()", unet.state_dict().keys())
scheduler = EulerDiscreteScheduler(timestep_spacing='leading', steps_offset=1,	**config.noise_scheduler_kwargs)



loaded temporal unet's pretrained weights from ../models/sd_xl/unet ...
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
using vanilla temporal module
### missing keys: 420; 
### unexpected keys: 0;
### Temporal Module Parameters: 236.7792 M


# generate style prompt

In [8]:
seed = 2047
sa32 = 0.5
sa64 = 0.5

global id_prompts
global negative_prompt
general_prompt = config.get("general_prompt", "")
general_prompt = general_prompt[0]
negative_prompt = config.get("negative_prompt", "")
negative_prompt = negative_prompt[0]
prompt_array = config.get("prompt_array", [])
# print("prompt_array", prompt_array)
# print("general_prompt", general_prompt)

def apply_style_positive(style_name: str, positive: str):
    p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
    return p.replace("{prompt}", positive) 
def apply_style(style_name: str, positives: list, negative: str = ""):
    p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
    return [p.replace("{prompt}", positive) for positive in positives], n + ' ' + negative
### Set the generated Style
style_name = "Comic book"
setup_seed(seed)
generator = torch.Generator(device="cuda").manual_seed(seed)
print(id_length)
prompts = [general_prompt + " " + prompt for prompt in prompt_array]
id_prompts = prompts[:id_length]
# print("id_prompts", id_prompts)
real_prompts = prompts[id_length:]
write = False
cur_step = 0
attn_count = 0
# _, negative_prompt = apply_style(style_name, id_prompts, negative_prompt)
# id_prompts, negative_prompt = apply_style(style_name, id_prompts, negative_prompt)
print("id_prompts", id_prompts)
print("len(id_prompts)", len(id_prompts))
print(id_prompts[0])
print(negative_prompt)
negative_prompts = [negative_prompt] * len(id_prompts)


3
id_prompts ['a girl with blond hair and blue eyes go to the company', 'a girl with blond hair and blue eyes work in the company', 'a girl with blond hair and blue eyes running in the playground']
len(id_prompts) 3
a girl with blond hair and blue eyes go to the company
naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation


### Insert PairedAttention

In [9]:

## Insert PairedAttention
for name in unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]
    if cross_attention_dim is None and (name.startswith("up_blocks") ) :
        attn_procs[name] =  SpatialAttnProcessor2_0(id_length = id_length, single_model_length = single_model_length)
        total_count +=1
    else:
        attn_procs[name] = AttnProcessor()
print("successsfully load consistent self-attention")
print(f"number of the processor : {total_count}")
unet.set_attn_processor(copy.deepcopy(attn_procs))
del attn_procs ## release memory ？？？
global mask1024,mask4096
mask1024, mask4096 = cal_attn_mask_xl(total_length,id_length,sa32,sa64,height,width,device=device,dtype= torch.float16)


successsfully load consistent self-attention
number of the processor : 36


## 将插入CSA的unet 投入到 AnimationPipeline中

In [10]:

pipeline = AnimationPipeline(
        unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=scheduler,
        text_encoder_2 = text_encoder_two, tokenizer_2=tokenizer_two
).to("cuda")
# print("pipeline = ", pipeline)

## insert motion module and generate storygen video

In [11]:

from animatediff.utils.util import load_weights, save_videos_grid
import datetime
savedir = f"./output/{datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')}"
pipeline = load_weights(
		pipeline = pipeline,
		motion_module_path = config.get("motion_module_path", ""),
		ckpt_path = config.get("ckpt_path", ""),
		lora_path = config.get("lora_path", ""),
		lora_alpha = config.get("lora_alpha", 0.8)
	)

pipeline.unet = pipeline.unet.half()
pipeline.text_encoder = pipeline.text_encoder.half()
pipeline.text_encoder_2 = pipeline.text_encoder_2.half()
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()


random_seeds = config.get("seed", [-1])
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
random_seeds = random_seeds * len(id_prompts) if len(random_seeds) == 1 else random_seeds
seeds = []
samples = []


Loading motion module from ./models/Motion_Module/mm_sdxl_v10_beta.ckpt...


In [12]:

write = True
with torch.inference_mode():
    random_seed = random_seeds[0]
    if random_seed != -1: torch.manual_seed(random_seed)
    else: torch.seed()
    seeds.append(torch.initial_seed())
    print(f"current seed: {torch.initial_seed()}")
    print("n_prompt", negative_prompts)
    print("len(negative_prompts)", len(negative_prompts))
    print("prompt", id_prompts)
    samples = pipeline(
        id_prompts,
        negative_prompt	  = negative_prompts,
        num_inference_steps = config.get('steps', 20),
        guidance_scale	  = config.get('guidance_scale', 10),
        width				  = width,
        height			  = height,
        single_model_length = single_model_length,
    ).videos
    # print("sample = ", sample)
    # save video
    i = 0
    print("samples", samples.shape)
    for sample in samples:
        prompt = prompts[i]
        # save video
        sample = sample.unsqueeze(0)
        save_videos_grid(sample, f"{savedir}/sample/{prompt}.mp4")
        import imageio
        video_reader = imageio.get_reader(f"{savedir}/sample/{prompt}.mp4")
        imageio.mimsave(f"{savedir}/sample/{prompt}.gif", [frame for frame in video_reader], fps=25)
        print(f"save to {savedir}/sample/{prompt}.mp4")
        i = i + 1

    
    save_videos_grid(samples, f"{savedir}/main.mp4")
    print(f"save to {savedir}/main.mp4")

# save_videos_grid(samples, f"{savedir}/sample-{datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')}.mp4", n_rows=4)
config.seed = seeds
OmegaConf.save(config, f"{savedir}/config.yaml")

current seed: 17323659542486420036
n_prompt ['naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation', 'naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation', 'naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation']
len(negative_prompts) 3
prompt ['a girl with blond hair and blue eyes go to the company', 'a girl with blond hair and blue eyes work in the company', 'a girl with blond hair and blue eyes running in the playground']
batch_size 3


  0%|          | 0/25 [00:00<?, ?it/s]

total_count 36


  4%|▍         | 1/25 [00:06<02:35,  6.48s/it]

total_count 36


  8%|▊         | 2/25 [00:10<01:58,  5.15s/it]

total_count 36


 12%|█▏        | 3/25 [00:14<01:43,  4.72s/it]

total_count 36


 16%|█▌        | 4/25 [00:19<01:35,  4.53s/it]

total_count 36


 20%|██        | 5/25 [00:23<01:28,  4.42s/it]

total_count 36


 24%|██▍       | 6/25 [00:26<01:18,  4.11s/it]

total_count 36


 28%|██▊       | 7/25 [00:30<01:10,  3.91s/it]

total_count 36


 32%|███▏      | 8/25 [00:33<01:04,  3.82s/it]

total_count 36


 36%|███▌      | 9/25 [00:37<00:58,  3.68s/it]

total_count 36


 40%|████      | 10/25 [00:41<00:54,  3.66s/it]

total_count 36


 44%|████▍     | 11/25 [00:44<00:51,  3.65s/it]

total_count 36


 48%|████▊     | 12/25 [00:48<00:46,  3.60s/it]

total_count 36


 52%|█████▏    | 13/25 [00:51<00:43,  3.63s/it]

total_count 36


 56%|█████▌    | 14/25 [00:55<00:40,  3.69s/it]

total_count 36


 60%|██████    | 15/25 [00:59<00:36,  3.68s/it]

total_count 36


 64%|██████▍   | 16/25 [01:02<00:32,  3.63s/it]

total_count 36


 68%|██████▊   | 17/25 [01:06<00:28,  3.59s/it]

total_count 36


 72%|███████▏  | 18/25 [01:09<00:25,  3.58s/it]

total_count 36


 76%|███████▌  | 19/25 [01:13<00:21,  3.57s/it]

total_count 36


 80%|████████  | 20/25 [01:17<00:18,  3.67s/it]

total_count 36


 84%|████████▍ | 21/25 [01:20<00:14,  3.62s/it]

total_count 36


 88%|████████▊ | 22/25 [01:24<00:10,  3.57s/it]

total_count 36


 92%|█████████▏| 23/25 [01:27<00:07,  3.50s/it]

total_count 36


 96%|█████████▌| 24/25 [01:31<00:03,  3.50s/it]

total_count 36


100%|██████████| 25/25 [01:34<00:00,  3.78s/it]


samples torch.Size([3, 3, 16, 512, 512])
Saving video grid to ./output/2024-09-10T11-30-42/sample/a girl with blond hair and blue eyes go to the company.mp4
save to ./output/2024-09-10T11-30-42/sample/a girl with blond hair and blue eyes go to the company.mp4
Saving video grid to ./output/2024-09-10T11-30-42/sample/a girl with blond hair and blue eyes work in the company.mp4
save to ./output/2024-09-10T11-30-42/sample/a girl with blond hair and blue eyes work in the company.mp4
Saving video grid to ./output/2024-09-10T11-30-42/sample/a girl with blond hair and blue eyes running in the playground.mp4
save to ./output/2024-09-10T11-30-42/sample/a girl with blond hair and blue eyes running in the playground.mp4




Saving video grid to ./output/2024-09-10T11-30-42/main.mp4




save to ./output/2024-09-10T11-30-42/main.mp4


## continued generation

In [13]:
write = False 
with torch.inference_mode():
    print("real_prompts", real_prompts)
    for prompt_idx, (id_prompt, random_seed) in enumerate(zip(real_prompts,  random_seeds)):
        cur_step = 0
        print(f"current seed: {torch.initial_seed()}")
        print("n_prompt", negative_prompt)
        print("prompt", id_prompt)
        sample = pipeline(
            id_prompt,
            negative_prompt	  = negative_prompt,
            num_inference_steps = config.get('steps', 20),
            guidance_scale	  = config.get('guidance_scale', 10),
            width				  = width,
            height			  = height,
            # single_model_length = single_model_length,
            single_model_length = single_model_length,  
        ).videos
        print("sample = ", sample)
        samples.append(sample)
        # save video
        save_videos_grid(sample, f"{savedir}/{id_prompt}.mp4")
        print(f"save to {savedir}/{id_prompt}.mp4")
samples = torch.concat(samples)

real_prompts ['a girl with blond hair and blue eyes reading book in the home', 'a girl with blond hair and blue eyes wake up in the bed']
current seed: 17323659542486420036
n_prompt naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation
prompt a girl with blond hair and blue eyes reading book in the home
batch_size 1


  0%|          | 0/25 [00:02<?, ?it/s]


KeyError: 0

### Create the text description for the comics
Tips: Existing text2image diffusion models may not always generate images that accurately match text descriptions. Our training-free approach can improve the consistency of characters, but it does not enhance the control over the text. Therefore, in some cases, you may need to carefully craft your prompts.

### Continued Creation
From now on, you can create endless stories about this character without worrying about memory constraints.

### Make pictures into comics