In [1]:
%load_ext autoreload 
%autoreload 2

In [3]:
import argparse, os, sys, glob, yaml, math, random, json
sys.path.append('.')
sys.path.append('./scripts/evaluation/')

import datetime, time
import numpy as np
from omegaconf import OmegaConf
from collections import OrderedDict
from tqdm import trange, tqdm
from einops import rearrange, repeat
from functools import partial
import torch
from pytorch_lightning import seed_everything

from funcs import load_model_checkpoint, load_prompts, load_image_batch, get_filelist, save_videos
from funcs import batch_ddim_sampling
from utils.utils import instantiate_from_config, encode_attribute_multiple

import torchvision
from pathlib import Path
from PIL import Image
import torch.nn.functional as F
import cv2
import scipy as sp
from scipy import stats

# My utils
from utils.attention_utils import *
from utils.vis_utils import *
from utils.test_list import regularization_dict, collected_prompt_list

## Base T2V

In [8]:
ddim_steps = 25
unconditional_guidance_scale = 12
config = 'configs/inference_t2v_512_v2.0.yaml'
ckpt = 'checkpoint/model.ckpt'
savedir = 'results'
fps = 28
height, width = 320, 512
gpu_num = 1
mode = "base"
n_samples = 1
bs = 1
savefps = 8 
frames = 16 # Change the number of frames here
args_dict = {
    "ckpt_path": ckpt,
    "config": config,
    "mode": mode,
    "fps": fps,
    "width": width,
    "height": height,
    "n_samples": n_samples,
    "bs": bs,
    "ddim_steps": ddim_steps, "ddim_eta": 1.0,
    "unconditional_guidance_scale": unconditional_guidance_scale,
    "savedir": savedir, "frames": frames, 
    "savefps": savefps,
}

args = OmegaConf.create(args_dict)
print(args)

{'ckpt_path': 'checkpoint/model.ckpt', 'config': 'configs/inference_t2v_512_v2.0.yaml', 'mode': 'base', 'fps': 28, 'width': 512, 'height': 320, 'n_samples': 1, 'bs': 1, 'ddim_steps': 25, 'ddim_eta': 1.0, 'unconditional_guidance_scale': 12, 'savedir': 'results', 'frames': 16, 'savefps': 8}


# 1: Load Model

In [9]:
## step 1: model config
## -----------------------------------------------------------------
config = OmegaConf.load(args.config)
model_config = config.pop("model", OmegaConf.create())
model = instantiate_from_config(model_config)
model = model.cuda()
assert os.path.exists(args.ckpt_path), f"Error: checkpoint [{args.ckpt_path}] Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()

## sample shape
assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
## latent noise shape
h, w = args.height // 8, args.width // 8
frames = model.temporal_length if args.frames < 0 else args.frames
channels = model.channels

AE working on z of shape (1, 4, 64, 64) = 16384 dimensions.
>>> model checkpoint loaded.
Frames:  16


In [10]:
# Change here!
#postfix = ''
postfix = '_vstar'

now = datetime.datetime.now().strftime("%Y-%m-%d")
## saving folders
save_dir = os.path.join(args.savedir, now + f'-{mode}' + postfix)

os.makedirs(save_dir, exist_ok=True)
print("Create save_dir: ", save_dir)

Create save_dir:  results/2024-07-29-base_vstar


# Helper Function

In [11]:
def img_callback(pred_x0, i):
    video = model.decode_first_stage_2DAE(pred_x0).clip(-1.0, 1.0)
    video = (video / 2 + 0.5).clamp(0, 1) # -1,1 -> 0,1
    save_path_inter =  f"step{i}.jpg"
    save_path_inter = os.path.join(save_dir_latest,save_path_inter)
    save_image_grid(video, save_path_inter, rescale=False, n_rows=8,)  

# 2. Prepare Prompts

### Active Prompts

In [12]:
for i, p_dict in enumerate(collected_prompt_list):
    print(i, p_dict["prompt"])

0 A young boy becomes an old man, realistic,best quality, 4k
1 A day at the beach from dawn till dusk,eye-level shot
2 A Ferrari driving on the road, starts to snow
3 The seasonal cycle of a lake from frozen winter to autumn,a shot at eye level
4 A day at the beach from dawn till dusk
5 A young boy becomes an old man, and turns into a young girl, neon-lit urban landscape background,dystopian cyberpunk,futuristic design, electronic artwork,moody lighting, ultra-detailed,high-tech implants
6 Superman flying in the sky, sunny day becomes a dark rainy day, best quality, 4k, realistic
7 Spider-Man standing on the beach from morning to evening
8 A peony starts to bloom, in the field
9 Rainbow appears after the rainy day
10 A flower starts to bloom
11 A pizza is being made
12 A mural being painted on a city wall
13 A landscape transitioning from winter to spring
14 A young girl is aging
15 The sun rises from the sea, making the dark sky bright
16 A night sky changing from dusk till dawn
17 Fr

In [13]:
# Choose prompt
prompt_id = 0
prompt_dict = collected_prompt_list[prompt_id]
prompt_list = [
    prompt_dict["prompt"]
]
attribute_list = prompt_dict["subprompts"] # None, if disabling VSP

# Set seeds
seed_list = [128] 

# Set number of frames
frames = 32

cond_image_list = None

In [16]:
# VSTAR Setup

keep_timestep_list = []
save_timestep_list = [*range(1,26)]
save_maps = False # True for saving visualization
save_npy = False

attention_store = AttentionStore(
    base_width=64, base_height=40,
    keep_timestep_list=keep_timestep_list,
    save_timestep_list=save_timestep_list,
    save_maps=save_maps, save_npy=save_npy
)
use_delta_attention = True

# Stardard deviation is a hyperparameter controlling the dynamics
# Smaller values has stronger effect
# For short videos: "64D4", "64D1", "64D8"
# For longer videos, regualarization at res32 becomes necessary: "64D1_32D8"
setup_key = "64D1"
ablation_dict = regularization_dict[setup_key]
if use_delta_attention:
    register_attention_control_vstar(model, attention_store, ablation_dict)
else:
    register_attention_control(model, attention_store)

print("Use delta attention: ", use_delta_attention)
if use_delta_attention and ablation_dict is None:
    print("diag_std: ", diag_std)

post_fix_folder = ""
for i,k in enumerate(ablation_dict["regularize_res_list"]):
    diag = ablation_dict[f'diag_{k}']
    scale = ablation_dict[f'scale_{k}']
    if i!=0:
        post_fix_folder += '_'
    post_fix_folder += f"res{k}-std{diag}"
    if scale != 1.0:
        post_fix_folder += f"-scale{scale}"
        
until_time = ablation_dict["until_time"] 
post_fix_folder += f"-until{until_time}"
    
if use_delta_attention:
    print(post_fix_folder)

Keep attention maps at:  []
Total attention layers:  34
Use delta attention:  True
res64-std1-until13


## Run Test

In [None]:
indices_list = None
interpolation_mode = "linear" #"linear"

## step 3: run over samples
## -----------------------------------------------------------------
start = time.time()
n_rounds = len(seed_list)

for idx in range(0, n_rounds):
    attention_store.reset()
    cur_prompt = prompt_list[0]
    cur_seed = seed_list[idx]
    seed_everything(cur_seed)
    save_prompt = "-".join((cur_prompt.replace("/", "").split(" ")[:10]))
    save_dir_cur = f"{save_prompt}-{cur_seed}"
    num_attribute = len(attribute_list)
    if num_attribute > 0:
        save_dir_cur += f"_embed{num_attribute}"
    if interpolation_mode != "linear":
        save_dir_cur += f"_{interpolation_mode}"
    save_dir_latest = os.path.join(save_dir, save_dir_cur)

    if use_delta_attention:
        save_dir_latest += f"_deltaAttn_f{frames}_{post_fix_folder}" 
    else:
        save_dir_latest += f"_f{frames}"
    attention_store.set_save_dir(os.path.join(save_dir_latest, "attention"))
    
    print(f'Work on prompt {idx + 1} / {n_rounds}... Seed={cur_seed}')
    print(cur_prompt)
    batch_size = args.bs
    noise_shape = [batch_size, channels, frames, h, w]
    fps = torch.tensor([args.fps]*batch_size).to(model.device).long()
   
    g_cpu = torch.Generator(device=model.device)
    g_cpu.manual_seed(cur_seed)
    x_T = torch.randn(noise_shape, device=model.device, generator=g_cpu)
    
    x_T = None
    print(f'----> saved in {save_dir_latest}')
    
    if isinstance(cur_prompt, str):
        prompts = [cur_prompt]
        
    if len(attribute_list) == 0:
        print("Use normal prompt embedding.")
        text_emb = model.get_learned_conditioning(prompts) # (1,77,1024)
    else:
        print("Use attrbites embeddings.")
        text_emb = encode_attribute_multiple(model, attribute_list, frames,interpolation_mode,indices_list=indices_list)


    cond = {"c_crossattn": [text_emb], "fps": fps}
   

    ## inference
    batch_samples = batch_ddim_sampling(
        model, cond, noise_shape, args.n_samples, 
        args.ddim_steps, args.ddim_eta, args.unconditional_guidance_scale,
        verbose=True, img_callback=img_callback,
        x_T=x_T,
    )
    
    ## b,samples,c,t,h,w
    file_names = [save_dir_cur]
    save_videos(batch_samples, save_dir_latest, file_names, fps=args.savefps,ext_name="gif")
    final_frame_save_dir = os.path.join(save_dir_latest, 'final_video')
    save_image_batch(batch_samples[0,0], final_frame_save_dir, ext_type="jpg")
    
    # Save config
    config_cur = {
        "seed": cur_seed,
        "prompt": cur_prompt,
        "attribute_list": attribute_list,
        "ablation_dict":ablation_dict,
    }
    with open(os.path.join(save_dir_latest, f"{save_dir_cur}.json"), "w") as outfile: 
        json.dump(config_cur, outfile, indent=4)
    print()
    
print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds")