## Load

In [None]:
import os
import torch

from animatediff.pipelines.pipeline_animation_inpaint import AnimationInpaintPipeline

from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from animatediff.models.unet import UNet3DConditionModel

stable_diffusion_model_path = os.path.join(os.getcwd(), "models", "StableDiffusion", "ACertainThing")

tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_model_path, subfolder="text_encoder").cuda()
vae = AutoencoderKL.from_pretrained(stable_diffusion_model_path, subfolder="vae").cuda()
unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_model_path, subfolder="unet", unet_additional_kwargs={
    "unet_use_cross_frame_attention": False,
    "unet_use_temporal_attention": False,
    "use_motion_module": True,
    "motion_module_resolutions": [1, 2, 4, 8],
    "motion_module_mid_block": False,
    "motion_module_decoder_only": False,
    "motion_module_type": "Vanilla",
    "motion_module_kwargs": {
        "num_attention_heads": 8,
        "num_transformer_block": 1,
        "attention_block_types": ["Temporal_Self", "Temporal_Self"],
        "temporal_position_encoding": True,
        "temporal_position_encoding_max_len": 24,
        "temporal_attention_dim_div": 1
    }
}).cuda()
motion_module_path = os.path.join(os.getcwd(), "models", "Motion_Module", "mm_sd_v14.ckpt")
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
missing, unexpected = unet.load_state_dict(motion_module_state_dict, strict=False)
assert len(unexpected) == 0
unet.enable_xformers_memory_efficient_attention()
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear")

self = AnimationInpaintPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler).to("cuda")

## Prompt

In [None]:
import PIL

prompt = "best quality, masterpiece, 1girl, looking at viewer, hatsune miku"
negative_prompt = ""
num_inference_steps = 25
guidance_scale = 7.5
width = 512
height = 512
video_length = 16
seed = 1
keyframes = {
    0: PIL.Image.open("images/0.jpeg"),
    15: PIL.Image.open("images/15.jpeg")
}

## Inference

In [None]:
import time
from datetime import datetime

from animatediff.utils.util import save_videos_grid

torch.manual_seed(seed)

prompt = prompt
negative_prompt = negative_prompt
num_inference_steps = num_inference_steps
guidance_scale = guidance_scale
width = width
height = height
video_length = video_length
keyframes = keyframes
add_predicted_noise = False
do_reconstruction_guidance = True
reconstruction_guidance_scale = 3.0

num_videos_per_prompt = 1
eta = 0.0
generator = None
latents = None
output_type = "tensor"
return_dict = True
callback = None
callback_steps = 1

self.unet.requires_grad_(False)
self.vae.requires_grad_(False)
self.text_encoder.text_model.encoder.requires_grad_(False)
self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)

# Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# Define call parameters
# batch_size = 1 if isinstance(prompt, str) else len(prompt)
batch_size = 1
if latents is not None:
    batch_size = latents.shape[0]
if isinstance(prompt, list):
    batch_size = len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# Encode input prompt
prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
if negative_prompt is not None:
    negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size 
text_embeddings = self._encode_prompt(
    prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
)
# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# Prepare latent variables
num_channels_latents = self.unet.in_channels
# latents shape is [Batch, Channel, Frame, Height, Width]
latents = self.prepare_latents(
    batch_size * num_videos_per_prompt,
    num_channels_latents,
    video_length,
    height,
    width,
    text_embeddings.dtype,
    device,
    generator,
    latents,
)
latents_dtype = latents.dtype
# Prepare mask for keyframes
# mask shape is [Batch, Channel, Frame, Height, Width]
# and is 1 for [Batch, Channel, keyframe, Height, Width] and 0 for others
zeros = torch.zeros(
    (
        batch_size * num_videos_per_prompt,
        num_channels_latents,
        video_length,
        int(height / self.vae_scale_factor),
        int(width / self.vae_scale_factor)
    ),
    device=device
)
mask = zeros.clone()
for keyframe_idx in keyframes.keys():
    mask[:, :, keyframe_idx, :, :] = 1
# Prepare image latents
# preprocess all keyframes
keyframes_latents = zeros.clone()
keyframes_init_latents_orig = zeros.clone()
keyframes_noise = zeros.clone()
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
for keyframe_idx, keyframe in keyframes.items():
    if isinstance(keyframe, PIL.Image.Image):
        keyframe = self.preprocess_image(keyframe)
    keyframe_latents, keyframe_init_latents_orig, keyframe_noise = self.prepare_image_latents(keyframe, latent_timestep, batch_size * num_videos_per_prompt, text_embeddings.dtype, device, generator)
    keyframes_latents[:, :, keyframe_idx, :, :] = keyframe_latents
    keyframes_init_latents_orig[:, :, keyframe_idx, :, :] = keyframe_init_latents_orig
    keyframes_noise[:, :, keyframe_idx, :, :] = keyframe_noise

latents = keyframes_latents * mask + latents * (1 - mask)
# Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
    for i, t in enumerate(timesteps):
        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
        # predict the noise residual
        noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
        # noise_pred = []
        # import pdb
        # pdb.set_trace()
        # for batch_idx in range(latent_model_input.shape[0]):
        #     noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
        #     noise_pred.append(noise_pred_single)
        # noise_pred = torch.cat(noise_pred)
        # perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        # compute the previous noisy sample x_t -> x_t-1
        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
        # reconstruction guidance
        if do_reconstruction_guidance:
            pass
        # masking
        if add_predicted_noise:
            init_latents_proper = self.scheduler.add_noise(
                keyframes_init_latents_orig,
                noise_pred_uncond,
                torch.tensor([t])
            )
        else:
            init_latents_proper = self.scheduler.add_noise(
                keyframes_init_latents_orig,
                keyframes_noise,
                torch.tensor([t])
            )
        latents = init_latents_proper * mask + latents * (1 - mask)
        # call the callback, if provided
        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
            progress_bar.update()
            if callback is not None and i % callback_steps == 0:
                callback(i, t, latents)
# Post-processing
video = self.decode_latents(latents)
# Convert to tensor
if output_type == "tensor":
    video = torch.from_numpy(video)
if not return_dict:
    pass

savedir = os.path.join(os.getcwd(), "sample", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
savedir_sample = os.path.join(savedir, "sample")
os.makedirs(savedir, exist_ok=True)

save_sample_path = os.path.join(savedir_sample, f"{str(int(time.time()))}.mp4")
save_videos_grid(video, save_sample_path)