## StoryDiffusion: Consistent Self-Attention for Long-Range Image and Video Generation  
[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md-dark.svg)]()
[[Paper]()] &emsp; [[Project Page]()] &emsp; <br>

### Import Packages

In [1]:

# %load_ext autoreload
# %autoreload 2
import gradio as gr
import numpy as np
import torch
import requests
import random
import os
import sys
import pickle
from PIL import Image
from tqdm.auto import tqdm
from datetime import datetime

import diffusers
from diffusers import DDIMScheduler
import torch.nn.functional as F

from utils_sd15.gradio_utils import cal_attn_mask
import copy
import os
from diffusers.utils import load_image
from utils_sd15.utils import get_comic
from utils_sd15.style_template import styles

from animatediff_25.utils.util import save_videos_grid

from animatediff_25.models.unet import UNet3DConditionModel
from animatediff_25.pipelines.pipeline_animation import AnimationPipeline


# from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor, LoRAAttnProcessor

from utils_sd15.gradio_utils import \
    AttnProcessor2_0 as AttnProcessor

  from .autonotebook import tqdm as notebook_tqdm


### Set Config 

In [2]:
## Global
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "(No style)"
MAX_SEED = np.iinfo(np.int32).max
global models_dict
use_va = False
models_dict = {
   "Juggernaut":"RunDiffusion/Juggernaut-XL-v8",
   "RealVision":"SG161222/RealVisXL_V4.0" ,
   "SDXL":"stabilityai/stable-diffusion-xl-base-1.0" ,
   "Unstable": "stablediffusionapi/sdxl-unstable-diffusers-y",
   "SD15": "runwayml/stable-diffusion-v1-5",
}


In [3]:
torch.cuda.is_available()
torch.cuda.empty_cache()

In [4]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

    
#################################################
########Consistent Self-Attention################
#################################################
class SpatialAttnProcessor2_0(torch.nn.Module):
    r"""
    Attention processor for IP-Adapater for PyTorch 2.0.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        text_context_len (`int`, defaults to 77):
            The context length of the text features.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
    """

    def __init__(self, hidden_size = None, cross_attention_dim=None,id_length = 4,device = "cuda",dtype = torch.float16, single_model_length = 16):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.device = device
        self.dtype = dtype
        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.total_length = id_length + 1
        self.id_length = id_length
        self.id_bank = {}
        self.single_model_length = single_model_length

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None):
        global total_count,attn_count,cur_step,mask256,mask1024,mask4096
        global sa32, sa64
        global write
        global height,width
        
        hidden_states = hidden_states.view(-1, self.single_model_length, hidden_states.shape[-2], hidden_states.shape[-1])
        # print("temb = ", temb)
        # print("cur_step = ", cur_step)
        
        if write:
            # print(f"white:{cur_step}")
            self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]
        else:
            
            encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))
 
        # skip in early step
        if cur_step <5:
            hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
        else:   # 256 1024 4096
            random_number = random.random()
            if cur_step <20:
                rand_num = 0.3
            else:
                rand_num = 0.1
            # if random_number > rand_num:
            # if cur_step > 11: ### ??? 
            if False:
                if write == False:
                    if hidden_states.shape[2] == 256 :
                        attention_mask = mask256[mask256.shape[0] // self.total_length * self.id_length:]
                        # print("hidden_states.shape[1] == 256, attention_mask", attention_mask.shape)
                    elif hidden_states.shape[2] == 1024:
                        attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:]
                        # print("hidden_states.shape[1] == 1024, attention_mask", attention_mask.shape)
                    elif hidden_states.shape[2] == 4096:
                        attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:]
                    else:
                        print("hidden_states.shape[1], attention_mask = ", hidden_states.shape)
                else:
                    if hidden_states.shape[2] == 256:
                        attention_mask = mask256[:mask256.shape[0] // self.total_length * self.id_length,:mask256.shape[0] // self.total_length * self.id_length]
                        # print("hidden_states.shape[1], attention_mask = ", hidden_states.shape, attention_mask.shape)
                    elif hidden_states.shape[2] == 1024:
                        attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length,:mask1024.shape[0] // self.total_length * self.id_length]
                        # print("hidden_states.shape[1], attention_mask", hidden_states.shape ,attention_mask.shape)
                    elif hidden_states.shape[2] == 4096:
                        attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length,:mask4096.shape[0] // self.total_length * self.id_length]
                        # print("hidden_states.shape[1], attention_mask = ", hidden_states.shape,attention_mask.shape)
                    else:
                        print("hidden_states.shape[1], attention_mask = ", hidden_states.shape)
                        
                hidden_states = self.__call1__(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
            else:
                hidden_states = self.__call2__(attn, hidden_states, None, attention_mask, temb)
        attn_count +=1
        if attn_count == total_count:
            attn_count = 0
            cur_step += 1
            mask256,mask1024,mask4096 = cal_attn_mask(total_length,id_length,sa16, sa32,sa64,device=self.device, dtype= self.dtype)
        
        return hidden_states
    def __call1__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        # __import__("ipdb").set_trace()
        if encoder_hidden_states is not None:
            print("encoder_hidden_states.shape = ", encoder_hidden_states.shape)
        # from animatediff_25.models.unet import timesteps
        # print("timesteps = ", timesteps)
        
        # __import__('ipdb').set_trace()
        residual = hidden_states
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)
        input_ndim = hidden_states.ndim

        if input_ndim == 5:
            batch_size, video_length, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, video_length, channel, height * width).transpose(2, 3)
            
        batch_size, video_length, nums_token, channel = hidden_states.shape
        
        hidden_states_frame1 = hidden_states[:, 0, :, :]
        hidden_states_frame1 = hidden_states_frame1.view(-1, nums_token, channel)
        img_nums = batch_size//2
        hidden_states_frame1 = hidden_states_frame1.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel)
        # print("hidden_states_frame1.shape = ", hidden_states_frame1.shape)

        if attn.group_norm is not None:
            hidden_states_frame1 = attn.group_norm(hidden_states_frame1.transpose(1, 2)).transpose(1, 2)
        
        # hidden_states = hidden_states.view(-1, nums_token, channel) 
        # hidden_states = hidden_states[:, 0, :, :]
        # print("new hidden_states.shape = ", hidden_states.shape)
        # hidden_states = hidden_states.view(-1, nums_token, channel)

        query = attn.to_q(hidden_states_frame1)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states_frame1  # B, N, C
        else:
            encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)


        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        batch_size_frame1 = hidden_states_frame1.shape[0]
        query = query.view(batch_size_frame1, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size_frame1, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size_frame1, -1, attn.heads, head_dim).transpose(1, 2)

        # print("query.shape, key.shape, value.shape = ", query.shape, key.shape, value.shape)
        hidden_states_frame1 = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )
        print("hidden_states_frame1.shape", hidden_states_frame1.shape)
        print("hidden_states.shape", hidden_states.shape)
        hidden_states[:, 0, :, :] = hidden_states_frame1.view(-1, nums_token, channel)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size * video_length , -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)
        print("final hidden_states.shape = ", hidden_states.shape)

        if input_ndim == 5:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, video_length, channel, height, width)
        if attn.residual_connection:
            hidden_states = hidden_states + residual
        hidden_states = hidden_states / attn.rescale_output_factor
        # print(hidden_states.shape)
        return hidden_states
    
    def __call2__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 5:
            batch_size, video_length, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, video_length, channel, height * width).transpose(2, 3)


        batch_size, video_length, sequence_length, channel = (
            hidden_states.shape
        )
        # print(hidden_states.shape)
        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, video_length, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        hidden_states = hidden_states.view(-1, sequence_length, channel)
        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states  # B, N, C
        else:
            encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,sequence_length,channel).reshape(-1,(self.id_length+1) * sequence_length,channel)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size * video_length, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        
        if input_ndim == 5:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, video_length, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

### Load Pipeline

In [5]:
from omegaconf import OmegaConf
config = OmegaConf.load("./config/inference.yaml")
global attn_count, total_count, id_length, total_length,cur_step, cur_model_type
global write
global sa16, sa32, sa64
global height,width

global attn_procs,unet

global pipeline
global sd_model_path
global num_steps
attn_procs = {}

attn_count = 0
total_count = 0
cur_step = 0
id_length = 3
cur_model_type = ""
device="cuda"
###
write = False
### strength of consistent self-attention: the larger, the stronger
sa16 = 0.5
sa32 = 0.5
sa64 = 0.5
### Res. of the Generated Comics. Please Note: SDXL models may do worse in a low-resolution! 
height = config.get("height", 256)
width = config.get("width", 256)
with_csa = config.get("with_csa", False)
single_model_length = config.get("single_model_length", 4)
num_steps = config.get("steps", 20)
total_length = id_length + 1
print("id_length", id_length)
print("total_length", total_length)
print("single_model_length", single_model_length)
print("height, width", height, width)
print("num_steps", num_steps)

id_length 3
total_length 4
single_model_length 16
height, width 256 256
num_steps 8


In [6]:

from diffusers import AutoencoderKL, EulerDiscreteScheduler

from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection

sd_model_path = "../models/stable-diffusion-v1-5"
# Load Component
tokenizer    = CLIPTokenizer.from_pretrained(sd_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(sd_model_path, subfolder="text_encoder")
vae          = AutoencoderKL.from_pretrained(sd_model_path, subfolder="vae")   

# init unet model
unet = UNet3DConditionModel.from_pretrained_2d(sd_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(config.unet_additional_kwargs))

# print("unet.state_dict()", unet.state_dict().keys())

scheduler=DDIMScheduler(**OmegaConf.to_container(config.noise_scheduler_kwargs))



loaded temporal unet's pretrained weights from ../models/stable-diffusion-v1-5/unet ...
self.attn_processors dict_keys([])
### missing keys: 672; 
### unexpected keys: 0;
### Temporal Module Parameters: 417.1376 M


## Insert CSA

In [7]:

### Insert PairedAttention
for name in unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
        # print("hidden_size", hidden_size)
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]
    if cross_attention_dim is None and (name.startswith("up_blocks") ) :
        # print("id_length", id_length)
        attn_procs[name] =  SpatialAttnProcessor2_0(id_length = id_length, single_model_length=single_model_length)
        total_count +=1
    else:
        attn_procs[name] = AttnProcessor()
    # print("hidden_size", hidden_size)
print("successsfully load consistent self-attention")
print(f"number of the processor : {total_count}")

if with_csa:
    unet.set_attn_processor(copy.deepcopy(attn_procs))
    print("using consistent self-attention")
# print("unet", unet )
    
mask256,mask1024,mask4096 = cal_attn_mask(total_length,id_length,sa16, sa32,sa64)

successsfully load consistent self-attention
number of the processor : 9
using consistent self-attention


## Get AnimationPipeline

In [8]:
from animatediff_25.utils.util import load_weights
pipeline = AnimationPipeline(
    vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
    scheduler=scheduler,
).to("cuda")


pipeline = load_weights(
            pipeline,
            # motion module
            motion_module_path         = "../AnimateDiff/models/Motion_Module/mm_sd_v15.ckpt",
        ).to("cuda")

load motion module from ../AnimateDiff/models/Motion_Module/mm_sd_v15.ckpt


## load motion module

### Create the text description for the comics
Tips: Existing text2image diffusion models may not always generate images that accurately match text descriptions. Our training-free approach can improve the consistency of characters, but it does not enhance the control over the text. Therefore, in some cases, you may need to carefully craft your prompts.

In [9]:
guidance_scale = 5.0
seed = 2047
sa32 = 0.5
sa64 = 0.5
general_prompt = "a girl with a blue hair"
negative_prompt = "naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation"
prompt_array = [
                "have breakfast",
                "is on the road, go to the company",
                "work in the company",
                "running in the playground",
                "reading book in the home"
                ]

In [10]:
def apply_style_positive(style_name: str, positive: str):
    p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
    return p.replace("{prompt}", positive) 
def apply_style(style_name: str, positives: list, negative: str = ""):
    p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
    return [p.replace("{prompt}", positive) for positive in positives], n + ' ' + negative
### Set the generated Style
style_name = "Comic book"
# setup_seed(seed)
generator = torch.Generator(device="cuda").manual_seed(seed)
prompts = [general_prompt+","+prompt for prompt in prompt_array]
id_prompts = prompts[:id_length]
real_prompts = prompts[id_length:]
print("prompts = ", prompts)
print("id_prompts = ", id_prompts)
print("real_prompts = ", real_prompts)
print("negative_prompt = ", negative_prompt)
negative_prompt = negative_prompt[0]
torch.cuda.empty_cache()
cur_step = 0
attn_count = 0
savedir = "./output"
id_prompts, negative_prompt = apply_style(style_name, id_prompts, negative_prompt)

# negative_prompt = negative_prompt  * len(id_prompts)
# negative_prompt = negative_prompt[0]  * len(id_prompts)

print("negative_prompt = ", negative_prompt)



prompts =  ['a girl with a blue hair,have breakfast', 'a girl with a blue hair,is on the road, go to the company', 'a girl with a blue hair,work in the company', 'a girl with a blue hair,running in the playground', 'a girl with a blue hair,reading book in the home']
id_prompts =  ['a girl with a blue hair,have breakfast', 'a girl with a blue hair,is on the road, go to the company', 'a girl with a blue hair,work in the company']
real_prompts =  ['a girl with a blue hair,running in the playground', 'a girl with a blue hair,reading book in the home']
negative_prompt =  naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation
negative_prompt =  photograph, deformed, glitch, noisy, realistic, stock photo, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, 

In [11]:
write = True
with torch.inference_mode():
    samples = pipeline(
        prompt=id_prompts,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_steps,
        video_length = single_model_length,
        width				  = width,
        height			  = height,
    ).videos

  num_channels_latents = self.unet.in_channels
  0%|          | 0/8 [00:00<?, ?it/s]

timesteps  tensor(876, device='cuda:0')


 12%|█▎        | 1/8 [00:02<00:20,  2.93s/it]

timesteps  tensor(751, device='cuda:0')


 25%|██▌       | 2/8 [00:05<00:16,  2.72s/it]

timesteps  tensor(626, device='cuda:0')


 38%|███▊      | 3/8 [00:08<00:13,  2.65s/it]

timesteps  tensor(501, device='cuda:0')


 50%|█████     | 4/8 [00:10<00:10,  2.61s/it]

timesteps  tensor(376, device='cuda:0')


 62%|██████▎   | 5/8 [00:13<00:07,  2.59s/it]

timesteps  tensor(251, device='cuda:0')
hidden_states_frame1.shape torch.Size([2, 8, 768, 80])
hidden_states.shape torch.Size([6, 16, 256, 640])
final hidden_states.shape =  torch.Size([96, 256, 640])
hidden_states_frame1.shape torch.Size([2, 8, 768, 80])
hidden_states.shape torch.Size([6, 16, 256, 640])
final hidden_states.shape =  torch.Size([96, 256, 640])
hidden_states_frame1.shape torch.Size([2, 8, 3072, 40])
hidden_states.shape torch.Size([6, 16, 1024, 320])
final hidden_states.shape =  torch.Size([96, 1024, 320])
hidden_states_frame1.shape torch.Size([2, 8, 3072, 40])
hidden_states.shape torch.Size([6, 16, 1024, 320])
final hidden_states.shape =  torch.Size([96, 1024, 320])


 75%|███████▌  | 6/8 [00:15<00:04,  2.36s/it]

timesteps  tensor(126, device='cuda:0')
hidden_states.shape[1], attention_mask =  torch.Size([6, 16, 64, 1280])
hidden_states_frame1.shape torch.Size([2, 8, 192, 160])
hidden_states.shape torch.Size([6, 16, 64, 1280])
final hidden_states.shape =  torch.Size([96, 64, 1280])
hidden_states.shape[1], attention_mask =  torch.Size([6, 16, 64, 1280])
hidden_states_frame1.shape torch.Size([2, 8, 192, 160])
hidden_states.shape torch.Size([6, 16, 64, 1280])
final hidden_states.shape =  torch.Size([96, 64, 1280])
hidden_states_frame1.shape torch.Size([2, 8, 768, 80])
hidden_states.shape torch.Size([6, 16, 256, 640])
final hidden_states.shape =  torch.Size([96, 256, 640])
hidden_states_frame1.shape torch.Size([2, 8, 768, 80])
hidden_states.shape torch.Size([6, 16, 256, 640])
final hidden_states.shape =  torch.Size([96, 256, 640])
hidden_states_frame1.shape torch.Size([2, 8, 768, 80])
hidden_states.shape torch.Size([6, 16, 256, 640])
final hidden_states.shape =  torch.Size([96, 256, 640])
hidden_st

 88%|████████▊ | 7/8 [00:17<00:02,  2.29s/it]

timesteps  tensor(1, device='cuda:0')
hidden_states.shape[1], attention_mask =  torch.Size([6, 16, 64, 1280])
hidden_states_frame1.shape torch.Size([2, 8, 192, 160])
hidden_states.shape torch.Size([6, 16, 64, 1280])
final hidden_states.shape =  torch.Size([96, 64, 1280])
hidden_states.shape[1], attention_mask =  torch.Size([6, 16, 64, 1280])
hidden_states_frame1.shape torch.Size([2, 8, 192, 160])
hidden_states.shape torch.Size([6, 16, 64, 1280])
final hidden_states.shape =  torch.Size([96, 64, 1280])
hidden_states_frame1.shape torch.Size([2, 8, 768, 80])
hidden_states.shape torch.Size([6, 16, 256, 640])
final hidden_states.shape =  torch.Size([96, 256, 640])
hidden_states_frame1.shape torch.Size([2, 8, 768, 80])
hidden_states.shape torch.Size([6, 16, 256, 640])
final hidden_states.shape =  torch.Size([96, 256, 640])
hidden_states_frame1.shape torch.Size([2, 8, 768, 80])
hidden_states.shape torch.Size([6, 16, 256, 640])
final hidden_states.shape =  torch.Size([96, 256, 640])
hidden_stat

100%|██████████| 8/8 [00:19<00:00,  2.39s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.00 GiB (GPU 0; 23.69 GiB total capacity; 16.20 GiB already allocated; 2.13 GiB free; 21.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
print("sample.shape", samples.shape)
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
print("save to", f"{savedir}/sample.gif")
i = 0
for sample in samples:
    print("sample.shape", sample.shape)
    print("i", i)
    id_prompt = id_prompts[i]
    print("id_prompt", id_prompt)
    sample = sample.unsqueeze(0)
    save_videos_grid(sample, f"{savedir}/{id_prompt}.gif")
    i = i + 1
print(f"save to {savedir}/{id_prompt}.gif")
    
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)

OmegaConf.save(config, f"{savedir}/config.yaml")

NameError: name 'samples' is not defined

### Continued Creation
From now on, you can create endless stories about this character without worrying about memory constraints.

### Make pictures into comics