In [5]:
import torch
from amt.AMT_L import Model

amtl = Model(corr_radius=3,
    corr_lvls=4,
    num_flows=5)
ckpt = torch.load("/Users/teli/www/ml/frame_interpolation/AMT/_pretrained/amt-l.pth", map_location="cpu")
amtl.load_state_dict(ckpt["state_dict"])

In [22]:
import cv2

In [19]:
import os
import sys
import tqdm
import torch
import numpy as np
import os.path as osp
from torchvision.utils import make_grid

from amt.utils.utils import (
    read, write,
    img2tensor, tensor2img,
    check_dim_and_resize, InputPadder
    )


# ----------------------- Initialization ----------------------- 
device = torch.device('cuda' if torch.cuda.is_available() else 'mps')

img0_path = "/Users/teli/www/pyProject/motion_blur_pytorch/dist/images/amt-34-0.jpg"
img1_path = "/Users/teli/www/pyProject/motion_blur_pytorch/dist/images/amt-34-1.jpg"
out_path = "/Users/teli/www/pyProject/motion_blur_pytorch/dist/output"
if osp.exists(out_path) is False:
    os.makedirs(out_path)

model = amtl.to(device)
model.eval()

# -----------------------  Load input frames ----------------------- 
img0 = read(img0_path)
img1 = read(img1_path)
img0_t = img2tensor(img0).to(device)
img1_t = img2tensor(img1).to(device)


In [20]:

def interpoate(model, img0, img1, frame_ratio=24, iters=4): # 2~7
    inputs = [img0_t, img1_t]
    
    if device == 'cuda':
        anchor_resolution = 1024 * 512
        anchor_memory = 1500 * 1024**2
        anchor_memory_bias = 2500 * 1024**2
        vram_avail = torch.cuda.get_device_properties(device).total_memory
    else:
        # Do not resize in cpu mode
        anchor_resolution = 8192*8192
        anchor_memory = 1
        anchor_memory_bias = 0
        vram_avail = 1
    embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)

    inputs = check_dim_and_resize(inputs)
    h, w = inputs[0].shape[-2:]
    scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
    scale = 1 if scale > 1 else scale
    scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
    if scale < 1:
        print(f"显卡显存限制, 视频将会被缩小 {scale:.2f}倍")
    padding = int(16 / scale)
    padder = InputPadder(inputs[0].shape, padding)
    inputs = padder.pad(*inputs)

    for i in range(iters):
        print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
        outputs = [inputs[0]]
        for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
            in_0 = in_0.to(device)
            in_1 = in_1.to(device)
            with torch.no_grad():
                imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
            outputs += [imgt_pred.cpu(), in_1.cpu()]
        inputs = outputs
    outputs = padder.unpad(*outputs)

    size = outputs[0].shape[2:][::-1]
    writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size)
    for i, imgt_pred in enumerate(outputs):
        imgt_pred = tensor2img(imgt_pred)
        imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)
        writer.write(imgt_pred)
    writer.release()


In [23]:
interpoate(model, img0, img1, frame_ratio=24, iters=3)

Iter 1. input_frames=2 output_frames=3
Iter 2. input_frames=3 output_frames=5
Iter 3. input_frames=5 output_frames=9


In [11]:
data.keys()

odict_keys(['epoch', 'state_dict'])