# Play with text-to-video

In [1]:
%env MYHOME=/home/dcor/itanlevin
%env HF_HUB_CACHE=/home/dcor/itanlevin/.cache/huggingface/hub
%env TRANSFORMERS_CACHE=/home/dcor/itanlevin/.cache/huggingface/hub

env: MYHOME=/home/dcor/itanlevin
env: HF_HUB_CACHE=/home/dcor/itanlevin/.cache/huggingface/hub
env: TRANSFORMERS_CACHE=/home/dcor/itanlevin/.cache/huggingface/hub


In [17]:
import torch
from diffusers import DiffusionPipeline
import imageio
import cv2
import os
from ipywidgets import Video

pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda:1")

# memory optimization
pipe.enable_vae_slicing()

# "box1" "yoga4" "crabdraw" "breakdancing" "penguin" "yoga3" "yoyo" "car1" "fish" "box2" "sax_play2" "dancers3" "runner1" "biking" "hike2" "boat2" "sax_play1"

# prompt = "the man sailing the boat, his hands deftly manipulate the oars, while his body shifts subtly to maintain balance, while his boat moves foraward in the river"
prompt = "A color drawing of a humanoid frog singing and playing the mandolin"
# prompt = "A fencer stands in the distance in en garde position. His entire body is visible, and is ready to advance"
# note that "A stickman is jumping." or "A black and white smiley face emoji turns from happy to sad." works pretty bad ...
video_frames = pipe(prompt, num_frames=24).frames

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [18]:
def frames_to_vid(video_frames, output_vid_path):
    len_files = len(video_frames)
    writer = imageio.get_writer(output_vid_path, fps=8)
    for im in video_frames:
        writer.append_data(im)
    writer.close()
    
frames_to_vid(video_frames.squeeze(), output_vid_path="stickman_basic_text2vid.mp4")
Video.from_file("stickman_basic_text2vid.mp4")



Video(value=b'\x00\x00\x00 ftypisom\x00\x00\x02\x00isomiso2avc1mp41\x00\x00\x00\x08free...')

# SDS Video
## the SDS loss is based on Wors-as-Image (a first simple attempt)

In [None]:
from diffusers import DiffusionPipeline
import torch.nn as nn
import matplotlib.pyplot as plt
import torch
from torch.optim.lr_scheduler import LambdaLR
import pydiffvg
import kornia.augmentation as K
from easydict import EasyDict as edict
import numpy as np
from tqdm import tqdm
import imageio
import cv2
import os
from ipywidgets import Video

# Utils

In [None]:
# ==================================
# ====== video realted utils =======
# ==================================
def frames_to_vid(video_frames, output_vid_path):
    len_files = len(video_frames)
    writer = imageio.get_writer(output_vid_path, fps=8)
    for im in video_frames:
        writer.append_data(im)
    writer.close()

def render_frames_to_tensor(frames_shapes, frames_shapes_grous, w, h, render, device):
    # returns a [16, 256, 256, 3] video tensor
    frames_init = []
    for i in range(len(frames_shapes)):
        shapes = frames_shapes[i]
        shape_groups = frames_shapes_grous[i]
        scene_args = pydiffvg.RenderFunction.serialize_scene(w, h, shapes, shape_groups)
        cur_im = render(w, h, 2, 2, 0, None, *scene_args)
    
        cur_im = cur_im[:, :, 3:4] * cur_im[:, :, :3] + \
               torch.ones(cur_im.shape[0], cur_im.shape[1], 3, device=device) * (1 - cur_im[:, :, 3:4])
        cur_im = cur_im[:, :, :3]
        frames_init.append(cur_im)
    return torch.stack(frames_init)


def save_mp4_from_tensor(frames_tensor, output_vid_path):
    # input is a [16, 256, 256, 3] video
    frames_copy = frames_tensor.clone()
    frames_output = []
    for i in range(frames_copy.shape[0]):
        cur_im = frames_copy[i]
        cur_im = cur_im[:, :, :3].detach().cpu().numpy()
        cur_im = (cur_im * 255).astype(np.uint8)
        frames_output.append(cur_im)
    frames_to_vid(frames_output, output_vid_path=output_vid_path)
    
    
# ==================================
# ====== word-as-image utils =======
# ==================================
def learning_rate_decay(step,
                        lr_init,
                        lr_final,
                        max_steps,
                        lr_delay_steps=0,
                        lr_delay_mult=1):
    if lr_delay_steps > 0:
    # A kind of reverse cosine decay.
        delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
            0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1))
    else:
        delay_rate = 1.
    t = np.clip(step / max_steps, 0, 1)
    log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
    return delay_rate * log_lerp

def get_data_augs(cut_size):
    augmentations = []
    augmentations.append(K.RandomPerspective(distortion_scale=0.5, p=0.7))
    augmentations.append(K.RandomCrop(size=(cut_size, cut_size), pad_if_needed=True, padding_mode='reflect', p=1.0))
    return nn.Sequential(*augmentations)

def init_shapes(svg_path):
    parameters = edict()
    parameters.point = []
    frames_shapes = []
    frames_shapes_group = []
    
    svg = f'{svg_path}.svg'
    for i in range(16): # extending the init SVG into 16 frames
        canvas_width, canvas_height, shapes_init, shape_groups_init = pydiffvg.svg_to_scene(svg)    
        for path in shapes_init:
            path.points.requires_grad = True
            parameters.point.append(path.points)
        frames_shapes.append(shapes_init)
        frames_shapes_group.append(shape_groups_init)
    return frames_shapes, frames_shapes_group, parameters


# ==================================
# ======= sds loss (naive) =========
# ==================================
# TODO: think about a better loss
class SDSVideoLoss(nn.Module):
    def __init__(self, cfg, device):
        super(SDSVideoLoss, self).__init__()
        self.cfg = cfg
        self.device = device
        self.pipe = DiffusionPipeline.from_pretrained(cfg.model_name, torch_dtype=torch.float16, variant="fp16")
        self.pipe = self.pipe.to(self.device)
        self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device)
        self.sigmas = (1 - self.pipe.scheduler.alphas_cumprod).to(self.device)

        self.text_embeddings = None
        self.embed_text()

    def embed_text(self):
        # tokenizer and embed text
        text_input = self.pipe.tokenizer(self.cfg.caption, padding="max_length",
                                         max_length=self.pipe.tokenizer.model_max_length,
                                         truncation=True, return_tensors="pt")
        uncond_input = self.pipe.tokenizer([""], padding="max_length",
                                         max_length=text_input.input_ids.shape[-1],
                                         return_tensors="pt")
        with torch.no_grad():
            text_embeddings = self.pipe.text_encoder(text_input.input_ids.to(self.device))[0]
            uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.device))[0]
        self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
        self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
        del self.pipe.tokenizer
        del self.pipe.text_encoder


    def forward(self, x_aug):
        # I think that input shape of x should be (1, 16, 3, 256, 256), for 16 frames
        sds_loss = 0
        x = x_aug * 2. - 1. # encode rendered image
        with torch.cuda.amp.autocast():
            batch_size, num_frames, channels, height, width = x.shape
            x = x.reshape(batch_size * num_frames, channels, height, width) # I think that x shape should be (16, 3, 256, 256), for the VAE encoder
            init_latent_z = (self.pipe.vae.encode(x).latent_dist.sample()) # init_latent_z shape is now [16, 4, 32, 32]
            frames, channel, h_, w_ = init_latent_z.shape
            init_latent_z = init_latent_z[None, :].reshape(batch_size, num_frames, channel, h_, w_).permute(0, 2, 1, 3, 4) # shape should be (1, 4, 16, 32, 32)
            
        latent_z = self.pipe.vae.config.scaling_factor * init_latent_z  # scaling_factor * init_latents

        with torch.inference_mode():
            # sample timesteps
            timestep = torch.randint(
                low=400,
                high=min(950, self.cfg.timesteps) - 1,  # avoid highest timestep | diffusion.timesteps=1000
                size=(latent_z.shape[0],),
                device=self.device, dtype=torch.long)

            # add noise
            eps = torch.randn_like(latent_z)
            # zt = alpha_t * latent_z + sigma_t * eps
            noised_latent_zt = self.pipe.scheduler.add_noise(latent_z, eps, timestep)

            # denoise
            z_in = torch.cat([noised_latent_zt] * 2)  # expand latents for classifier free guidance
            timestep_in = torch.cat([timestep] * 2)
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                eps_t_uncond, eps_t = self.pipe.unet(z_in, timestep, encoder_hidden_states=self.text_embeddings).sample.float().chunk(2)
            
            eps_t = eps_t_uncond + self.cfg.guidance_scale * (eps_t - eps_t_uncond)

            # w = alphas[timestep]^0.5 * (1 - alphas[timestep]) = alphas[timestep]^0.5 * sigmas[timestep]
            grad_z = self.alphas[timestep]**0.5 * self.sigmas[timestep] * (eps_t - eps)
            assert torch.isfinite(grad_z).all()
            grad_z = torch.nan_to_num(grad_z.detach().float(), 0.0, 0.0, 0.0)

        sds_loss = grad_z.clone() * latent_z
        del grad_z

        sds_loss = sds_loss.sum(1).mean()
        return sds_loss

# Main run

In [None]:
class Config:
    model_name = "damo-vilab/text-to-video-ms-1.7b"
    timesteps = 1000
    guidance_scale = 100
    batch_size = 1
    render_size = 256#600
    cut_size = 256
    num_iter = 501
    # num_iter = 10
    
    lr_base = 0.5 # we need to play with the lr
    lr_init = 0.002
    lr_final = 0.0008
    lr_delay_mult = 0.1
    lr_delay_steps = 100
    
    target = "svg_input/smile2-01"
    caption = "A black and white smiley face emoji turns from happy to sad."
    
    
# this is the main run
cfg = Config()
if not os.path.exists("./output_videos"):
    os.mkdir("./output_videos")

pydiffvg.set_use_gpu(torch.cuda.is_available())
device = pydiffvg.get_device()

sds_loss = SDSVideoLoss(cfg, device)

h, w = cfg.render_size, cfg.render_size
data_augs = get_data_augs(cfg.cut_size)

render = pydiffvg.RenderFunction.apply
# stack the initial svg into 16 frames (for now, the SVG is duplicated and the gradients are backpropogated to all frames)
shapes_lst, shape_groups_lst, parameters = init_shapes(svg_path=cfg.target) 

output_vid_path = "output_videos/init_vid.mp4"
frames_tensor = render_frames_to_tensor(shapes_lst, shape_groups_lst, w, h, render, device)
save_mp4_from_tensor(frames_tensor, output_vid_path)
Video.from_file(output_vid_path)

num_iter = cfg.num_iter
pg = [{'params': parameters["point"], 'lr': cfg.lr_base}]
optim = torch.optim.Adam(pg, betas=(0.9, 0.9), eps=1e-6)

lr_lambda = lambda step: learning_rate_decay(step, cfg.lr_init, cfg.lr_final, num_iter,
                                             lr_delay_steps=cfg.lr_delay_steps,
                                             lr_delay_mult=cfg.lr_delay_mult) / cfg.lr_init

scheduler = LambdaLR(optim, lr_lambda=lr_lambda, last_epoch=-1)

t_range = tqdm(range(num_iter))
for step in t_range:
    optim.zero_grad()

    # render image
    vid_tensor = render_frames_to_tensor(shapes_lst, shape_groups_lst, w, h, render, device)
    if step % 100 == 0:
        save_mp4_from_tensor(vid_tensor, f"output_videos/{step}.mp4")
    
    # cur shape before is (16, 256, 256, 3)
    x = vid_tensor.unsqueeze(0).permute(0, 1, 4, 2, 3)  # (16, 256, 256, 3) -> (1, 16, 3, 256, 256)
    x = x.repeat(cfg.batch_size, 1, 1, 1, 1)
    x_aug = x # for now we skip the augment, but we can add it later (x_aug = data_augs.forward(x))

    # compute diffusion loss per pixel
    loss = sds_loss(x_aug)

    t_range.set_postfix({'loss': loss.item()})
    loss.backward()
    optim.step()
    scheduler.step()

In [None]:
Video.from_file("output_videos/init_vid.mp4")

In [None]:
Video.from_file("output_videos/500.mp4")