<a href="https://colab.research.google.com/github/jj132535/DeblurGAN-for-Video-Sharpness/blob/main/DNCNN_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Get Dataset from Video File
import cv2
import os
import torch
import torch.nn as nn
import torch.utils.data as data
import numpy as np
import tqdm.notebook as tq
from torchvision import transforms
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

# Paths to be configured by the user
video_path = "/content/16-9_CHUU_Strawberry.mp4"  # Path to the input video file
frame_output_path = "/content/frames"  # Path to save extracted frames
checkpoint_path = "/content/drive/MyDrive/DNCNN_models/30.tar"  # Path to the trained model checkpoint
result_save_path = "/content/denoised_frames"  # Path to save denoised frames
output_video_path = "/content/denoised_video.mp4"  # Path to save the final denoised video

os.makedirs(frame_output_path, exist_ok=True)

# Extract frames from video
cap = cv2.VideoCapture(video_path)
frame_count = 0

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    frame_filename = os.path.join(frame_output_path, f"frame_{frame_count:04d}.png")
    cv2.imwrite(frame_filename, frame)
    frame_count += 1

cap.release()

# Noise Transform
class NoiseTransform(object):
  def __init__(self, size=180, mode="training"):
    super(NoiseTransform, self).__init__()
    self.size = size
    self.mode = mode

  def gaussian_noise(self, img):
    mean = 0
    stddev = 25
    noise = torch.zeros(img.size()).normal_(mean, stddev/255.)
    return noise

  def __call__(self, img):
    if self.mode in ["training", "validation"]:
      self.gt_transform = transforms.Compose([
        transforms.Resize((self.size, self.size), interpolation=2),
        transforms.ToTensor()])
      self.noise_transform = transforms.Compose([
        transforms.Resize((self.size, self.size), interpolation=2),
        transforms.ToTensor(),
        transforms.Lambda(self.gaussian_noise),
      ])
      return self.gt_transform(img), self.noise_transform(img)

    elif self.mode == "testing":
      self.gt_transform = transforms.Compose([
        transforms.Resize((self.size, self.size), interpolation=2),
        transforms.ToTensor()])
      return self.gt_transform(img)
    else:
      return NotImplementedError

# Dataloader for Noise Dataset
class NoiseDataset(data.Dataset):
  def __init__(self, root_path, size):
    super(NoiseDataset, self).__init__()
    self.root_path = root_path
    self.size = size
    self.transforms = None
    self.examples = [os.path.join(root_path, f) for f in os.listdir(root_path) if f.endswith(".png")]

  def set_mode(self, mode):
    self.mode = mode
    self.transforms = NoiseTransform(self.size, mode)

  def __len__(self):
    return len(self.examples)

  def __getitem__(self, idx):
    file_name = self.examples[idx]
    image = Image.open(file_name)

    if self.mode == "testing":
      input_img = self.transforms(image)
      sample = {"img": input_img, "file_name": "image_%06d.png" % int(os.path.basename(file_name).split('_')[1].split('.')[0])}
    else:
      clean, noise = self.transforms(image)
      sample = {"img": clean, "noise": noise}

    return sample

# Simplified DNCNN network
class DNCNN(nn.Module):
  def __init__(self, in_planes=3, blocks=17, hidden=64, kernel_size=3, padding=1, bias=False):
    super(DNCNN, self).__init__()
    self.conv_f = nn.Conv2d(in_channels=in_planes, out_channels=hidden, kernel_size=kernel_size, padding=padding, bias=bias)
    self.conv_h = nn.Conv2d(in_channels=hidden, out_channels=hidden, kernel_size=kernel_size, padding=padding, bias=bias)
    self.conv_l = nn.Conv2d(in_channels=hidden, out_channels=in_planes, kernel_size=kernel_size, padding=padding, bias=bias)

    self.bn = nn.BatchNorm2d(hidden)
    self.relu = nn.ReLU(inplace=True)

    self.hidden_layer = self.mk_hidden_layer(blocks)

  def mk_hidden_layer(self, blocks=17):
    layers = []
    for _ in range(blocks-2):
      layers.append(self.conv_h)
      layers.append(self.bn)
      layers.append(self.relu)
    return nn.Sequential(*layers)

  def forward(self, x):
    out = self.conv_f(x)
    out = self.relu(out)
    out = self.hidden_layer(out)
    out = self.conv_l(out)
    return out

