# FILM: Frame Interpolation for Large Motion

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nguu/maze/blob/main/FILM.ipynb)
[![GitHub Repository](https://img.shields.io/github/stars/google-research/frame-interpolation?style=social)](https://github.com/google-research/frame-interpolation)


## Enviroment

In [None]:
import os
import cv2
import torch
import numpy as np
import bisect
import shutil

FILM_DIR = f'/content/film'
CKPT_DIR = f'{FILM_DIR}/pretrained'
os.makedirs(CKPT_DIR, exist_ok=True)

%pip install -q tqdm

def pad_batch(batch, align):
    height, width = batch.shape[1:3]
    height_to_pad = (align - height % align) if height % align != 0 else 0
    width_to_pad = (align - width % align) if width % align != 0 else 0
    crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)]
    batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
                           (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant')
    return batch, crop_region

def load_image(path, align=64):
    image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
    image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align)
    return image_batch, crop_region

def inference(model, img1: str, img2: str, inter_frames: int = 2):
    img_batch_1, crop_region_1 = load_image(img1)
    img_batch_2, crop_region_2 = load_image(img2)
    img_batch_1 = torch.from_numpy(img_batch_1).permute(0, 3, 1, 2)
    img_batch_2 = torch.from_numpy(img_batch_2).permute(0, 3, 1, 2)
    results = [img_batch_1, img_batch_2]
    idxes = [0, inter_frames + 1]
    remains = list(range(1, inter_frames + 1))
    splits = torch.linspace(0, 1, inter_frames + 2)

    for _ in range(len(remains)):
        starts = splits[idxes[:-1]]
        ends = splits[idxes[1:]]
        distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
        matrix = torch.argmin(distances).item()
        start_i, step = np.unravel_index(matrix, distances.shape)
        end_i = start_i + 1

        x0 = results[start_i]
        x1 = results[end_i]
        x0 = x0.cuda()
        x1 = x1.cuda()
        dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])

        with torch.no_grad():
            prediction = model(x0, x1, dt)
        insert_position = bisect.bisect_left(idxes, remains[step])
        idxes.insert(insert_position, remains[step])
        results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
        del remains[step]

    y1, x1, y2, x2 = crop_region_1
    frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy()[y1:y2, x1:x2].copy() for tensor in results]
    return frames

### Loading Model

In [None]:
from torch.hub import download_url_to_file
FILM_MODEL_URL = 'https://huggingface.co/nguu/film-pytorch/resolve/main'
FILM_MODEL = 'film_net_fp32.pt' #@param ["film_net_fp16.pt", "film_net_fp32.pt"]
model_link = f'{FILM_MODEL_URL}/{FILM_MODEL}'
model_path = f'{CKPT_DIR}/{FILM_MODEL}'
if not os.path.exists(model_path):
  download_url_to_file(model_link, model_path)

device = torch.device('cuda')
precision = torch.float16 if FILM_MODEL == 'film_net_fp16.pt' else torch.float32
model = torch.jit.load(model_path, map_location='cpu')
model.eval().to(device=device, dtype=precision)


## Processing

### Extract Video Frames (Optional)

In [None]:
INPUT_VIDEO = '/content/video.mp4' #@param {type:'string'}
EXPORT_FRAME_DIR = '/content/film/temp' #@param {type:'string'}
if os.path.exists(EXPORT_FRAME_DIR):
  shutil.rmtree(EXPORT_FRAME_DIR)
os.makedirs(EXPORT_FRAME_DIR, exist_ok=True)

os.system(f'ffmpeg -i "{INPUT_VIDEO}" "{EXPORT_FRAME_DIR}/%06d.png"')


### Frame Interpolation

In [None]:
from tqdm import tqdm
INTER_NUM = 2 #@param {type:'integer'}
INPUT_FRAME_DIR = '/content/film/temp' #@param {type:'string'}
OUTPUT_FRAME_DIR = '/content/film/output' #@param {type:'string'}

sources = sorted([f'{INPUT_FRAME_DIR}/{item}' for item in os.listdir(INPUT_FRAME_DIR)])
frames = []

for index in tqdm(range(0, len(sources) - 1), f'generate frames'):
  output = inference(model, sources[index], sources[index + 1], INTER_NUM)
  if index == len(sources) - 1:
    frames += output
    break
  else:
    frames += output[:-1]

if OUTPUT_FRAME_DIR:
  if os.path.exists(OUTPUT_FRAME_DIR):
    shutil.rmtree(OUTPUT_FRAME_DIR)
  os.makedirs(OUTPUT_FRAME_DIR, exist_ok=True)

  for index, frame in enumerate(frames):
    file_path = f'{OUTPUT_FRAME_DIR}/{str(index).zfill(4)}.png'
    cv2.imwrite(file_path, frame, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])

### Frames to video

In [None]:
FRAME_DIR = '/content/film/output' #@param {type:'string'}
OUTPUT_VIDEO_PATH = '/content/film/result.mp4' #@param {type:'string'}
OUTPUT_VIDEO_FPS = 24 #@param {type:'integer'}
os.system(f'ffmpeg -y -r {OUTPUT_VIDEO_FPS} -i "{FRAME_DIR}/%04d.png" -c:v libx264 -crf 18 -pix_fmt yuv420p "{OUTPUT_VIDEO_PATH}"')


## Fixer

In [None]:
#@title Fix: A UTF-8 locale is required. Got ANSI_X3.4-1968
import locale
locale.getpreferredencoding = lambda: "UTF-8"