## Baseline model

In [1]:
import torch
from omegaconf import OmegaConf

from einops import rearrange

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler

from unipaint.pipelines.pipeline_unipaint import AnimationPipeline
from unipaint.models.unet import UNet3DConditionModel
from unipaint.models.sparse_controlnet import SparseControlNetModel
from unipaint.models.unipaint.brushnet import BrushNetModel

from unipaint.utils.util import load_weights,save_videos_grid
import decord
decord.bridge.set_bridge("torch")
from unipaint.utils.mask import RectangularMaskGenerator

data = "Running4"
path = "models/StableDiffusion/stable-diffusion-v1-5"
brushnet_path = "models/BrushNet/random_mask_brushnet_ckpt"
device = "cuda:0"
dtype = torch.float16

use_motion_module = True
use_adapter = True

motion_module_path = "models/Motion_Module/v3_sd15_mm.ckpt" if use_motion_module else ""
adapter_path = "models/Motion_Module/v3_sd15_adapter.ckpt" if use_adapter else ""

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#load base model
tokenizer        = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder     = CLIPTextModel.from_pretrained(path, subfolder="text_encoder").to(device,dtype)
vae              = AutoencoderKL.from_pretrained(path, subfolder="vae").to(device, dtype)

inference_config = OmegaConf.load("configs/inference/inference-v3.yaml")
unet             = UNet3DConditionModel.from_pretrained_2d(path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).to(device, dtype)

#load controlnet
unet.config.num_attention_heads = 8
unet.config.projection_class_embeddings_input_dim = None
controlnet_config = OmegaConf.load("configs/inference/sparsectrl/latent_condition.yaml")
controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {}))
controlnet_state_dict = torch.load("models/Motion_Module/v3_sd15_sparsectrl_rgb.ckpt", map_location="cpu")
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name}
controlnet_state_dict.pop("animatediff_config", "")
controlnet.load_state_dict(controlnet_state_dict)
controlnet.to(device, dtype)

#load brushnet
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=dtype).to(device)

#build pipeline
pipeline = AnimationPipeline(
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
            controlnet=controlnet, brushnet = brushnet,
            scheduler=DDIMScheduler(beta_start=0.00085,
                                                beta_end=0.012,
                                                beta_schedule="linear",
                                                steps_offset=0,
                                                clip_sample=False)
                                                ).to(device)

loaded 3D unet's pretrained weights from models/StableDiffusion/stable-diffusion-v1-5 ...
### missing keys: 520; 
### unexpected keys: 0;
### Motion Module Parameters: 417.1376 M


In [3]:
pipeline = load_weights(
    pipeline,
    # motion module
    motion_module_path         = motion_module_path,
    motion_module_lora_configs = [],
    # domain adapter
    adapter_lora_path          = adapter_path,
    adapter_lora_scale         = 1.0,
    # image layers
    dreambooth_model_path      = "",
    lora_model_path            = "",
    lora_alpha                 = 0.8,
).to(device)

load motion module from models/Motion_Module/v3_sd15_mm.ckpt
load domain lora from models/Motion_Module/v3_sd15_adapter.ckpt


# Read Video
Here we read frames of a video and generate a corresponding mask. The data is normalized to [0, 1], in shape (b f) c h w.

In [4]:
video_path = f"outpaint_videos/SB_{data}.mp4"
vr = decord.VideoReader(video_path, width=512, height=512)

video = vr.get_batch(list(range(0,16)))
video = rearrange(video, "f h w c -> c f h w")
frame = torch.clone(torch.unsqueeze(video/255, dim=0)).to(device, brushnet.dtype)
mask_generator = RectangularMaskGenerator(mask_l=[0,0.4],
                                          mask_r=[0,0.4],
                                          mask_t=[0,0.4],
                                          mask_b=[0,0.4])
mask = mask_generator(frame)
frame[mask==1]=0
mask = mask.to(device, brushnet.dtype)
frame = frame*2.-1.
mask = mask*2.-1.

save_videos_grid(((frame+1)/2).cpu(), f"samples/{data}/masked_video.gif")

In [5]:
prompt = "a man in blue, running"
n_prompt = "worst quality, low quality, letterboxed"
sample = pipeline(
    prompt = prompt,
    negative_prompt     = n_prompt,
    num_inference_steps = 25,
    guidance_scale      = 12.5,
    width               = 512,
    height              = 512,
    video_length        = 16,

    controlnet_images = None,
    controlnet_image_index = [0],
    controlnet_conditioning_scale=0.0,

    init_video = frame[:,:,:],
    mask_video = mask[:,:,:],
    brushnet_conditioning_scale = 1.0,
    control_guidance_start = 0.0,
    control_guidance_end = 1.0,
    ).videos
save_videos_grid(sample, f"samples/{data}/brushnet_mm_{use_motion_module}_adapter_{use_adapter}.gif")

controlnet_images = torch.clone(frame)
num_controlnet_images = controlnet_images.shape[2]
controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
controlnet_images = vae.encode(controlnet_images).latent_dist.sample() * 0.18215
controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)

sample = pipeline(
    prompt = prompt,
    negative_prompt     = n_prompt,
    num_inference_steps = 25,
    guidance_scale      = 12.5,
    width               = 512,
    height              = 512,
    video_length        = 16,

    controlnet_images = controlnet_images,
    controlnet_image_index = [0],
    controlnet_conditioning_scale=1.0
).videos
save_videos_grid(sample, f"samples/{data}/controlnet_mm_{use_motion_module}_adapter_{use_adapter}.gif")

100%|██████████| 25/25 [00:19<00:00,  1.28it/s]
100%|██████████| 16/16 [00:00<00:00, 26.21it/s]
100%|██████████| 25/25 [00:20<00:00,  1.25it/s]
100%|██████████| 16/16 [00:00<00:00, 27.40it/s]


## Possible trainings

1. Trainable params: options
    1. Motion Lora
    2. Temporal layers in Brushnet (How?)
    3. Whole motion module
2. Data
    1. WebVid or similar video datasets
    2. Maybe some video segmentation dataset? These dataset should provide masks and corresponding tags
3. How to train
    - I have no experience of training a large model from scratch
