In [1]:
import cv2
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
from typing import List, Tuple, Optional
from torchvision import transforms
from torch.amp import autocast, GradScaler
import logging
import shutil
from dataclasses import dataclass
import torch.utils.checkpoint

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class Config:
    """Configuration class with memory-optimized settings"""
    img_size: Tuple[int, int] = (128, 128)
    num_interpolated: int = 1
    batch_size: int = 8  # Reduced batch size
    epochs: int = 10
    hidden_dim: int = 256  # Reduced hidden dimension
    num_layers: int = 4
    learning_rate: float = 1e-4
    train_ratio: float = 0.8
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    @property
    def input_dim(self) -> int:
        return self.img_size[0] * self.img_size[1] * 3

class FrameProcessor:
    """Handles video frame extraction and processing"""
    def __init__(self, img_size: Tuple[int, int]):
        self.img_size = img_size

    def process_frame(self, frame: np.ndarray) -> np.ndarray:
        """Process a single frame"""
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, self.img_size)
        return frame

    def extract_frames(self, video_path: str) -> List[np.ndarray]:
        """Extracts and processes frames from video"""
        cap = cv2.VideoCapture(video_path)
        frames = []
        try:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                frame = self.process_frame(frame)
                frames.append(frame)
        finally:
            cap.release()
        return frames

class S6Layer(nn.Module):
    """Updated S6 layer with consistent dtype"""
    def __init__(self, hidden_dim: int):
        super().__init__()
        dtype = torch.float32
        self.layer_norm = nn.LayerNorm(hidden_dim).to(dtype)
        self.linear1 = nn.Linear(hidden_dim, hidden_dim * 4).to(dtype)
        self.linear2 = nn.Linear(hidden_dim * 4, hidden_dim).to(dtype)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.layer_norm(x)
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x + residual

class VFIMamba(nn.Module):
    """AMP-compatible model architecture"""
    def __init__(self, config: Config):
        super().__init__()
        self.config = config

        # Ensure all layers use the same dtype
        dtype = torch.float32

        # Encoder
        self.conv_encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1).to(dtype),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1).to(dtype),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, config.hidden_dim, kernel_size=3, padding=1).to(dtype)
        )

        # Reduced dimension for S6 layers
        reduced_dim = config.hidden_dim // 2
        self.s6_layers = nn.ModuleList([
            S6Layer(reduced_dim).to(dtype)
            for _ in range(config.num_layers)
        ])

        # Decoder
        self.conv_decoder = nn.Sequential(
            nn.ConvTranspose2d(config.hidden_dim, 64, kernel_size=4, stride=2, padding=1).to(dtype),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1).to(dtype),
            nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1).to(dtype),
            nn.Sigmoid()
        )

        # Projection layers
        self.dim_reduce = nn.Conv2d(config.hidden_dim, reduced_dim, 1).to(dtype)
        self.dim_expand = nn.Conv2d(reduced_dim, config.hidden_dim, 1).to(dtype)

    @torch.cuda.amp.autocast()
    def process_sequence(self, x: torch.Tensor) -> torch.Tensor:
        """Process a sequence through S6 layers with gradient checkpointing"""
        for layer in self.s6_layers:
            x = torch.utils.checkpoint.checkpoint(layer, x)
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C, H, W = x.shape

        # Ensure input is float32
        x = x.to(torch.float32)

        # Process frames in chunks
        encoded_frames = []
        chunk_size = 2

        for i in range(0, T, chunk_size):
            chunk = x[:, i:i+chunk_size].reshape(-1, C, H, W)
            encoded = self.conv_encoder(chunk)
            encoded_frames.append(encoded)

        x = torch.cat(encoded_frames, dim=0)
        _, C_hidden, H_hidden, W_hidden = x.shape

        # Reduce dimensions
        x = self.dim_reduce(x)
        x = x.view(B, T, -1, H_hidden, W_hidden)

        # Process spatial chunks
        spatial_chunks = []
        chunk_size = H_hidden // 4

        for i in range(0, H_hidden, chunk_size):
            chunk = x[..., i:i+chunk_size, :]
            chunk = chunk.reshape(B * chunk_size * W_hidden, T, -1)
            chunk = self.process_sequence(chunk)
            spatial_chunks.append(chunk)

        x = torch.cat(spatial_chunks, dim=0)
        x = x.view(B, H_hidden, W_hidden, T, -1)
        x = x.permute(0, 3, 4, 1, 2)

        # Expand dimensions
        x = x.reshape(B * T, -1, H_hidden, W_hidden)
        x = self.dim_expand(x)

        # Decode frames
        decoded_frames = []
        chunk_size = 2

        for i in range(0, B * T, chunk_size):
            chunk = x[i:i+chunk_size]
            decoded = self.conv_decoder(chunk)
            decoded_frames.append(decoded)

        x = torch.cat(decoded_frames, dim=0)
        x = x.view(B, T, C, H, W)

        return x

class FrameInterpolationDataset(Dataset):
    """Dataset class for frame sequences"""
    def __init__(self, root_dir: str, img_size: Tuple[int, int], transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.img_size = img_size
        self.sequences = list(self.root_dir.glob("**/frame_0.jpg"))

        if len(self.sequences) == 0:
            raise RuntimeError(f"No sequences found in {root_dir}")

        logger.info(f"Found {len(self.sequences)} sequences in {root_dir}")

    def __len__(self) -> int:
        return len(self.sequences)

    def __getitem__(self, idx: int) -> torch.Tensor:
        sequence_dir = self.sequences[idx].parent
        frames = []

        frame_paths = sorted(sequence_dir.glob("*.jpg"),
                           key=lambda x: int(x.stem.split('_')[1]))

        for frame_path in frame_paths:
            frame = cv2.imread(str(frame_path))
            if frame is None:
                raise RuntimeError(f"Failed to load image: {frame_path}")

            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, self.img_size)

            frame = torch.from_numpy(frame).float() / 255.0
            frame = frame.permute(2, 0, 1)

            if self.transform:
                frame = self.transform(frame)

            frames.append(frame)

        frames_tensor = torch.stack(frames)

        expected_shape = (len(frame_paths), 3, self.img_size[1], self.img_size[0])
        if frames_tensor.shape != expected_shape:
            raise RuntimeError(
                f"Incorrect tensor shape. Expected {expected_shape}, got {frames_tensor.shape}"
            )

        return frames_tensor

class Trainer:
    """Updated trainer with proper AMP handling"""
    def __init__(self, model: nn.Module, config: Config):
        self.model = model.to(config.device)
        self.config = config
        self.criterion = nn.MSELoss()
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=0.01
        )
        self.scaler = GradScaler()

    def train_epoch(self, dataloader: DataLoader) -> float:
        self.model.train()
        total_loss = 0

        for batch in dataloader:
            # Ensure input data is float32
            batch = batch.to(self.config.device, dtype=torch.float32)

            start_frames = batch[:, [0]]
            end_frames = batch[:, [-1]]
            target_frames = batch[:, 1:-1]

            input_frames = torch.cat([start_frames, end_frames], dim=1)

            with torch.cuda.amp.autocast():
                pred_frames = self.model(input_frames)
                loss = self.criterion(pred_frames, target_frames)

            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            total_loss += loss.item()

        return total_loss / len(dataloader)

def generate_dataset(input_folder: str, train_folder: str, test_folder: str, config: Config):
    """Generate dataset from input videos"""
    processor = FrameProcessor(config.img_size)

    for folder in [train_folder, test_folder]:
        Path(folder).mkdir(exist_ok=True, parents=True)

    video_files = list(Path(input_folder).glob("*.mp4"))
    if not video_files:
        raise RuntimeError(f"No MP4 files found in {input_folder}")

    for video_file in video_files:
        logger.info(f"Processing video: {video_file}")

        frames = processor.extract_frames(str(video_file))
        sequence_length = config.num_interpolated + 2

        if len(frames) < sequence_length:
            logger.warning(f"Video {video_file} too short, skipping")
            continue

        num_sequences = len(frames) - sequence_length + 1
        for i in range(num_sequences):
            sequence = frames[i:i + sequence_length]

            output_dir = Path(train_folder if random.random() < config.train_ratio else test_folder)
            seq_folder = output_dir / f"{video_file.stem}_seq_{i}"
            seq_folder.mkdir(exist_ok=True, parents=True)

            for j, frame in enumerate(sequence):
                frame_path = seq_folder / f"frame_{j}.jpg"
                cv2.imwrite(str(frame_path), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

        logger.info(f"Generated {num_sequences} sequences from {video_file}")

def test_interpolation(model_path: str, video_path: str, output_path: str, config: Config):
    """Test the model on a video"""
    model = VFIMamba(config)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    model.to(config.device)

    processor = FrameProcessor(config.img_size)
    frames = processor.extract_frames(video_path)

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, 60.0, config.img_size)

    with torch.no_grad():
        for i in range(len(frames) - 1):
            frame1 = torch.from_numpy(frames[i]).permute(2, 0, 1).float() / 255.0
            frame2 = torch.from_numpy(frames[i + 1]).permute(2, 0, 1).float() / 255.0

            input_frames = torch.stack([frame1, frame2]).unsqueeze(0).to(config.device)

            interpolated = model(input_frames)

            frame1_np = (frame1.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
            interpolated_np = (interpolated[0, 0].cpu().permute(1, 2, 0).numpy() * 255).astype(np.uint8)

            out.write(cv2.cvtColor(frame1_np, cv2.COLOR_RGB2BGR))
            out.write(cv2.cvtColor(interpolated_np, cv2.COLOR_RGB2BGR))

        final_frame = (frames[-1] * 255).astype(np.uint8)
        out.write(cv2.cvtColor(final_frame, cv2.COLOR_RGB2BGR))

    out.release()
    logger.info(f"Interpolated video saved to {output_path}")

def main():
    # Initialize configuration
    config = Config()
    logger.info(f"Using device: {config.device}")

    # Clean existing datasets
    for folder in ['train', 'test']:
        if Path(folder).exists():
            shutil.rmtree(folder)

    # Generate dataset
    generate_dataset("input", "train", "test", config)

    # Initialize model and trainer
    model = VFIMamba(config)
    trainer = Trainer(model, config)

    # Create dataset and dataloader
    train_dataset = FrameInterpolationDataset(
        "train",
        img_size=config.img_size,
        transform=None
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0 if config.device == "cpu" else 4,
        pin_memory=True
    )

    # Training loop
    logger.info("Starting training...")
    for epoch in range(config.epochs):
        loss = trainer.train_epoch(train_dataloader)
        logger.info(f"Epoch [{epoch+1}/{config.epochs}] Loss: {loss:.4f}")

    # Save model
    torch.save(model.state_dict(), "frame_interpolation_model.pth")
    logger.info("Training completed. Model saved.")

    # Test model on a video
    if list(Path("input").glob("*.mp4")):
        test_video = next(Path("input").glob("*.mp4"))
        test_interpolation(
            "frame_interpolation_model.pth",
            str(test_video),
            "output_interpolated.mp4",
            config
        )

if __name__ == "__main__":
    main()

  @torch.cuda.amp.autocast()
  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
  model.load_state_dict(torch.load(model_path))


In [2]:
import cv2
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
from typing import List, Tuple, Optional
from torchvision import transforms
from torch.amp import autocast, GradScaler
import logging
import shutil
from dataclasses import dataclass
import torch.utils.checkpoint

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class Config:
    """Configuration class with memory-optimized settings"""
    img_size: Tuple[int, int] = (128, 128)
    num_interpolated: int = 1
    batch_size: int = 8
    epochs: int = 10
    hidden_dim: int = 256
    num_layers: int = 4
    learning_rate: float = 1e-4
    train_ratio: float = 0.8
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    @property
    def input_dim(self) -> int:
        return self.img_size[0] * self.img_size[1] * 3

# [Previous FrameProcessor class remains the same]

class S6Layer(nn.Module):
    """Updated S6 layer with consistent dtype"""
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.linear1 = nn.Linear(hidden_dim, hidden_dim * 4)
        self.linear2 = nn.Linear(hidden_dim * 4, hidden_dim)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.layer_norm(x)
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x + residual

class VFIMamba(nn.Module):
    """AMP-compatible model architecture"""
    def __init__(self, config: Config):
        super().__init__()
        self.config = config

        # Encoder
        self.conv_encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, config.hidden_dim, kernel_size=3, padding=1)
        )

        # Reduced dimension for S6 layers
        reduced_dim = config.hidden_dim // 2
        self.s6_layers = nn.ModuleList([
            S6Layer(reduced_dim)
            for _ in range(config.num_layers)
        ])

        # Decoder
        self.conv_decoder = nn.Sequential(
            nn.ConvTranspose2d(config.hidden_dim, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

        # Projection layers
        self.dim_reduce = nn.Conv2d(config.hidden_dim, reduced_dim, 1)
        self.dim_expand = nn.Conv2d(reduced_dim, config.hidden_dim, 1)

    # Removed @torch.utils.checkpoint.checkpoint
    def process_sequence(self, x: torch.Tensor) -> torch.Tensor:
        """Process a sequence through S6 layers without gradient checkpointing"""
        for layer in self.s6_layers:
            x = layer(x)
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C, H, W = x.shape

        # Process frames
        x = x.reshape(B * T, C, H, W)
        x = self.conv_encoder(x)

        _, C_hidden, H_hidden, W_hidden = x.shape

        # Reduce dimensions
        x = self.dim_reduce(x)
        x = x.view(B, T, -1, H_hidden, W_hidden)
        x = x.permute(0, 3, 4, 1, 2)  # [B, H, W, T, C]
        x = x.reshape(B * H_hidden * W_hidden, T, -1)

        # Process sequence
        x = self.process_sequence(x)

        # Reshape and expand
        x = x.view(B, H_hidden, W_hidden, T, -1)
        x = x.permute(0, 3, 4, 1, 2)
        x = x.reshape(B * T, -1, H_hidden, W_hidden)
        x = self.dim_expand(x)

        # Decode
        x = self.conv_decoder(x)
        x = x.view(B, T, C, H, W)

        return x

class Trainer:
    """Updated trainer with proper AMP handling"""
    def __init__(self, model: nn.Module, config: Config):
        self.model = model.to(config.device)
        self.config = config
        self.criterion = nn.MSELoss()
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=0.01
        )
        self.scaler = GradScaler()

    def train_epoch(self, dataloader: DataLoader) -> float:
        self.model.train()
        total_loss = 0
        num_batches = 0

        for batch in dataloader:
            batch = batch.to(self.config.device)

            # Extract frames
            start_frames = batch[:, [0]]  # [B, 1, C, H, W]
            end_frames = batch[:, [-1]]   # [B, 1, C, H, W]
            target_frames = batch[:, 1:-1]  # [B, num_interpolated, C, H, W]

            input_frames = torch.cat([start_frames, end_frames], dim=1)

            # Use updated autocast
            with autocast(device_type='cuda' if self.config.device == 'cuda' else 'cpu'):
                pred_frames = self.model(input_frames)
                loss = self.criterion(pred_frames, target_frames)

            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.scaler.step(self.optimizer)
            self.scaler.update()

            total_loss += loss.item()
            num_batches += 1

            if num_batches % 10 == 0:
                logger.info(f"Batch {num_batches}: Loss = {loss.item():.4f}")

        return total_loss / num_batches
def generate_dataset(input_folder: str, train_folder: str, test_folder: str, config: Config):
    """Generate dataset from input videos"""
    processor = FrameProcessor(config.img_size)

    for folder in [train_folder, test_folder]:
        Path(folder).mkdir(exist_ok=True, parents=True)

    video_files = list(Path(input_folder).glob("*.mp4"))
    if not video_files:
        raise RuntimeError(f"No MP4 files found in {input_folder}")

    for video_file in video_files:
        logger.info(f"Processing video: {video_file}")

        frames = processor.extract_frames(str(video_file))
        sequence_length = config.num_interpolated + 2

        if len(frames) < sequence_length:
            logger.warning(f"Video {video_file} too short, skipping")
            continue

        num_sequences = len(frames) - sequence_length + 1
        for i in range(num_sequences):
            sequence = frames[i:i + sequence_length]

            output_dir = Path(train_folder if random.random() < config.train_ratio else test_folder)
            seq_folder = output_dir / f"{video_file.stem}_seq_{i}"
            seq_folder.mkdir(exist_ok=True, parents=True)

            for j, frame in enumerate(sequence):
                frame_path = seq_folder / f"frame_{j}.jpg"
                cv2.imwrite(str(frame_path), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

        logger.info(f"Generated {num_sequences} sequences from {video_file}")

class FrameProcessor:
    """Handles video frame extraction and processing"""
    def __init__(self, img_size: Tuple[int, int]):
        self.img_size = img_size

    def process_frame(self, frame: np.ndarray) -> np.ndarray:
        """Process a single frame"""
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, self.img_size)
        return frame

    def extract_frames(self, video_path: str) -> List[np.ndarray]:
        """Extracts and processes frames from video"""
        cap = cv2.VideoCapture(video_path)
        frames = []
        try:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                frame = self.process_frame(frame)
                frames.append(frame)
        finally:
            cap.release()
        return frames

class FrameInterpolationDataset(Dataset):
    """Dataset class for frame sequences"""
    def __init__(self, root_dir: str, img_size: Tuple[int, int], transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.img_size = img_size
        self.sequences = list(self.root_dir.glob("**/frame_0.jpg"))

        if len(self.sequences) == 0:
            raise RuntimeError(f"No sequences found in {root_dir}")

        logger.info(f"Found {len(self.sequences)} sequences in {root_dir}")

    def __len__(self) -> int:
        return len(self.sequences)

    def __getitem__(self, idx: int) -> torch.Tensor:
        sequence_dir = self.sequences[idx].parent
        frames = []

        frame_paths = sorted(sequence_dir.glob("*.jpg"),
                           key=lambda x: int(x.stem.split('_')[1]))

        for frame_path in frame_paths:
            frame = cv2.imread(str(frame_path))
            if frame is None:
                raise RuntimeError(f"Failed to load image: {frame_path}")

            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, self.img_size)

            frame = torch.from_numpy(frame).float() / 255.0
            frame = frame.permute(2, 0, 1)

            if self.transform:
                frame = self.transform(frame)

            frames.append(frame)

        frames_tensor = torch.stack(frames)

        expected_shape = (len(frame_paths), 3, self.img_size[1], self.img_size[0])
        if frames_tensor.shape != expected_shape:
            raise RuntimeError(
                f"Incorrect tensor shape. Expected {expected_shape}, got {frames_tensor.shape}"
            )

        return frames_tensor

def main():
    # Initialize configuration
    config = Config()
    logger.info(f"Using device: {config.device}")

    # Check for input directory and videos
    input_dir = Path("input")
    if not input_dir.exists() or not list(input_dir.glob("*.mp4")):
        logger.error("No input directory found or no .mp4 files present")
        return

    # Clean existing datasets
    for folder in ['train', 'test']:
        if Path(folder).exists():
            shutil.rmtree(folder)
            Path(folder).mkdir(parents=True)

    # Generate dataset
    generate_dataset("input", "train", "test", config)

    # Verify dataset creation
    train_files = list(Path("train").rglob("*.jpg"))
    if not train_files:
        logger.error("No training data generated")
        return

    logger.info(f"Generated {len(train_files)} training frames")

    # Initialize model and trainer
    model = VFIMamba(config)
    trainer = Trainer(model, config)

    # Create dataset and dataloader
    train_dataset = FrameInterpolationDataset(
        "train",
        img_size=config.img_size,
        transform=None
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0 if config.device == "cpu" else 4,
        pin_memory=True
    )

    # Training loop
    logger.info("Starting training...")
    for epoch in range(config.epochs):
        loss = trainer.train_epoch(train_dataloader)
        logger.info(f"Epoch [{epoch+1}/{config.epochs}] Loss: {loss:.4f}")

        # Save checkpoint
        checkpoint_path = f"checkpoint_epoch_{epoch+1}.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': trainer.optimizer.state_dict(),
            'loss': loss,
        }, checkpoint_path)
        logger.info(f"Saved checkpoint: {checkpoint_path}")

    # Save final model
    torch.save(model.state_dict(), "frame_interpolation_model.pth")
    logger.info("Training completed. Model saved.")

if __name__ == "__main__":
    main()

  return F.mse_loss(input, target, reduction=self.reduction)
