In [None]:
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models

# Parameters
IMG_SIZE = (128, 128)     # Resize frames to a fixed resolution
NUM_INTERPOLATED = 3      # Number of frames to interpolate between the start and end
BATCH_SIZE = 16
EPOCHS = 10

def load_sequence(seq_path):
    """
    Loads and sorts frames from a sequence folder.
    Assumes frame filenames are like "frame_0.jpg", "frame_1.jpg", etc.
    """
    frames = []
    # Sort filenames to ensure correct order
    file_list = sorted([f for f in os.listdir(seq_path) if f.endswith(".jpg") and f.startswith("frame_")])
    for filename in file_list:
        img_path = os.path.join(seq_path, filename)
        # Read using OpenCV, convert from BGR to RGB, and resize
        img = cv2.imread(img_path)
        if img is None:
            continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, IMG_SIZE)
        img = img.astype(np.float32) / 255.0
        frames.append(img)
    return frames

def process_sequence(frames):
    """
    Given a list of frames, forms:
      - Input: concatenated first and last frames (shape: H x W x 6)
      - Target: concatenated intermediate frames (shape: H x W x (NUM_INTERPOLATED*3))
    """
    # Ensure there are enough frames
    if len(frames) < NUM_INTERPOLATED + 2:
        return None, None
    input_frames = [frames[0], frames[-1]]
    target_frames = frames[1:-1]
    # Concatenate along the channel axis
    input_tensor = np.concatenate(input_frames, axis=-1)
    target_tensor = np.concatenate(target_frames, axis=-1)
    return input_tensor, target_tensor

def dataset_from_folder(folder):
    """
    Builds a tf.data.Dataset from a folder containing sequence subfolders.
    """
    inputs, targets = [], []
    for seq_folder in os.listdir(folder):
        seq_path = os.path.join(folder, seq_folder)
        if os.path.isdir(seq_path):
            frames = load_sequence(seq_path)
            inp, tar = process_sequence(frames)
            if inp is not None and tar is not None:
                inputs.append(inp)
                targets.append(tar)
    dataset = tf.data.Dataset.from_tensor_slices((np.array(inputs), np.array(targets)))
    return dataset

def build_model(input_shape, num_interpolated):
    """
    Builds a simple encoder-decoder CNN.
    Input shape: (H, W, 6) corresponding to concatenated start and end frames.
    Output shape: (H, W, num_interpolated * 3) representing stacked intermediate frames.
    """
    inputs = layers.Input(shape=input_shape)
    # Encoder
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    # Decoder
    x = layers.UpSampling2D((2, 2))(x)
    outputs = layers.Conv2D(num_interpolated * 3, (3, 3), activation='sigmoid', padding='same')(x)
    model = models.Model(inputs, outputs)
    return model

# Prepare datasets
train_folder = "train"  # Folder containing training sequence subfolders
test_folder = "test"    # Folder containing testing sequence subfolders

train_ds = dataset_from_folder(train_folder)
train_ds = train_ds.shuffle(100).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

test_ds = dataset_from_folder(test_folder)
test_ds = test_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# Build and compile the model
input_shape = (IMG_SIZE[1], IMG_SIZE[0], 6)  # (height, width, channels)
model = build_model(input_shape, NUM_INTERPOLATED)
model.compile(optimizer='adam', loss='mse', metrics=['mae'])

# Training
print("Starting training...")
history = model.fit(train_ds, epochs=EPOCHS, validation_data=test_ds)

# Evaluation
print("Evaluating on test data...")
loss, mae = model.evaluate(test_ds)
print("Test Loss:", loss)
print("Test MAE:", mae)

# (Optional) Make predictions on test samples
for inputs, targets in test_ds.take(1):
    predictions = model.predict(inputs)
    print("Predictions shape:", predictions.shape)
