<a href="https://colab.research.google.com/github/moizyousufi/Frame-Interpolation-Deep-Learning-Methods/blob/main/project_memory_optim_mamba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


### Data Processing

In [None]:
import os
import cv2
import numpy as np

!pip install yt_dlp

import yt_dlp

import matplotlib.pyplot as plt

def download_youtube_video(url, output_path):
    ydl_opts = {
        'format': 'bestvideo[height=1080][ext=mp4]/mp4',
        'outtmpl': output_path,
    }
    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        ydl.download([url])
    return output_path

def video_to_3d_data_frame(video_path, segment_number):
    video_data = []
    start_frame = (segment_number - 1) * 3000
    end_frame = segment_number * 3000
    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        print("Error: Could not open video.")
        return None

    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)  # move the current position to the start frame

    current_frame = start_frame
    while current_frame < end_frame and cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        video_data.append(rgb_frame)
        current_frame += 1

    cap.release()
    video_data_array = np.array(video_data)
    return video_data_array

def yt_to_data(url, segment_number):
    youtube_url = url
    video_path = 'downloaded_video.mp4'
    downloaded_video_path = download_youtube_video(youtube_url, video_path)
    video_data_array = video_to_3d_data_frame(downloaded_video_path, segment_number)

    #os.remove(downloaded_video_path)  # ensure the file is removed after processing

    return video_data_array


train_url1 = 'https://www.youtube.com/watch?v=j0HoMaaQj9I'
train_url2 = 'https://www.youtube.com/watch?v=1jzJGcRdxPY'
train_url3 = 'https://www.youtube.com/watch?v=Bs58RoTf-g8'
train_url4 = 'https://www.youtube.com/watch?v=XoyYtqi5u54'
train_url5 = 'https://www.youtube.com/watch?v=rMPkUuMq024'
train_url6 = 'https://www.youtube.com/watch?v=JUWtWPX6hgs'
train_url7 = 'https://www.youtube.com/watch?v=WtnbT6ft710'
train_url8 = 'https://www.youtube.com/watch?v=QLF0FXcW25E'

Collecting yt_dlp
  Downloading yt_dlp-2024.7.9-py3-none-any.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m40.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting brotli (from yt_dlp)
  Downloading Brotli-1.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m83.6 MB/s[0m eta [36m0:00:00[0m
Collecting mutagen (from yt_dlp)
  Downloading mutagen-1.47.0-py3-none-any.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.4/194.4 kB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pycryptodomex (from yt_dlp)
  Downloading pycryptodomex-3.20.0-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting requests<3,>=2.32.2 (fr

### Setting up Training Pairs for Data

Keep in mind how many iterations are necessary for a samples of 3000 frames per video. You can calculate the total number of frames by looking at the 'stats for nerds' setting on the YouTube video and multiplying the FPS by the length of the video. Then divide that by 3000 to determine the number of iterations necessary. Then add that to the urls dictionary

In [None]:
import torch
from torch.nn import functional

def create_training_pairs(video_data_array):
    print("Starting the Numpy!")
    length = len(video_data_array) - 2
    # pre-allocate tensors
    input_frames = torch.zeros((length // 2, 2, *video_data_array.shape[1:]), dtype=video_data_array.dtype)
    target_frames = torch.zeros((length // 2, *video_data_array.shape[1:]), dtype=video_data_array.dtype)

    print("Entering loop")
    for i in range(0, length - 1, 2):
        input_frames[i // 2] = video_data_array[i:i+2]  # two consecutive frames
        target_frames[i // 2] = video_data_array[i+1]   # middle frame
        if i % 100 == 0:
          print(i, "/", length, flush=True)
    print("Finished loop")
    return input_frames, target_frames

urls = {train_url1 : 1,
        train_url2 : 1,
        train_url3 : 5,
        train_url4 : 2,
        train_url5 : 1,
        train_url6 : 6,
        train_url7 : 1,
        train_url8 : 11
        }

### Creating the CNN Model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [None]:
!pip install mamba-ssm

Collecting mamba-ssm
  Downloading mamba_ssm-2.2.2.tar.gz (85 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/85.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.4/85.4 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ninja (from mamba-ssm)
  Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.2/307.2 kB[0m [31m23.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops (from mamba-ssm)
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->mamba-ssm)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cud

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset, ConcatDataset
from functools import partial
from mamba_ssm import Mamba2
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.modules.mamba2 import Mamba2
from mamba_ssm.modules.mha import MHA
from mamba_ssm.modules.mlp import GatedMLP
from mamba_ssm.modules.block import Block

try:
    from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None

# adapted from tridao's state-spaces/mamba
# https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L215
def create_block(
    d_model,
    d_intermediate,
    ssm_cfg=None,
    attn_layer_idx=None,
    attn_cfg=None,
    norm_epsilon=1e-5,
    rms_norm=False,
    residual_in_fp32=False,
    fused_add_norm=False,
    layer_idx=None,
    device=None,
    dtype=None,
):
    if ssm_cfg is None:
        ssm_cfg = {}
    if attn_layer_idx is None:
        attn_layer_idx = []
    if attn_cfg is None:
        attn_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}
    if layer_idx not in attn_layer_idx:
        # Create a copy of the config to modify
        ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
        ssm_layer = ssm_cfg.pop("layer", "Mamba1")
        if ssm_layer not in ["Mamba1", "Mamba2"]:
            raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
        mixer_cls = partial(
            Mamba2 if ssm_layer == "Mamba2" else Mamba,
            layer_idx=layer_idx,
            **ssm_cfg,
            **factory_kwargs
        )
    else:
        mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)

    # Ensure the correct normalization class is used
    norm_cls = nn.LayerNorm if not rms_norm or RMSNorm is None else RMSNorm
    norm_cls = partial(norm_cls, eps=norm_epsilon, **factory_kwargs)

    if d_intermediate == 0:
        mlp_cls = nn.Identity
    else:
        mlp_cls = partial(
            GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
        )
    block = Block(
        d_model,
        mixer_cls,
        mlp_cls,
        norm_cls=norm_cls,
        fused_add_norm=fused_add_norm,
        residual_in_fp32=residual_in_fp32,
    )
    block.layer_idx = layer_idx

    # Debug statements to print the shapes
    '''
    print(f"Block {layer_idx}:")
    print(f"  d_model: {d_model}")
    print(f"  norm_cls weight shape: {block.norm.weight.shape}")
    if block.norm.bias is not None:
        print(f"  norm_cls bias shape: {block.norm.bias.shape}")
    else:
        print(f"  norm_cls bias: None")
    '''
    return block


class FrameInterpolationMamba(nn.Module):
    def __init__(self):
        super(FrameInterpolationMamba, self).__init__()
        # Encoding layers for individual frames
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),  # downsample
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        # Decoding layers to interpolate the frame
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256 * 2, 128, kernel_size=3, stride=2, padding=1, output_padding=1), # upsample
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
        self.mamba_layer = create_block(
            d_model=256,
            d_intermediate=512,
            ssm_cfg={'layer': 'Mamba2'},
            attn_layer_idx=[1],
            attn_cfg={'num_heads': 8},
            norm_epsilon=1e-5,
            rms_norm=True,
            residual_in_fp32=False,
            fused_add_norm=False,
            layer_idx=1,
            device=device  # Use the same device as the model
        ).to(device)

        self.mamba_activation = nn.ReLU()

    def forward(self, frame1, frame2):
        # Encode both frames separately
        enc1 = self.encoder(frame1)
        enc2 = self.encoder(frame2)

        #print(f"Encoded shapes: {enc1.shape}, {enc2.shape}")
        #print(f"Encoded ranges: {enc1.min():.4f}-{enc1.max():.4f}, {enc2.min():.4f}-{enc2.max():.4f}")

        # Reshape for MAMBA layer
        batch_size, channels, height, width = enc1.shape
        enc1 = enc1.view(batch_size, channels, -1).permute(0, 2, 1)
        enc2 = enc2.view(batch_size, channels, -1).permute(0, 2, 1)

        # Apply MAMBA block
        enc1, _ = self.mamba_layer(enc1, None)
        enc2, _ = self.mamba_layer(enc2, None)

        # Reshape back
        enc1 = enc1.permute(0, 2, 1).view(batch_size, channels, height, width)
        enc2 = enc2.permute(0, 2, 1).view(batch_size, channels, height, width)

        #print(f"After MAMBA shapes: {enc1.shape}, {enc2.shape}")
        #print(f"After MAMBA ranges: {enc1.min():.4f}-{enc1.max():.4f}, {enc2.min():.4f}-{enc2.max():.4f}")

        # Add ReLU activation after MAMBA layer
        enc1 = self.mamba_activation(enc1)
        enc2 = self.mamba_activation(enc2)

        # Concatenate the encoded frames
        enc = torch.cat((enc1, enc2), dim=1)

        # Decode to get the interpolated frame
        out = self.decoder(enc)

        #print(f"Output shape: {out.shape}")
        #print(f"Output range: {out.min():.4f}-{out.max():.4f}")

        return out

class FrameDataset(Dataset):
    def __init__(self, input_frames, target_frames):
        self.input_frames = input_frames
        self.target_frames = target_frames

    def __len__(self):
        return len(self.target_frames)

    def __getitem__(self, idx):
        # normalize pixel data to range [0, 1]
        frame1 = self.input_frames[idx][0].float() / 255.0
        frame2 = self.input_frames[idx][1].float() / 255.0
        target_frame = self.target_frames[idx].float() / 255.0

        # permute the dimensions to [C, H, W]
        frame1 = frame1.permute(2, 0, 1)
        frame2 = frame2.permute(2, 0, 1)
        target_frame = target_frame.permute(2, 0, 1)

        return (frame1, frame2), target_frame


model = FrameInterpolationMamba().to(device)

from google.colab import files

from google.colab import drive
drive.mount('/content/drive')
path = "/content/drive/My Drive/model_mamba_delta4_9epoch.pth"

state_dict = torch.load(path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)

criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

num_epochs = 1

train_array = None
input_frames, target_frames = None, None

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    counter = 1
    for url in urls:
      for j in range(urls[url]):

        train_array = yt_to_data(url, j+1)
        train_array = torch.from_numpy(train_array)
        input_frames, target_frames = create_training_pairs(train_array)
        train_array = None
        frame_dataset = FrameDataset(input_frames, target_frames)
        train_loader = DataLoader(frame_dataset, batch_size=10, shuffle=False, num_workers=4) # I ran it on 12 CPU Cores, use GPU if you can

        for i, ((frame1, frame2), target_frame) in enumerate(train_loader):
          frame1, frame2, target_frame = frame1.to(device), frame2.to(device), target_frame.to(device)

          # Zero the parameter gradients
          optimizer.zero_grad()

          # Forward pass
          outputs = model(frame1, frame2)
          #print(f"Output min: {outputs.min().item()}, max: {outputs.max().item()}, mean: {outputs.mean().item()}")

          # Compute loss
          loss = criterion(outputs, target_frame)

          # Backward pass and optimize
          loss.backward()
          optimizer.step()

          # print statistics so that we can tell the progress of this model as it runs
          running_loss += loss.item()
          if i % 10 == 9:    # print every 10 mini-batches
              print(f'Url: {counter}, Sample: {j+1}, Epoch: {epoch + 1}, Batch: {i + 1}, Loss: {running_loss / 10:.4f}')
              running_loss = 0.0

        input_frames, target_frames = None, None
        frame_dataset = None
        train_loader = None
      os.remove('downloaded_video.mp4')
      counter += 1

print('Finished Training')


Mounted at /content/drive
[youtube] Extracting URL: https://www.youtube.com/watch?v=j0HoMaaQj9I
[youtube] j0HoMaaQj9I: Downloading webpage
[youtube] j0HoMaaQj9I: Downloading ios player API JSON
[youtube] j0HoMaaQj9I: Downloading player 8d9f6215
[youtube] j0HoMaaQj9I: Downloading m3u8 information
[info] j0HoMaaQj9I: Downloading 1 format(s): 614
[hlsnative] Downloading m3u8 manifest
[hlsnative] Total fragments: 16
[download] Destination: downloaded_video.mp4
[download] 100% of   14.99MiB in 00:00:00 at 19.46MiB/s                  
Starting the Numpy!
Entering loop
0 / 2011
100 / 2011
200 / 2011
300 / 2011
400 / 2011
500 / 2011
600 / 2011
700 / 2011
800 / 2011
900 / 2011
1000 / 2011
1100 / 2011
1200 / 2011
1300 / 2011
1400 / 2011
1500 / 2011
1600 / 2011
1700 / 2011
1800 / 2011
1900 / 2011
2000 / 2011
Finished loop
Url: 1, Sample: 1, Epoch: 1, Batch: 10, Loss: 0.0004
Url: 1, Sample: 1, Epoch: 1, Batch: 20, Loss: 0.0003
Url: 1, Sample: 1, Epoch: 1, Batch: 30, Loss: 0.0004
Url: 1, Sample: 1,

### Testing the Model

Lets now save the model to another device, just so that we may run it again to other devices

In [None]:
device = "cuda"
model = model.to(device)

In [None]:
torch.save(model.state_dict(), 'model_mamba_delta5_10epoch.pth')

from google.colab import files

from google.colab import drive
drive.mount('/content/drive')
path = "/content/drive/My Drive/model_mamba_delta5_10epoch.pth"
torch.save(model.state_dict(), path)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
