<a href="https://colab.research.google.com/github/hochany95/Multimedia_image_denoising/blob/main/Image_Denoising.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dataset Upload

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import zipfile
import os

file_name = 'Multimedia_dataset.zip'
zip_path = "/content/drive/MyDrive/Multimedia_dataset.zip"

print(file_name)


!cp "{zip_path}" .
!unzip -q "{file_name}"
!rm "{file_name}"

In [None]:
import os 

print(len(os.listdir('./train')))
print(len(os.listdir('./validation')))

In [None]:
import torch
import os 
import matplotlib.pyplot as plt

root_path = '/content/'# 동일안 root??

train_root = './train'
val_root = './validation'

train_examples = os.listdir(train_root)
val_examples = os.listdir(val_root)

print("train_len: ", len(train_examples))
print("validation len: ", len(val_examples))

img = plt.imread(train_root+'/'+train_examples[0])
print(img.shape)
plt.imshow(img)
plt.show()


# Cuda

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("Using device: ", DEVICE)

# Noise Tranform

In [None]:
import torch
from torch.autograd import Variable
from torchvision import transforms
import random

class NoiseTransform(object):
  def __init__(self, size=256, mode="training"):
    super(NoiseTransform, self).__init__()
    self.size = size
    self.mode = mode
  
  def gaussian_noise(self, img):
    mean = 0
    stddev = 25
    noise = Variable(torch.zeros(img.size()))
    noise = noise.data.normal_(mean, stddev/255.)

    return noise

  def __call__(self, img):
    if (self.mode == "training") | (self.mode == "validation"):
      self.gt_transform = transforms.Compose([
        transforms.Resize((self.size, self.size), interpolation=2),
        transforms.ToTensor()])
      self.noise_transform = transforms.Compose([
        transforms.Resize((self.size, self.size), interpolation=2),
        transforms.ToTensor(),
        transforms.Lambda(self.gaussian_noise),
      ])
      return self.gt_transform(img), self.noise_transform(img)

    elif self.mode == "testing":
      self.gt_transform = transforms.Compose([
        transforms.ToTensor()])
      return self.gt_transform(img)
    else:
      print("[Noise transform]: mode error")
      return NotImplementedError



# Dataset

In [None]:
import torch
import torch.utils.data as data
import os 
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, ToTensor, ToPILImage
from PIL import Image 

def image_show(img):
  if isinstance(img, torch.Tensor):
    img = ToPILImage()(img)
  plt.imshow(img)
  plt.show()

class DenoisingDataSet(data.Dataset):
  def __init__(self, root_path, size=256):
    super(DenoisingDataSet, self).__init__()
    self.root_path = root_path
    self.size = size
    self.examples = [file_name for file_name in os.listdir(self.root_path)]
    # root를 모를 경우? 
    self.transforms = None
  
  def set_mode(self, mode):
    self.mode = mode
    self.transforms = NoiseTransform(self.size, mode)
  
  def __len__(self):
    return len(self.examples)
  
  def __getitem__(self, idx):
    file_name = self.examples[idx]

    # img = plt.imread(os.path.join(self.root_path, file_name))
    # #예제 / 
    img = Image.open(os.path.join(self.root_path, file_name))
    
    if self.mode == "testing":
      input_img = self.transforms(img)
      sample = {'img':input_img, 'file_name':file_name}
    else:
      clean, noise = self.transforms(img)
      sample = {'img':clean, 'noise':noise, 'file_name':file_name}
    
    return sample

# Data loader



In [None]:
import tqdm 
import torch.utils.data as data

BATCH_SIZE = 64

train_root = './train'
val_root = './validation'

train_dataset = DenoisingDataSet(train_root, 256) #root, size
train_dataset.set_mode('training')

val_dataset = DenoisingDataSet(val_root, 256)
val_dataset.set_mode('validation')


train_dataloader = data.DataLoader(
    train_dataset, 
    batch_size = BATCH_SIZE, 
    shuffle = True,
    num_workers = 2, 
    drop_last = True
)
val_dataloader = data.DataLoader(
    val_dataset, 
    batch_size = BATCH_SIZE, 
    shuffle = False,
    num_workers = 2,
    drop_last = True
)


# for i, data in enumerate(tqdm.tqdm(train_dataloader)):
#   img = data["img"]
#   noise = data["noise"]
#   model_input = img + noise
#   noise_image = torch.clamp(model_input, 0, 1)

#   print(i, len(img),len(noise), len(noise_image))
  

# Network Construction

In [None]:
import re
import os, glob, datetime, time
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import torch.nn.init as init
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR

import matplotlib.pyplot as plt
from torchvision import transforms
import tqdm
from PIL import Image

class DnCNN(nn.Module):
  def __init__(self, depth=17, n_channels=16, image_channels=3, use_bnorm=True, kernel_size=3):
    super(DnCNN, self).__init__()
    kernel_size = 3
    padding = 1
    layers = []

    layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
    layers.append(nn.ReLU(inplace=True))
    for _ in range(depth-2):
      layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
      layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
      layers.append(nn.ReLU(inplace=True))
    layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
    self.dncnn = nn.Sequential(*layers)
    self._initialize_weights()

  def forward(self, x):
    y = x
    out = self.dncnn(x)
    return y - out

  def _initialize_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        init.orthogonal_(m.weight)        
        if m.bias is not None:
          init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
          init.constant_(m.weight, 1)
          init.constant_(m.bias, 0)

In [None]:
class sum_squared_error(_Loss):
  def __init__(self, size_average=None, reduce=None, reduction='sum'):
      super(sum_squared_error, self).__init__(size_average, reduce, reduction)

  def forward(self, input, target):
      # return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
      return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)

# Training session

In [None]:
import torch, gc 
def free_cuda():
  gc.collect()
  torch.cuda.empty_cache()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
import numpy as np

net = DnCNN().cuda()
criterion = sum_squared_error()
optimizer = optim.Adam(net.parameters(), lr = 1e-3)

train_info = []
val_info = []
EPOCH = 30

save_path = './DenoisingNetwork'
os.makedirs(save_path, exist_ok = True)
output_path = os.path.join(save_path, 'denoising_model.tar')

# 1epoch

In [None]:
def train_1epoch(net, train_dataloader):  
  total_loss = 0
  iteration = 0
  net.train()
  for step ,sample in enumerate(tqdm.tqdm(train_dataloader)):
    img = sample['img']
    noise = sample['noise']
    model_input = img + noise
    noise_image = torch.clamp(model_input, 0, 1)    
    noise_image = noise_image.float().cuda()
    img = img.float().cuda()    
    denoised = net(noise_image)

    if denoised.size() != img.size():
      print("다른 크기",denoised.size(), img.size(), sample['file_name'])
      continue
    
    loss = criterion(denoised, img)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    iteration += 1  
      
  return total_loss / len(train_dataloader)

def validation_1epoch(net, val_dataloader):
  total_loss = 0
  iteration = 0
  net.eval()
  for step ,sample in enumerate(tqdm.tqdm(val_dataloader)):
    img = sample['img']
    noise = sample['noise']
    model_input = img + noise
    noise_image = torch.clamp(model_input, 0, 1)    
    noise_image = noise_image.float().cuda()
    img = img.float().cuda()
    denoised = net(noise_image)
    if denoised.size() != img.size():
      print("다른 크기",denoised.size(), img.size(), sample['file_name'])      
      continue

    loss = criterion(denoised, img)
    total_loss += loss.item()
    iteration += 1  
      
  return total_loss / len(val_dataloader)

# Training


In [None]:
low_loss = float('inf')
for epoch in range(EPOCH):
  print("{} EPOCH".format(epoch))
  train_loss = train_1epoch(net, train_dataloader)
  if (epoch % 10 == 0):
    print('epoch: {} loss: {}'.format(epoch+1, train_loss))
  train_info.append({'loss':train_loss})

  with torch.no_grad():
    val_loss = validation_1epoch(net, val_dataloader)
    val_info.append({'loss':val_loss})
    if val_loss < low_loss:
      low_loss = val_loss
      torch.save({
        'memo':'DenoisingDnCnnModel', 
        'loss':low_loss, 
        'model_weight':net.state_dict()
      }, output_path)

# loss graph

In [None]:
import numpy as np
import matplotlib.pyplot as plt
epoch_axis = np.arange(0, EPOCH)
print(len(train_info), len(val_info))
min_count = min(len(val_info), len(train_info))
plt.title('LOSS')
plt.plot(epoch_axis, [info['loss'] for info in train_info[:min_count]], epoch_axis, [info['loss'] for info in val_info[:min_count]], 'r-')
plt.legend(['TRAIN', 'VALIDATION'])

plt.show()

# Test session

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def transImage(img):
  if isinstance(img, torch.Tensor):
    return ToPILImage(img)
  else:
    return img
free_cuda()
for step ,sample in enumerate(val_dataloader):
    original = sample['img']
    noise = sample['noise']
    model_input = original + noise
    noised = torch.clamp(model_input, 0, 1)    
    noised = noised.float().cuda()
    image_show(noised[step])
    original = original.float()

    denoised = net(noised).cuda()
    image_show(denoised[step])

    if step == 2:
      break


