In [1]:
import tensorflow as tf
from flow_models import *
from model_trainer import AnimeModel
from tensorflow.keras import mixed_precision
import numpy as np
from pathlib import Path
import cv2

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: GeForce RTX 2070 SUPER, compute capability 7.5


In [2]:
frame_size = (960,540)
interp_ratio = [0.4,0.8]
model_f = hr_3_2_16
weight_dir = 'savedmodels/hr3216bilinear7/20'

In [3]:
inputs = tf.keras.Input((frame_size[1],frame_size[0],6))
anime_model = AnimeModel(inputs, model_f, interp_ratio)
anime_model.load_weights(weight_dir)

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 540, 960, 6) 0                                            
__________________________________________________________________________________________________
HR_0 (HighResolutionModule)     [(None, 540, 960, 16 9712        input_1[0][0]                    
__________________________________________________________________________________________________
HR_1 (HighResolutionModule)     [(None, 540, 960, 16 52144       HR_0[0][0]                       
__________________________________________________________________________________________________
HR_2 (HighResolutionModule)     [(None, 540, 960, 16 268128      HR_1[0][0]                       
                                                                 HR_1[0][1]                   

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x149ac1ea160>

In [4]:
vid_dir = Path('data/cut')
vid_paths = [str(vid_dir/vn) for vn in os.listdir(vid_dir)]

In [5]:
from tqdm.notebook import trange

In [6]:
cap = cv2.VideoCapture(vid_paths[0])
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
writer = cv2.VideoWriter(f'{vid_paths[0]}_interp.mp4',fourcc,60,frame_size)
ret, frame = cap.read()
for i in trange(1600):
    if not cap.isOpened():
        break
    if ret:
        frame0 = frame
    else:
        break
    # ret, _ = cap.read()
    # if not ret:
    #     break
    
    ret, frame = cap.read()
    if ret:
        frame1 = frame
    else:
        break

    ret, frame = cap.read()
    if ret:
        frame2 = frame
    else:
        break
    frame0_resized = cv2.resize(frame0, dsize=frame_size)
    frame1_resized = cv2.resize(frame1, dsize=frame_size)
    frame2_resized = cv2.resize(frame2, dsize=frame_size)
    concated1 = np.concatenate([frame0_resized,frame1_resized],axis=-1).astype(np.float32)/ 255.0
    concated2 = np.concatenate([frame2_resized,frame1_resized],axis=-1).astype(np.float32)/ 255.0
    outputs = anime_model(np.array([concated1,concated2]))
    outputs = np.round(np.clip(outputs, 0, 1) * 255).astype(np.uint8)
    interped1, interped2 = outputs[0][...,0:3], outputs[0][...,3:6]
    interped3, interped4 = outputs[1][...,3:6], outputs[1][...,0:3]
    writer.write(frame0_resized)
    writer.write(interped1)
    writer.write(interped2)
    writer.write(interped3)
    writer.write(interped4)


cap.release()
writer.release()

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1600.0), HTML(value='')))


