# Import packages

In [2]:
import torch
import numpy as np
import cv2
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import copy

# Load the frame interpolation model

In [3]:
device = torch.device('cuda')
precision = torch.float16

model = torch.jit.load('./check/film_net_fp16.pt', map_location='cpu')
model.eval().to(device=device, dtype=precision);

path = Path.cwd()

# define the dataset class and interpolation function

In [4]:
class image_data:
    def __init__(self, video_path ='./sc_60.mp4'):
        self.video_path = Path(video_path)
        # Open the video file
        print(self.video_path.as_posix())
        self.cap = cv2.VideoCapture(self.video_path.as_posix())
        # Loop through frames
        self.fps = self.cap.get(cv2.CAP_PROP_FPS)
        self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        print(f'frame_count: {self.frame_count}')
    
    def close(self):
        # Release the video capture object and close the window
        self.cap.release()
        cv2.destroyAllWindows()
    
    def __getitem__(self, frame_number):
        self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number - 1)
        ret, frame = self.cap.read()
        if not ret:
            print("Error: Couldn't read the specified frame.")
            return None
        else:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            return frame

def interp_2frames(f1,f2, precision=torch.float16, device='cuda'):
    img1 = (torch.from_numpy(f1)[None, ...].permute(0, 3, 1, 2)/255.).to(precision).to(device)
    img2 = (torch.from_numpy(f2)[None, ...].permute(0, 3, 1, 2)/255.).to(precision).to(device)
    dt = img1.new_full((1, 1), .5)    
    with torch.no_grad():
        imgmid = model(img1, img2, dt)  # Will be of the same shape as inputs (1, 3, h, w)
    imgmid.clamp_(0,1); 
    img_np = imgmid[0].cpu().permute(1,2,0).to(torch.float32).numpy()
    img_np_uint8 = (img_np * 255).astype(np.uint8)
    
    return img_np_uint8 # return the interpolated. Second frame returned in the next iteration

def write_frames(new_frame_list, slow=False, save_index=0, initial_fps=30, initial_nframes=None):
    if not slow:
        final_nframes = len(new_frame_list)
        # print(f'final_nframes: {final_nframes}')
        # print(f'initial_nframes: {initial_nframes}')
        # print(f'initial_fps: {initial_fps}')
        fps_out = final_nframes/initial_nframes*initial_fps
    else:
        fps_out = initial_fps
    print(f'fps_out: {fps_out}')
    fourcc = cv2.VideoWriter_fourcc(*'XVID')  # You can also use 'XVID' or 'MJPG' as the codec
    height, width = new_frame_list[0].shape[:2]  # Set the width and height of the video frames
    video_writer = cv2.VideoWriter(f'output_video_{save_index}.avi', fourcc, fps_out, (width, height))
    for f in new_frame_list:
        # Write the frame to the video file
        f = cv2.cvtColor(f, cv2.COLOR_RGB2BGR)
        video_writer.write(f)
    # Release the VideoWriter object
    video_writer.release()
    print(f"Video has been written to 'output_video_{save_index}.avi'")

# load the base video

In [5]:
video_path = Path('smile/output_video_2.mp4')
images = image_data(video_path)
images.frame_count, images.fps, images.frame_count/images.fps

smile/output_video_2.mp4
frame_count: 65


(65, 24.0, 2.7083333333333335)

# loop through the frames, interpolate, and save the result

There will be $2 \times n-1$ frames, where $n$ is the number of original frames.

In [9]:
save_index = 1
slow=True
if not slow:
    final_nframes = (images.frame_count*2)-1
    fps_out = final_nframes/images.frame_count*images.fps
else:
    fps_out = images.fps
print(f'fps_out: {fps_out}')
fourcc = cv2.VideoWriter_fourcc(*'XVID')  # You can also use 'XVID' or 'MJPG' as the codec
height, width = images[0].shape[:2]  # Set the width and height of the video frames
video_writer = cv2.VideoWriter(f'{video_path.stem}_{save_index}.avi', fourcc, fps_out, (width, height))

# new_frames = []
for i in tqdm(range(images.frame_count+1)):
    f1 = images[i]
    if i==images.frame_count: # last frame
        f2 = images[0] # make loopable
    else:
        f2 = images[i+1]
    if f2 is None:
        f1 = cv2.cvtColor(f1, cv2.COLOR_RGB2BGR)
        video_writer.write(f1)
        
        continue
    
    if (f1==f2).all(): continue
    
    f_interp = interp_2frames(f1, f2)
    f1 = cv2.cvtColor(f1, cv2.COLOR_RGB2BGR)
    f_interp = cv2.cvtColor(f_interp, cv2.COLOR_RGB2BGR)

    video_writer.write(f1)
    video_writer.write(f_interp)
    
images.close()
video_writer.release()
print(f"Video has been written to '{video_path.stem}_{save_index}.avi'")

fps_out: 54.060340000000004


100%|████████████████████████████████████████| 101/101 [02:29<00:00,  1.48s/it]

Video has been written to 'generated_svd_2_0_1_1.avi'





In [159]:
images.fps

2.0