In [1]:
!pip install torch torchvision

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
import cv2
import os
import random

# Parameters
IMG_SIZE = (128, 128)     # Resize frames to a fixed resolution
NUM_INTERPOLATED = 1      # Number of frames to interpolate between the start and end
BATCH_SIZE = 64
EPOCHS = 10

def extract_frames(video_path):
    """Extracts all frames from a given video file."""
    cap = cv2.VideoCapture(video_path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    return frames

def save_sequence(sequence, output_dir, video_name, seq_index):
    """Saves a sequence of frames into a dedicated subfolder."""
    seq_folder = os.path.join(output_dir, f"{video_name}_seq_{seq_index}")
    os.makedirs(seq_folder, exist_ok=True)
    for i, frame in enumerate(sequence):
        frame_path = os.path.join(seq_folder, f"frame_{i}.jpg")
        cv2.imwrite(frame_path, frame)

def generate_dataset(input_folder, train_folder, test_folder, num_interpolated=NUM_INTERPOLATED, train_ratio=0.8):
    """
    Processes each video in the input folder.

    For each video, a sliding window of length (num_interpolated + 2) is used
    to generate sequences where the first and last frames are the inputs for interpolation,
    and the frames in between are used as ground truth.

    Each sequence is randomly assigned to train or test.
    """
    # The total sequence length includes the starting and ending frames
    sequence_length = num_interpolated + 2

    os.makedirs(train_folder, exist_ok=True)
    os.makedirs(test_folder, exist_ok=True)

    for filename in os.listdir(input_folder):
        if filename.endswith(".mp4"):
            video_path = os.path.join(input_folder, filename)
            video_name = os.path.splitext(filename)[0]
            frames = extract_frames(video_path)
            total_frames = len(frames)
            seq_index = 0
            # Slide a window over the frames to generate sequences
            for i in range(total_frames - sequence_length + 1):
                sequence = frames[i:i + sequence_length]
                # Randomly assign the sequence to training or testing set
                if random.random() < train_ratio:
                    save_sequence(sequence, train_folder, video_name, seq_index)
                else:
                    save_sequence(sequence, test_folder, video_name, seq_index)
                seq_index += 1


input_folder = "input"    # Folder containing your mp4 videos
train_folder = "train"    # Output folder for training sequences
test_folder = "test"      # Output folder for testing sequences
generate_dataset(input_folder, train_folder, test_folder)

In [4]:
# prompt: Now define 'dataloader' instance considering previous code.

import torch
from torch.utils.data import Dataset, DataLoader
import os
import cv2
import numpy as np

class FrameInterpolationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.sequences = []
        for subdir, _, files in os.walk(root_dir):
            if len(files) > 0 and all(f.endswith('.jpg') for f in files):
                self.sequences.append(subdir)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence_dir = self.sequences[idx]
        frames = []
        for i in range(len(os.listdir(sequence_dir))):  # Iterate over all frames
          frame_path = os.path.join(sequence_dir, f"frame_{i}.jpg")
          frame = cv2.imread(frame_path)
          frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB

          if self.transform:
              frame = self.transform(image=frame)['image']
          frames.append(frame)

        frames = np.stack(frames) # Stack to create a tensor
        frames = torch.tensor(frames, dtype=torch.float32).permute(0, 3, 1, 2) / 255.0 # Normalize
        return frames


# Example usage (assuming you have 'train' and 'test' folders)
train_dataset = FrameInterpolationDataset(root_dir='train')
test_dataset = FrameInterpolationDataset(root_dir='test')

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) # Use the slider value 'foo'
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) # Use the slider value 'foo'


In [None]:
import torch
import torch.nn as nn

class S6Layer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(S6Layer, self).__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, input_dim)
        self.activation = nn.GELU()

    def forward(self, x):
        # x: (batch_size, sequence_length, input_dim)
        h = self.activation(self.linear1(x))
        out = self.linear2(h)
        return out + x  # Residual connection

class VFIMamba(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(VFIMamba, self).__init__()
        self.s6_layers = nn.ModuleList([S6Layer(input_dim, hidden_dim) for _ in range(num_layers)])
        self.output_layer = nn.Linear(input_dim, input_dim)

    def forward(self, x):
        for layer in self.s6_layers:
            x = layer(x)
        return self.output_layer(x)

import torch.optim as optim

# Hyperparameters
input_dim = 128 * 128 * 3  # Assuming 128x128 RGB images
hidden_dim = 512
num_layers = 4
learning_rate = 1e-4
num_epochs = 10

# Initialize model, loss function, and optimizer
model = VFIMamba(input_dim, hidden_dim, num_layers)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    for inputs, targets in train_dataloader:  # Define 'dataloader' to load your dataset
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")


In [None]:
import cv2
import numpy as np

model.eval()
with torch.no_grad():
    for inputs, targets in test_dataloader:  # Define 'test_dataloader' for your test set
        outputs = model(inputs)
        # Reshape and convert outputs to images
        output_images = outputs.view(-1, 128, 128, 3).cpu().numpy()
        target_images = targets.view(-1, 128, 128, 3).cpu().numpy()
        # Compare output_images with target_images
        for i, (output_img, target_img) in enumerate(zip(output_images, target_images)):
            combined = np.hstack((output_img, target_img))
            cv2.imshow(f'Output vs Target {i}', combined)
            cv2.waitKey(0)
        cv2.destroyAllWindows()

# Whole code

In [None]:
import cv2
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
from typing import List, Tuple, Optional
from torchvision import transforms
from torch.cuda.amp import autocast, GradScaler
import logging
import yaml
from dataclasses import dataclass

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class Config:
    """Configuration class for hyperparameters"""
    img_size: Tuple[int, int] = (128, 128)
    num_interpolated: int = 1
    batch_size: int = 16
    epochs: int = 10
    input_dim: int = 128 * 128 * 3
    hidden_dim: int = 512
    num_layers: int = 4
    learning_rate: float = 1e-4
    train_ratio: float = 0.8
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

class FrameProcessor:
    """Handles video frame extraction and processing"""
    def __init__(self, img_size: Tuple[int, int]):
        self.img_size = img_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(img_size),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])

    def extract_frames(self, video_path: str) -> List[np.ndarray]:
        """Extracts and processes frames from video"""
        cap = cv2.VideoCapture(video_path)
        frames = []
        try:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, self.img_size)
                frames.append(frame)
        finally:
            cap.release()
        return frames

    def save_sequence(self, sequence: List[np.ndarray],
                     output_dir: Path, video_name: str,
                     seq_index: int) -> None:
        """Saves frame sequence to disk"""
        seq_folder = output_dir / f"{video_name}_seq_{seq_index}"
        seq_folder.mkdir(exist_ok=True, parents=True)
        for i, frame in enumerate(sequence):
            frame_path = seq_folder / f"frame_{i}.jpg"
            cv2.imwrite(str(frame_path), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

class FrameInterpolationDataset(Dataset):
    """Dataset class for frame sequences"""
    def __init__(self, root_dir: str, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.sequences = list(self.root_dir.glob("**/frame_0.jpg"))

    def __len__(self) -> int:
        return len(self.sequences)

    def __getitem__(self, idx: int) -> torch.Tensor:
        sequence_dir = self.sequences[idx].parent
        frames = []
        for frame_path in sorted(sequence_dir.glob("*.jpg")):
            frame = cv2.imread(str(frame_path))
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)
        return torch.stack(frames)

class S6Layer(nn.Module):
    """Single S6 layer implementation"""
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.layer_norm = nn.LayerNorm(input_dim)
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, input_dim)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.layer_norm(x)
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x + residual

class VFIMamba(nn.Module):
    """Main model architecture"""
    def __init__(self, config: Config):
        super().__init__()
        self.input_projection = nn.Linear(config.input_dim, config.hidden_dim)
        self.s6_layers = nn.ModuleList([
            S6Layer(config.hidden_dim, config.hidden_dim)
            for _ in range(config.num_layers)
        ])
        self.output_projection = nn.Linear(config.hidden_dim, config.input_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C, H, W = x.shape
        x = x.view(B, T, -1)  # Flatten spatial dimensions
        x = self.input_projection(x)

        for layer in self.s6_layers:
            x = layer(x)

        x = self.output_projection(x)
        return x.view(B, T, C, H, W)

class Trainer:
    """Handles model training and evaluation"""
    def __init__(self, model: nn.Module, config: Config):
        self.model = model.to(config.device)
        self.config = config
        self.criterion = nn.MSELoss()
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=0.01
        )
        self.scaler = GradScaler()

    def train_epoch(self, dataloader: DataLoader) -> float:
        self.model.train()
        total_loss = 0

        for batch in dataloader:
            batch = batch.to(self.config.device)
            start_frames = batch[:, 0]
            end_frames = batch[:, -1]
            target_frames = batch[:, 1:-1]

            with autocast():
                pred_frames = self.model(torch.cat([start_frames, end_frames], dim=1))
                loss = self.criterion(pred_frames, target_frames)

            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            total_loss += loss.item()

        return total_loss / len(dataloader)

def main():
    # Load configuration
    config = Config()

    # Initialize components
    processor = FrameProcessor(config.img_size)
    model = VFIMamba(config)
    trainer = Trainer(model, config)

    # Setup datasets
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    train_dataset = FrameInterpolationDataset("train", transform=transform)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )

    # Training loop
    logger.info("Starting training...")
    for epoch in range(config.epochs):
        loss = trainer.train_epoch(train_dataloader)
        logger.info(f"Epoch [{epoch+1}/{config.epochs}] Loss: {loss:.4f}")

    # Save model
    torch.save(model.state_dict(), "frame_interpolation_model.pth")
    logger.info("Training completed. Model saved.")

if __name__ == "__main__":
    main()

  self.scaler = GradScaler()


In [5]:
class VFIMamba(nn.Module):
    """Main model architecture"""
    def __init__(self, config: Config):
        super().__init__()
        self.input_projection = nn.Linear(config.input_dim, config.hidden_dim)
        self.s6_layers = nn.ModuleList([
            S6Layer(config.hidden_dim, config.hidden_dim)
            for _ in range(config.num_layers)
        ])
        self.output_projection = nn.Linear(config.hidden_dim, config.input_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Original: B, T, C, H, W = x.shape
        # The input tensor has shape (batch_size, channels, height, width)
        # We need to add a dimension for the sequence length (T)
        # Assuming the input contains start and end frames, T = 2
        x = x.unsqueeze(1)  # Add a dimension for sequence length

        B, T, C, H, W = x.shape # Now x has the expected 5 dimensions
        x = x.view(B, T, -1)  # Flatten spatial dimensions
        x = self.input_projection(x)

        for layer in self.s6_layers:
            x = layer(x)

        x = self.output_projection(x)
        return x.view(B, T, C, H, W)

In [None]:
import torch
from pathlib import Path
import cv2
import numpy as np

def test_interpolation(model_path: str, video_path: str, output_path: str, config: Config):
    # Load model
    model = VFIMamba(config)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    model.to(config.device)

    # Process video
    processor = FrameProcessor(config.img_size)
    frames = processor.extract_frames(video_path)

    # Create output video
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, 30.0, config.img_size)

    # Process frames in pairs
    with torch.no_grad():
        for i in range(len(frames) - 1):
            frame1 = torch.from_numpy(frames[i]).unsqueeze(0).to(config.device)
            frame2 = torch.from_numpy(frames[i + 1]).unsqueeze(0).to(config.device)

            # Generate interpolated frame
            interpolated = model(torch.stack([frame1, frame2], dim=1))

            # Write frames to video
            out.write(cv2.cvtColor(frames[i], cv2.COLOR_RGB2BGR))
            out.write(cv2.cvtColor(
                interpolated.cpu().numpy()[0],
                cv2.COLOR_RGB2BGR
            ))

        # Write final frame
        out.write(cv2.cvtColor(frames[-1], cv2.COLOR_RGB2BGR))

    out.release()

# Test the model
test_interpolation(
    model_path="frame_interpolation_model.pth",
    video_path="input/motion.mp4",
    output_path="output/interpolated.mp4",
    config=Config()
)