In [1]:
%%capture
!pip install -q mediapy tensorflow tensorflow_hub
!apt-get -qq install -y ffmpeg

In [2]:
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import mediapy as media
from IPython.display import display
import ipywidgets as widgets
from google.colab import files
import time
from typing import List, Tuple, Optional
import cv2
import subprocess

# **1-FILM Interpolator Class**

In [3]:
class FilmInterpolator:
    def __init__(self, model_url: str = "https://tfhub.dev/google/film/1"):
        self.model = hub.load(model_url)
        self.align = 64
        self._warmup_model()

    def _warmup_model(self):
        dummy_input = {
            'time': tf.reshape(tf.constant([0.5], dtype=tf.float32), [1, 1]),
            'x0': tf.zeros((1, 64, 64, 3), dtype=tf.float32),
            'x1': tf.zeros((1, 64, 64, 3), dtype=tf.float32)
        }
        _ = self.model(dummy_input, training=False)

    def _pad_to_align(self, image: tf.Tensor) -> tf.Tensor:
        height, width = tf.shape(image)[0], tf.shape(image)[1]
        height_to_pad = (self.align - height % self.align) % self.align
        width_to_pad = (self.align - width % self.align) % self.align
        if height_to_pad != 0 or width_to_pad != 0:
            image = tf.pad(
                image,
                [[0, height_to_pad], [0, width_to_pad], [0, 0]],
                mode='REFLECT'
            )
        return image

    def interpolate(self, frame1: np.ndarray, frame2: np.ndarray, time: float = 0.5) -> np.ndarray:
        frame1_tensor = tf.convert_to_tensor(frame1, dtype=tf.float32)
        frame2_tensor = tf.convert_to_tensor(frame2, dtype=tf.float32)
        orig_height, orig_width = frame1_tensor.shape[0], frame1_tensor.shape[1]
        frame1_tensor = self._pad_to_align(frame1_tensor)
        frame2_tensor = self._pad_to_align(frame2_tensor)
        inputs = {
            'time': tf.reshape(tf.constant([time], dtype=tf.float32), [1, 1]),
            'x0': tf.expand_dims(frame1_tensor, axis=0),
            'x1': tf.expand_dims(frame2_tensor, axis=0)
        }
        with tf.device('/GPU:0' if tf.config.list_physical_devices('GPU') else '/CPU:0'):
            result = self.model(inputs, training=False)
        output = result['image'][0].numpy()
        return output[:orig_height, :orig_width, :]

# **2-Helper Functions**

In [4]:
def upload_image(description: str, layout: Optional[widgets.Layout] = None) -> widgets.FileUpload:
    if layout is None:
        layout = widgets.Layout(width='auto', height='40px')
    upload = widgets.FileUpload(
        description=description,
        accept='image/*',
        multiple=False,
        style={'description_width': 'initial'},
        layout=layout
    )
    return upload

def process_uploaded_image(upload: widgets.FileUpload, max_dim: int = 1024) -> Optional[np.ndarray]:
    if not upload.value:
        return None
    try:
        uploaded_file = next(iter(upload.value.values()))
        image = tf.io.decode_image(uploaded_file['content'], channels=3)
        image = tf.image.convert_image_dtype(image, tf.float32)
        original_height, original_width = image.shape[0], image.shape[1]
        scale = min(max_dim / original_height, max_dim / original_width)
        if scale < 1:
            new_height = int(original_height * scale)
            new_width = int(original_width * scale)
            image = tf.image.resize(image, [new_height, new_width], method=tf.image.ResizeMethod.AREA)
        return image.numpy()
    except Exception as e:
        print(f"Image processing error: {str(e)}")
        return None

def generate_interpolated_frames(
    frame1: np.ndarray,
    frame2: np.ndarray,
    num_frames: int = 5,
    interpolator: Optional[FilmInterpolator] = None
) -> List[np.ndarray]:
    if interpolator is None:
        interpolator = FilmInterpolator()
    frames = [frame1]
    progress = widgets.IntProgress(value=0, max=num_frames-2, description='Processing:')
    display(progress)
    try:
        for i in range(1, num_frames-1):
            t = i / (num_frames-1)
            mid_frame = interpolator.interpolate(frame1, frame2, t)
            frames.append(mid_frame)
            progress.value += 1
            time.sleep(0.1)
        frames.append(frame2)
        return frames
    finally:
        progress.close()

def smart_resize(img, target_size):
    h, w = img.shape[:2]
    target_h, target_w = target_size
    scale = min(target_h / h, target_w / w)
    if scale == 1:
        return img
    new_h, new_w = int(h * scale), int(w * scale)
    resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
    if new_w < target_w or new_h < target_h:
        pad_top = (target_h - new_h) // 2
        pad_bottom = target_h - new_h - pad_top
        pad_left = (target_w - new_w) // 2
        pad_right = target_w - new_w - pad_left
        resized = cv2.copyMakeBorder(
            resized, pad_top, pad_bottom, pad_left, pad_right,
            cv2.BORDER_CONSTANT, value=[0, 0, 0]
        )
    return resized

def convert_video(input_path, output_path):
    command = [
        'ffmpeg', '-y', '-i', input_path,
        '-vcodec', 'libx264', '-crf', '23', '-pix_fmt', 'yuv420p',
        output_path
    ]
    subprocess.run(command, check=True)

def create_combined_video(
    image1: np.ndarray,
    image2: np.ndarray,
    interpolated_frames,
    output_path: str = 'combined_video.mp4',
    fps: int = 5
) -> str:
    img1 = (image1 * 255).astype(np.uint8)
    img2 = (image2 * 255).astype(np.uint8)
    min_height = min(img1.shape[0], img2.shape[0], *[f.shape[0] for f in interpolated_frames])
    target_width = max(img1.shape[1], img2.shape[1])
    img1 = smart_resize(img1, (min_height, target_width))
    img2 = smart_resize(img2, (min_height, target_width))
    interpolated_frames = [
        smart_resize((frame * 255).astype(np.uint8), (min_height, target_width))
        for frame in interpolated_frames
    ]
    separator = np.zeros((min_height, 50, 3), dtype=np.uint8)
    top_row = np.concatenate([img1, separator, img2], axis=1)
    frame_height = top_row.shape[0] + 50 + interpolated_frames[0].shape[0]
    frame_width = max(top_row.shape[1], interpolated_frames[0].shape[1])
    # Save initial video
    temp_path = output_path
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(temp_path, fourcc, fps, (frame_width, frame_height))
    font_scale = 2.2
    thickness = 6
    for frame in interpolated_frames:
        combined_frame = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
        combined_frame[:top_row.shape[0], :top_row.shape[1]] = top_row
        video_y = top_row.shape[0] + 30
        video_x = (frame_width - frame.shape[1]) // 2
        combined_frame[video_y:video_y+frame.shape[0], video_x:video_x+frame.shape[1]] = frame
        cv2.putText(combined_frame, 'Original',
            (img1.shape[1]//2 - 130, 80),
            cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 0, 0), thickness)
        cv2.putText(combined_frame, 'Original',
            (img1.shape[1] + separator.shape[1] + img2.shape[1]//2 - 130, 80),
            cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 0, 0), thickness)
        cv2.putText(combined_frame, 'Generated',
            (video_x + frame.shape[1]//2 - 180, video_y - 20),
            cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 255), thickness)
        out.write(combined_frame)
    out.release()
    # Convert video for mobile compatibility
    mobile_path = output_path.replace('.mp4', '_mobile.mp4')
    convert_video(output_path, mobile_path)
    return mobile_path

# **3-User Interface**

In [6]:
def create_ui():
    style = {'description_width': '150px'}
    layout = widgets.Layout(width='auto', height='40px')
    header = widgets.HTML(
        value="<h1 style='text-align:center;color:#1f77b4'>FILM Frame Interpolation System</h1>"
        "<p style='text-align:center'>Please upload two images for interpolation</p>"
    )
    display(header)

    # Only build widgets once and display HBox
    image1_upload = upload_image("First Image:", layout)
    image2_upload = upload_image("Second Image:", layout)
    upload_box = widgets.HBox([image1_upload, image2_upload])
    display(upload_box)

    advanced_settings = widgets.Accordion([
        widgets.VBox([
            widgets.IntSlider(
                value=5,
                min=3,
                max=15,
                step=1,
                description='Number of intermediate frames:',
                style=style,
                layout=layout
            ),
            widgets.IntSlider(
                value=5,
                min=1,
                max=30,
                step=1,
                description='Video FPS:',
                style=style,
                layout=layout
            ),
            widgets.IntSlider(
                value=1024,
                min=256,
                max=2048,
                step=128,
                description='Max image size:',
                style=style,
                layout=layout
            )
        ])
    ], titles=('Advanced Settings',))
    display(advanced_settings)

    run_button = widgets.Button(
        description="Run Interpolation and Create Video",
        button_style='success',
        icon='play',
        layout=widgets.Layout(width='300px', height='50px')
    )
    display(widgets.HBox([run_button], layout=widgets.Layout(justify_content='center')))
    output_area = widgets.Output()
    display(output_area)
    def on_button_clicked(b):
        with output_area:
            output_area.clear_output()
            print("\nProcessing started...")
            num_frames = advanced_settings.children[0].children[0].value
            fps = advanced_settings.children[0].children[1].value
            max_dim = advanced_settings.children[0].children[2].value
            image1 = process_uploaded_image(image1_upload, max_dim)
            image2 = process_uploaded_image(image2_upload, max_dim)
            if image1 is None or image2 is None:
                print("Error: Please upload both images.")
                return
            try:
                print(f"\nImage Info:\n- Image 1: {image1.shape[1]}x{image1.shape[0]}\n- Image 2: {image2.shape[1]}x{image2.shape[0]}")
                print(f"\nGenerating {num_frames} interpolated frames...")
                start_time = time.time()
                frames = generate_interpolated_frames(image1, image2, num_frames)
                print(f"Frame generation finished in {time.time()-start_time:.2f} seconds.")
                print("\nCreating combined video...")
                video_path = create_combined_video(image1, image2, frames, fps=fps)
                print("\nFinal video:")
                media.show_video(media.read_video(video_path), fps=fps)
                print("\nClick below to download the video:")
                files.download(video_path)
            except Exception as e:
                print(f"\nProcessing error: {str(e)}")
                raise
    run_button.on_click(on_button_clicked)

create_ui()

HTML(value="<h1 style='text-align:center;color:#1f77b4'>FILM Frame Interpolation System</h1><p style='text-ali…

HBox(children=(FileUpload(value={}, accept='image/*', description='First Image:', layout=Layout(height='40px',…

Accordion(children=(VBox(children=(IntSlider(value=5, description='Number of intermediate frames:', layout=Lay…

HBox(children=(Button(button_style='success', description='Run Interpolation and Create Video', icon='play', l…

Output()