## Baseline model

In [5]:
import torch
from omegaconf import OmegaConf

from einops import rearrange

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler

from unipaint.pipelines.pipeline_unipaint import AnimationPipeline
from unipaint.models.unet import UNet3DConditionModel
from unipaint.models.sparse_controlnet import SparseControlNetModel
from unipaint.models.unipaint.brushnet import BrushNetModel

from unipaint.utils.util import load_weights,save_videos_grid
import decord
decord.bridge.set_bridge("torch")
from unipaint.utils.mask import StaticRectangularMaskGenerator

data = "Running4"
path = "models/StableDiffusion/stable-diffusion-v1-5"
brushnet_path = "models/BrushNet/random_mask_brushnet_ckpt"
device = "cuda:4"
dtype = torch.float16

use_motion_module = True
use_adapter = True

motion_module_path = "models/Motion_Module/v3_sd15_mm.ckpt" if use_motion_module else ""
adapter_path = "models/Motion_Module/v3_sd15_adapter.ckpt" if use_adapter else ""

In [6]:
#load base model
tokenizer        = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder     = CLIPTextModel.from_pretrained(path, subfolder="text_encoder").to(device,dtype)
vae              = AutoencoderKL.from_pretrained(path, subfolder="vae").to(device, dtype)

inference_config = OmegaConf.load("configs/inference/inference-v3.yaml")
unet_additional_kwargs = OmegaConf.to_container(inference_config.unet_additional_kwargs)
unet_additional_kwargs["unet_use_moe"] = True
unet             = UNet3DConditionModel.from_pretrained_2d(path, subfolder="unet", unet_additional_kwargs=unet_additional_kwargs).to(device, dtype)

#load controlnet
unet.config.num_attention_heads = 8
unet.config.projection_class_embeddings_input_dim = None
controlnet_config = OmegaConf.load("configs/inference/sparsectrl/latent_condition.yaml")
controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {}))
controlnet_state_dict = torch.load("models/Motion_Module/v3_sd15_sparsectrl_rgb.ckpt", map_location="cpu")
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name}
controlnet_state_dict.pop("animatediff_config", "")
controlnet.load_state_dict(controlnet_state_dict)
controlnet.to(device, dtype)

#load brushnet
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=dtype).to(device)

#build pipeline
pipeline = AnimationPipeline(
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
            controlnet=controlnet, brushnet = brushnet,
            scheduler=DDIMScheduler(beta_start=0.00085,
                                                beta_end=0.012,
                                                beta_schedule="linear",
                                                steps_offset=0,
                                                clip_sample=False)
                                                ).to(device)

loaded 3D unet's pretrained weights from models/StableDiffusion/stable-diffusion-v1-5 ...
### missing keys: 520; 
### unexpected keys: 0;
Using MoE
Marked down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff for replacement with MoEFFN
Marked down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff for replacement with MoEFFN
Marked down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff for replacement with MoEFFN
Marked down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff for replacement with MoEFFN
Marked down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff for replacement with MoEFFN
Marked down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff for replacement with MoEFFN
Marked down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff for replacement with MoEFFN
Marked down_blocks.3.motion_modules.1.temporal_transformer.transformer_bloc

In [None]:
pipeline = load_weights(
    pipeline,
    # motion module
    motion_module_path         = motion_module_path,
    motion_module_lora_configs = [],
    # domain adapter
    adapter_lora_path          = adapter_path,
    adapter_lora_scale         = 1.0,
    # image layers
    dreambooth_model_path      = "",
    lora_model_path            = "",
    lora_alpha                 = 0.8,
).to(device)

load motion module from models/Motion_Module/v3_sd15_mm.ckpt
load domain lora from models/Motion_Module/v3_sd15_adapter.ckpt


# Read Video
Here we read frames of a video and generate a corresponding mask. The data is normalized to [0, 1], in shape (b f) c h w.

In [None]:
video_path = f"outpaint_videos/SB_{data}.mp4"
vr = decord.VideoReader(video_path, width=512, height=512)

video = vr.get_batch(list(range(0,16)))
video = rearrange(video, "f h w c -> c f h w")
frame = torch.clone(torch.unsqueeze(video/255, dim=0)).to(device, brushnet.dtype)
del(vr)
mask_generator = StaticRectangularMaskGenerator(mask_l=[0,0.4],
                                          mask_r=[0,0.4],
                                          mask_t=[0,0.4],
                                          mask_b=[0,0.4])
mask = mask_generator(frame)
frame[mask==1]=0
mask = mask.to(device, brushnet.dtype)
frame = frame*2.-1.
mask = mask*2.-1.

save_videos_grid(((frame+1)/2).cpu(), f"samples/{data}/masked_video.gif")

In [None]:
prompt = "a man in blue, running"
n_prompt = "worst quality, low quality, letterboxed"
sample = pipeline(
    prompt = prompt,
    negative_prompt     = n_prompt,
    num_inference_steps = 25,
    guidance_scale      = 12.5,
    width               = 512,
    height              = 512,
    video_length        = 16,

    controlnet_images = None,
    controlnet_image_index = [0],
    controlnet_conditioning_scale=0.0,

    init_video = frame[:,:,:],
    mask_video = mask[:,:,:],
    brushnet_conditioning_scale = 1.0,
    control_guidance_start = 0.0,
    control_guidance_end = 1.0,
    ).videos
save_videos_grid(sample, f"samples/{data}/brushnet_mm_{use_motion_module}_adapter_{use_adapter}.gif")

controlnet_images = torch.clone(frame)
num_controlnet_images = controlnet_images.shape[2]
controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
controlnet_images = vae.encode(controlnet_images).latent_dist.sample() * 0.18215
controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)

sample = pipeline(
    prompt = prompt,
    negative_prompt     = n_prompt,
    num_inference_steps = 25,
    guidance_scale      = 12.5,
    width               = 512,
    height              = 512,
    video_length        = 16,

    controlnet_images = controlnet_images,
    controlnet_image_index = [0],
    controlnet_conditioning_scale=1.0
).videos
save_videos_grid(sample, f"samples/{data}/controlnet_mm_{use_motion_module}_adapter_{use_adapter}.gif")

100%|██████████| 25/25 [00:34<00:00,  1.38s/it]
100%|██████████| 16/16 [00:01<00:00, 13.69it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 95.11 GiB of which 88.25 MiB is free. Process 381632 has 320.00 MiB memory in use. Process 381635 has 320.00 MiB memory in use. Process 381630 has 320.00 MiB memory in use. Process 381628 has 64.61 GiB memory in use. Process 381634 has 320.00 MiB memory in use. Process 381629 has 320.00 MiB memory in use. Process 381633 has 320.00 MiB memory in use. Process 381631 has 320.00 MiB memory in use. Process 3075899 has 28.17 GiB memory in use. Of the allocated memory 27.26 GiB is allocated by PyTorch, and 411.43 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Possible trainings

1. Trainable params: options
    1. Motion Lora
    2. Temporal layers in Brushnet (How?)
    3. Whole motion module
2. Data
    1. WebVid or similar video datasets
    2. Maybe some video segmentation dataset? These dataset should provide masks and corresponding tags
3. How to train
    - I have no experience of training a large model from scratch


In [None]:
import os
import torch
from omegaconf import OmegaConf

from einops import rearrange

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler

from unipaint.pipelines.pipeline_unipaint import AnimationPipeline
from unipaint.models.unet import UNet3DConditionModel
from unipaint.models.sparse_controlnet import SparseControlNetModel
from unipaint.models.unipaint.brushnet import BrushNetModel

from unipaint.utils.util import load_weights,save_videos_grid
import decord
decord.bridge.set_bridge("torch")
from unipaint.utils.mask import StaticRectangularMaskGenerator
path = "models/StableDiffusion/stable-diffusion-v1-5"
brushnet_path = "models/BrushNet/random_mask_brushnet_ckpt"
device = "cuda:4"
dtype = torch.float16

use_motion_module = True
use_adapter = True

motion_module_path = "models/Motion_Module/v3_sd15_mm.ckpt" if use_motion_module else ""
adapter_path = "models/Motion_Module/v3_sd15_adapter.ckpt" if use_adapter else ""

#load base model
tokenizer        = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder     = CLIPTextModel.from_pretrained(path, subfolder="text_encoder").to(device,dtype)
vae              = AutoencoderKL.from_pretrained(path, subfolder="vae").to(device, dtype)

inference_config = OmegaConf.load("configs/inference/inference-v3.yaml")
unet             = UNet3DConditionModel.from_pretrained_2d(path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).to(device, dtype)

#load brushnet
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=dtype).to(device)

#build pipeline
pipeline = AnimationPipeline(
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
            brushnet = brushnet,
            scheduler=DDIMScheduler(beta_start=0.00085,
                                                beta_end=0.012,
                                                beta_schedule="linear",
                                                steps_offset=0,
                                                clip_sample=False)
                                                ).to(device)

pipeline = load_weights(
    pipeline,
    # motion module
    motion_module_path         = motion_module_path,
    motion_module_lora_configs = [],
    # domain adapter
    adapter_lora_path          = adapter_path,
    adapter_lora_scale         = 1.0,
    # image layers
    dreambooth_model_path      = "",
    lora_model_path            = "",
    lora_alpha                 = 0.8,
).to(device)

unet_checkpoint_path = "outputs/unipaint_training_mask-2024-09-03T18-40-37/checkpoints/checkpoint.ckpt"
unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
if "global_step" in unet_checkpoint_path: print(f"global_step: {unet_checkpoint_path['global_step']}")
state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
new_state_dict = {}
for k, v in state_dict.items():
    new_key = k.replace('module.', '')  # Remove 'module.' from each key
    new_state_dict[new_key] = v

m, u = unet.load_state_dict(new_state_dict, strict=False)

  from .autonotebook import tqdm as notebook_tqdm


loaded 3D unet's pretrained weights from models/StableDiffusion/stable-diffusion-v1-5 ...
### missing keys: 520; 
### unexpected keys: 0;
### Motion Module Parameters: 417.1376 M
load motion module from models/Motion_Module/v3_sd15_mm.ckpt
load domain lora from models/Motion_Module/v3_sd15_adapter.ckpt
global_step: 13000


In [4]:
import numpy as np
name = "chicken"
if name == "eagle":
    video_path = f"outpaint_videos/SB_Eagle.mp4"
    prompt = "a white head bald eagle"
    n_prompt = ""
    data = np.load('SB_Eagle_mask.npz')
if name == "chicken":
    video_path = f"outpaint_videos/SB_Eagle.mp4"
    prompt = "a chicke"
    n_prompt = ""
    data = np.load('samples/SB_Eagle_mask.npz')
if name == "dog":
    video_path = f"outpaint_videos/SB_Dog1.mp4"
    prompt = "a white fluffy dog walking"
    n_prompt = ""
    data = np.load('SB_Dog1_mask.npz')

vr = decord.VideoReader(video_path, width=512, height=512)

video = vr.get_batch(list(range(0,16)))
video = rearrange(video, "f h w c -> c f h w")
frame = torch.clone(torch.unsqueeze(video/255, dim=0)).to(device, brushnet.dtype)
del(vr)

mask = torch.tensor(np.unpackbits(data['mask']).reshape((16, 512, 512)))
mask = torch.cat([mask.unsqueeze(dim=0)]*3,dim=0).unsqueeze(dim=0)
frame[mask==1]=0
mask = mask.to(device, brushnet.dtype)
frame = frame*2.-1.
mask = mask*2.-1.
save_videos_grid(((frame+1)/2).cpu(), f"./test_masked_video_{name}.gif")
samples = []
samples.append(((frame+1)/2).cpu())
for scale in np.arange(0.0, 1.1, 0.1).tolist():
    sample = pipeline(
        prompt = prompt,
        negative_prompt     = n_prompt,
        num_inference_steps = 25,
        guidance_scale      = 12.5,
        width               = 512,
        height              = 512,
        video_length        = 16,

        init_video = frame[:,:,:],
        mask_video = mask[:,:,:],
        brushnet_conditioning_scale = scale,
        control_guidance_start = 0.0,
        control_guidance_end = 1.0,
        ).videos
    samples.append(sample)
samples = torch.concat(samples)
save_videos_grid(samples, f"./sam_test_{name}.gif")

100%|██████████| 25/25 [00:19<00:00,  1.28it/s]
100%|██████████| 16/16 [00:00<00:00, 26.11it/s]
100%|██████████| 25/25 [00:19<00:00,  1.27it/s]
100%|██████████| 16/16 [00:00<00:00, 27.34it/s]
 96%|█████████▌| 24/25 [00:19<00:00,  1.24it/s]


KeyboardInterrupt: 

In [8]:
from unipaint.utils.convert_to_moe import replace_ffn_with_moeffn, task_context, get_task_name
task_context("interpolation")
task_name = get_task_name()
print(task_name)

None
