Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions docs/source/en/api/pipelines/text_to_video_zero.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,41 @@ You can change these parameters in the pipeline call:
* Video length:
* `video_length`, the number of frames video_length to be generated. Default: `video_length=8`

We an also generate longer videos by doing the processing in a chunk-by-chunk manner:
```python
import torch
import imageio
from diffusers import TextToVideoZeroPipeline
import numpy as np

model_id = "runwayml/stable-diffusion-v1-5"
pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
seed = 0
video_length = 8
chunk_size = 4
prompt = "A panda is playing guitar on times square"

# Generate the video chunk-by-chunk
result = []
chunk_ids = np.arange(0, video_length, chunk_size - 1)
generator = torch.Generator(device="cuda")
for i in range(len(chunk_ids)):
print(f"Processing chunk {i + 1} / {len(chunk_ids)}")
ch_start = chunk_ids[i]
ch_end = video_length if i == len(chunk_ids) - 1 else chunk_ids[i + 1]
# Attach the first frame for Cross Frame Attention
frame_ids = [0] + list(range(ch_start, ch_end))
# Fix the seed for the temporal consistency
generator.manual_seed(seed)
output = pipe(prompt=prompt, video_length=len(frame_ids), generator=generator, frame_ids=frame_ids)
result.append(output.images[1:])

# Concatenate chunks and save
result = np.concatenate(result)
result = [(r * 255).astype("uint8") for r in result]
imageio.mimsave("video.mp4", result, fps=4)
```


### Text-To-Video with Pose Control
To generate a video from prompt with additional pose control
Expand Down Expand Up @@ -202,7 +237,7 @@ can run with custom [DreamBooth](../training/dreambooth) models, as shown below

reader = imageio.get_reader(video_path, "ffmpeg")
frame_count = 8
video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
canny_edges = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
```

3. Run `StableDiffusionControlNetPipeline` with custom trained DreamBooth model
Expand All @@ -223,10 +258,10 @@ can run with custom [DreamBooth](../training/dreambooth) models, as shown below
pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))

# fix latents for all frames
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(canny_edges), 1, 1, 1)

prompt = "oil painting of a beautiful girl avatar style"
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
result = pipe(prompt=[prompt] * len(canny_edges), image=canny_edges, latents=latents).images
imageio.mimsave("video.mp4", result, fps=4)
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def rearrange_4(tensor):

class CrossFrameAttnProcessor:
"""
Cross frame attention processor. For each frame the self-attention is replaced with attention with first frame
Cross frame attention processor. Each frame attends the first frame.

Args:
batch_size: The number that represents actual batch size, other than the frames.
For example, using calling unet with a single prompt and num_images_per_prompt=1, batch_size should be
equal to 2, due to classifier-free guidance.
For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
2, due to classifier-free guidance.
"""

def __init__(self, batch_size=2):
Expand All @@ -63,7 +63,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

# Sparse Attention
# Cross Frame Attention
if not is_cross_attention:
video_length = key.size()[0] // self.batch_size
first_frame_index = [0] * video_length
Expand Down Expand Up @@ -95,6 +95,81 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
return hidden_states


class CrossFrameAttnProcessor2_0:
"""
Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0.

Args:
batch_size: The number that represents actual batch size, other than the frames.
For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to
2, due to classifier-free guidance.
"""

def __init__(self, batch_size=2):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.batch_size = batch_size

def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

query = attn.to_q(hidden_states)

is_cross_attention = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

# Cross Frame Attention
if not is_cross_attention:
video_length = key.size()[0] // self.batch_size
first_frame_index = [0] * video_length

# rearrange keys to have batch and frames in the 1st and 2nd dims respectively
key = rearrange_3(key, video_length)
key = key[:, first_frame_index]
# rearrange values to have batch and frames in the 1st and 2nd dims respectively
value = rearrange_3(value, video_length)
value = value[:, first_frame_index]

# rearrange back to original shape
key = rearrange_4(key)
value = rearrange_4(value)

head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states


@dataclass
class TextToVideoPipelineOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]
Expand Down Expand Up @@ -227,7 +302,12 @@ def __init__(
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
self.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
processor = (
CrossFrameAttnProcessor2_0(batch_size=2)
if hasattr(F, "scaled_dot_product_attention")
else CrossFrameAttnProcessor(batch_size=2)
)
self.unet.set_attn_processor(processor)

def forward_loop(self, x_t0, t0, t1, generator):
"""
Expand Down Expand Up @@ -338,6 +418,7 @@ def __call__(
callback_steps: Optional[int] = 1,
t0: int = 44,
t1: int = 47,
frame_ids: Optional[List[int]] = None,
):
"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -399,6 +480,9 @@ def __call__(
t1 (`int`, *optional*, defaults to 47):
Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the
[paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
frame_ids (`List[int]`, *optional*):
Indexes of the frames that are being generated. This is used when generating longer videos
chunk-by-chunk.

Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]:
Expand All @@ -407,7 +491,9 @@ def __call__(
likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
assert video_length > 0
frame_ids = list(range(video_length))
if frame_ids is None:
frame_ids = list(range(video_length))
assert len(frame_ids) == video_length

assert num_videos_per_prompt == 1

Expand Down