## Baseline model

In [1]:
import torch
from omegaconf import OmegaConf
import torchvision.transforms as transforms
from PIL import Image

from einops import rearrange

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler

from animatediff.pipelines.pipeline_unipaint import AnimationPipeline
from animatediff.models.unet import UNet3DConditionModel
from animatediff.models.sparse_controlnet import SparseControlNetModel

from animatediff.utils.util import load_weights,save_videos_grid

path = "models/StableDiffusion/stable-diffusion-v1-5"
device = "cuda:0"
dtype = torch.float16

#load base model
tokenizer        = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder     = CLIPTextModel.from_pretrained(path, subfolder="text_encoder").to(device,dtype)
vae              = AutoencoderKL.from_pretrained(path, subfolder="vae").to(device, dtype)

inference_config = OmegaConf.load("configs/inference/inference-v3.yaml")
unet             = UNet3DConditionModel.from_pretrained_2d(path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).to(device, dtype)

#load controlnet
unet.config.num_attention_heads = 8
unet.config.projection_class_embeddings_input_dim = None
controlnet_config = OmegaConf.load("configs/inference/sparsectrl/latent_condition.yaml")
controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {}))
controlnet_state_dict = torch.load("models/Motion_Module/v3_sd15_sparsectrl_rgb.ckpt", map_location="cpu")
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name}
controlnet_state_dict.pop("animatediff_config", "")
controlnet.load_state_dict(controlnet_state_dict)
controlnet.to(device, dtype)

#build pipeline
pipeline = AnimationPipeline(
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
            controlnet=controlnet,
            scheduler=DDIMScheduler(beta_start=0.00085,
                                                beta_end=0.012,
                                                beta_schedule="linear",
                                                steps_offset=0,
                                                clip_sample=False)
                                                ).to(device)

  from .autonotebook import tqdm as notebook_tqdm


loaded 3D unet's pretrained weights from models/StableDiffusion/stable-diffusion-v1-5 ...
### missing keys: 520; 
### unexpected keys: 0;
### Motion Module Parameters: 417.1376 M
load motion module from models/Motion_Module/v3_sd15_mm.ckpt


In [7]:
pipeline = load_weights(
    pipeline,
    # motion module
    motion_module_path         = "models/Motion_Module/v3_sd15_mm.ckpt",
    # motion_module_path         = "",
    motion_module_lora_configs = [],
    # domain adapter
    adapter_lora_path          = "models/Motion_Module/v3_sd15_adapter.ckpt",
    # adapter_lora_path          = "",
    adapter_lora_scale         = 1.0,
    # image layers
    dreambooth_model_path      = "",
    lora_model_path            = "",
    lora_alpha                 = 0.8,
).to(device)

load motion module from models/Motion_Module/v3_sd15_mm.ckpt
load domain lora from models/Motion_Module/v3_sd15_adapter.ckpt


In [9]:
image_transform = transforms.ToTensor()
controlnet_images = [image_transform(Image.open("__assets__/demos/image/interpolation_1.png").convert("RGB"))]
controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(device)
controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
num_controlnet_images = controlnet_images.shape[2]
controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w").to(dtype)
controlnet_images = vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)

prompt = "aerial view, beautiful forest, autumn, 4k, high quality"
n_prompt = "worst quality, low quality, letterboxed"

In [10]:
sample = pipeline(
    prompt,
    negative_prompt     = n_prompt,
    num_inference_steps = 25,
    guidance_scale      = 8.5,
    width               = 384,
    height              = 256,
    video_length        = 16,

    controlnet_images = controlnet_images,
    controlnet_image_index = [0],
    controlnet_conditioning_scale=1.0
).videos
save_videos_grid(sample, "samples/baseline_test_with_control_no_adapter.gif")

  0%|          | 0/25 [00:00<?, ?it/s]

100%|██████████| 25/25 [00:16<00:00,  1.49it/s]
100%|██████████| 16/16 [00:00<00:00, 45.68it/s]


## BrushNet

In [2]:
from animatediff.models.unipaint.brushnet import BrushNetModel
import torch

brushnet_path = "models/BrushNet/random_mask_brushnet_ckpt"
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=dtype).to(device)
pipeline.brushnet = brushnet

# Read Video
Here we read frames of a video and generate a corresponding mask. The data is normalized to [0, 1], in shape (b f) c h w.

In [3]:
import decord
from scripts.mask import RectangularMaskGenerator
decord.bridge.set_bridge("torch")
vr = decord.VideoReader("outpaint_videos/SB_Bear.mp4", width=512, height=512)

video = vr.get_batch(list(range(0,16)))
video = rearrange(video, "f h w c -> c f h w")
frame = torch.unsqueeze(video, dim=0)
mask_generator = RectangularMaskGenerator(mask_l=[0,0.4],
                                          mask_r=[0,0.4],
                                          mask_t=[0,0.4],
                                          mask_b=[0,0.4])
mask = mask_generator(frame)
frame[mask==1]=0
frame = frame/255
frame = frame.to(device, brushnet.dtype)
mask = mask.to(device, brushnet.dtype)
frame = frame*2-1
mask = mask*2-1
# original_mask=(mask.sum(1)[:,None,:,:] < 0).to(frame.dtype)
# conditioning_latents=pipeline.vae.encode(frame).latent_dist.sample() * 0.18215
# height, width = frame.shape[-2:]
# mask = torch.nn.functional.interpolate(
#             original_mask, 
#             size=(
#                 conditioning_latents.shape[-2], 
#                 conditioning_latents.shape[-1]
#             )
#         )
# conditioning_latents = torch.concat([conditioning_latents,mask],1)

In [8]:
sample = pipeline(
    prompt = "bear walking in the forest",
    negative_prompt     = "",
    num_inference_steps = 25,
    guidance_scale      = 12.5,
    width               = 512,
    height              = 512,
    video_length        = 16,

    controlnet_images = None,
    controlnet_image_index = [0],
    controlnet_conditioning_scale=0.0,

    init_video = frame[:,:,:],
    mask_video = mask[:,:,:],
    brushnet_conditioning_scale = 1.0,
    control_guidance_start = 0.0,
    control_guidance_end = 1.0,
    ).videos
save_videos_grid(sample, "samples/baseline_test_Bear_with_brushnet.gif")

100%|██████████| 25/25 [00:19<00:00,  1.27it/s]
100%|██████████| 16/16 [00:00<00:00, 27.38it/s]


In [5]:
save_videos_grid(((frame+1)/2).cpu(), "samples/masked_video_bear.gif")