# Preparation

In [None]:
import os, sys

workspace = os.path.abspath('../')

ckpt_path = './checkpoints/lvcd.ckpt'
svd_path = './checkpoints/svd.safetensors'
config_path = './configs/lvcd.yaml'
device = 'cuda:0'

# Load model

In [None]:
os.chdir(workspace)
sys.path.append(workspace)
%load_ext autoreload

import numpy as np
import argparse
import torch
from utils import *

model = load_model(device, config_path, svd_path, ckpt_path, use_xformer=True)

# Lineart extractor
from lineart_extractor.annotator.lineart import LineartDetector
detector = LineartDetector(device)

In [None]:
print(model.model.diffusion_model)

# Load data

In [None]:
%autoreload
import json
from PIL import Image
from glob import glob

root = './inference/test/sample_1'

N = len( glob(f'{root}/*.png') )

inp = argparse.ArgumentParser()
inp.resolution = [320, 576]

inp.imgs = []
inp.skts = []
for i in range(N):
    img = load_img(f'{root}/{i}.png', inp.resolution).to(device).unsqueeze(0)
    inp.imgs.append(img)
    np_img = np.array( Image.open(f'{root}/{i}.png').convert('RGB') )
    with torch.no_grad():
        skt = detector(np_img, coarse=False)
    skt = torch.from_numpy(skt).float()
    skt = (skt / 255.0)
    skt = skt[None, None, :, :].repeat(1, 3, 1, 1)
    skt = 1.0 - skt
    inp.skts.append(skt)

# Sample video

In [None]:
%autoreload
from sample_func import sample_video, decode_video

arg = argparse.ArgumentParser()

arg.ref_mode = 'prevref'
arg.num_frames = 19
arg.num_steps = 25
arg.overlap = 4
arg.prev_attn_steps = 25
arg.scale = [1.0, 1.0]
arg.seed = 1234
arg.decoding_t = 10
arg.decoding_olap = 3
arg.decoding_first = 1
arg.fps_id = 6
arg.motion_bucket_id = 160
arg.cond_aug = 0.0

sample = sample_video(model, device, inp, arg, verbose=True)
frames = decode_video(model, device, sample, arg)

make_video('./inference', frames.unsqueeze(0), fps=20, cols=1, name='output')