In [None]:
import os
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Conv2D, Concatenate, UpSampling2D, MaxPooling2D, Activation, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow_addons.image import dense_image_warp

In [None]:
# === 1. Parameters ===
target_size = (540,960)
# Construct the full path to the dataset directory
dataset_dir = os.path.join(os.getcwd(), "dataset", "videos", "thirtyfps")
output_dir = "dataset/videos/sixtyfps"
os.makedirs(output_dir, exist_ok=True)

batch_size = 1 # Reduce if out-of-memory
epochs = 10

print(f"Constructed dataset_dir: {dataset_dir}")

try:
        files_in_dir = os.listdir(dataset_dir)
        print(f"Contents of dataset directory: {files_in_dir}")
        if not files_in_dir:
            print("Warning: Dataset directory is empty.")
except Exception as e:
        print(f"Could not list contents of directory: {e}")

In [None]:
# === 2. Data Loader and Generator ===
def load_triplets_from_video(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if target_size:
            frame = cv2.resize(frame, target_size)
        frames.append(frame)
    cap.release()
    triplets = []
    for i in range(len(frames) - 2):
        triplets.append((frames[i], frames[i+1], frames[i+2]))
    return triplets

def get_all_triplets():
    triplets = []
    for file in os.listdir(dataset_dir):
        if file.endswith(".mp4"):
            triplets.extend(load_triplets_from_video(os.path.join(dataset_dir, file)))
    return triplets

class FlowTripletGenerator(tf.keras.utils.Sequence):
    def __init__(self, triplets, batch_size):
        self.triplets = triplets
        self.batch_size = batch_size

    def __len__(self):
        return len(self.triplets) // self.batch_size

    def __getitem__(self, idx):
        batch = self.triplets[idx * self.batch_size:(idx + 1) * self.batch_size]
        inputs, targets = [], []
        for f0, f1, f2 in batch:
            f0, f1, f2 = f0 / 255.0, f1 / 255.0, f2 / 255.0
            inputs.append(np.concatenate([f0, f2], axis=-1))
            targets.append(f1)
        return np.array(inputs), np.array(targets)

In [None]:
def build_deeper_model():

 inp = Input(shape=(None, None, 6))

 # Encoder

 x = Conv2D(32, 3, padding='same')(inp)

 x = BatchNormalization()(x)

 x = ReLU()(x)

 x = Conv2D(32, 3, padding='same')(x)

 x = ReLU()(x)

 x = MaxPooling2D()(x)

 x = Conv2D(64, 3, padding='same')(x)

 x = BatchNormalization()(x)

 x = ReLU()(x)

 x = Conv2D(64, 3, padding='same')(x)

 x = ReLU()(x)

 x = MaxPooling2D()(x)

 # Bottleneck

 x = Conv2D(128, 3, padding='same')(x)

 x = BatchNormalization()(x)

x = ReLU()(x)

 x = Conv2D(128, 3, padding='same')(x)

 x = ReLU()(x)

 # Decoder

 x = UpSampling2D()(x)

 x = Conv2D(64, 3, padding='same')(x)

 x = BatchNormalization()(x)

 x = ReLU()(x)

 x = Conv2D(64, 3, padding='same')(x)

 x = ReLU()(x)

 x = UpSampling2D()(x)

 x = Conv2D(32, 3, padding='same')(x)

 x = BatchNormalization()(x)

 x = ReLU()(x)

 x = Conv2D(32, 3, padding='same')(x)

 x = ReLU()(x)

 # Output Flow Map

 flow = Conv2D(2, 3, padding='same', activation=None)(x)

 model = Model(inputs=inp, outputs=flow)

 return model


In [None]:
# === 4. Frame Warping Utility ===
def warp_frame(frame, flow):
    h, w = frame.shape[:2]
    grid_x, grid_y = tf.meshgrid(tf.range(w), tf.range(h))
    grid = tf.stack([grid_y, grid_x], axis=-1)
    grid = tf.cast(grid, tf.float32)
    flow = tf.image.resize(flow, (h, w))
    coords = grid + flow[0]
    warped = dense_image_warp(tf.expand_dims(frame, 0), tf.expand_dims(flow[0], 0))
    return tf.squeeze(warped, 0)

In [None]:
# === 5. Training ===
triplets = get_all_triplets()
gen = FlowTripletGenerator(triplets, batch_size)
model = build_model()
optimizer = Adam(1e-4)

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        flow = model(x, training=True)
        f0, f2 = x[..., :3], x[..., 3:]
        warped_f0 = dense_image_warp(f0, flow / 2.0)
        warped_f2 = dense_image_warp(f2, -flow / 2.0)
        pred = (warped_f0 + warped_f2) / 2.0
        loss = tf.reduce_mean(tf.square(y - pred))
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    for i in range(len(gen)):
        x_batch, y_batch = gen[i]
        loss = train_step(tf.convert_to_tensor(x_batch, dtype=tf.float32),
                          tf.convert_to_tensor(y_batch, dtype=tf.float32))
        if i % 10 == 0:
            print(f"Batch {i}, Loss: {loss.numpy():.4f}")

In [None]:
model.save("flow_interpolation_model.keras")

In [None]:
from tensorflow.keras.models import load_model

# Load the model before using it
model = load_model("flow_interpolation_model.keras")


In [None]:
# === 6. Inference and Visualization ===
def interpolate_frame(f0, f2):
    import numpy as np
    from tensorflow_addons.image import dense_image_warp

    f0 = f0.astype(np.float32)
    f2 = f2.astype(np.float32)

    # Concatenate frames and predict optical flow
    inp = np.concatenate([f0, f2], axis=-1) / 255.0
    flow = model.predict(np.expand_dims(inp, 0))[0]

    # Add batch dimension
    f0_batch = np.expand_dims(f0 / 255.0, axis=0)
    f2_batch = np.expand_dims(f2 / 255.0, axis=0)
    flow_batch = np.expand_dims(flow, axis=0)

    # Warp both frames
    warped_f0 = dense_image_warp(f0_batch, flow_batch / 2.0)[0]
    warped_f2 = dense_image_warp(f2_batch, -flow_batch / 2.0)[0]

    # Average warped results
    interp = (warped_f0 + warped_f2) / 2.0
    interp = np.clip(interp * 255.0, 0, 255).astype(np.uint8)
    return interp


In [None]:
# === 7. Video Conversion Utilities ===
def extract_frames_from_video(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    return frames

def interpolate_video_frames(frames):
    output_frames = []
    for i in range(len(frames) - 1):
        f0 = frames[i]
        f2 = frames[i + 1]
        mid = interpolate_frame(f0, f2)
        output_frames.extend([f0, mid])
    output_frames.append(frames[-1])
    return output_frames

def save_video_from_frames(frames, path, fps):
    h, w = frames[0].shape[:2]
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(path, fourcc, fps, (w, h))
    for frame in frames:
        out.write(frame)
    out.release()

In [None]:
import os

# Specify your video file path
input_video_path = "dataset\\videos\\thirtyfps\\video-25- Made with Clipchamp.mp4"

# Check if the file exists
if os.path.exists(input_video_path):
    print(f"File exists: {input_video_path}")
else:
    print(f"File does not exist at the specified path: {input_video_path}")


In [None]:
import os
import matplotlib.pyplot as plt
import cv2

# === 8. Perform Full Video Conversion with Visualization of All Interpolated Frames ===
input_video_path = "dataset\\videos\\thirtyfps\\video-25- Made with Clipchamp.mp4"
video_name = os.path.splitext(os.path.basename(input_video_path))[0]
output_video_path = os.path.join(output_dir, f"{video_name}_60fps.mp4")

# Step 1: Extract frames
input_frames = extract_frames_from_video(input_video_path)

# Step 2: Interpolate frames
interpolated_frames = interpolate_video_frames(input_frames)

# Step 3: Save video
save_video_from_frames(interpolated_frames, output_video_path, fps=60)

# Step 4: Visualize all interpolated frames between pairs
for i in range(len(input_frames) - 1):
    f0 = input_frames[i]
    f2 = input_frames[i + 1]
    mid = interpolate_frame(f0, f2)

    plt.figure(figsize=(15, 3))
    plt.subplot(1, 3, 1)
    plt.imshow(cv2.cvtColor(f0, cv2.COLOR_BGR2RGB))
    plt.title(f"Original Frame {i}")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(cv2.cvtColor(mid, cv2.COLOR_BGR2RGB))
    plt.title("Interpolated Frame")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(cv2.cvtColor(f2, cv2.COLOR_BGR2RGB))
    plt.title(f"Original Frame {i + 1}")
    plt.axis('off')

    plt.tight_layout()
    plt.show()

print(f"\nSaved interpolated video to: {output_video_path}")
print("Model ready for high-resolution interpolation and video conversion.")
