# Get Dataset from Google Drive  
Please upload your dataset on google drive first.

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
import os
import zipfile
import tqdm
from PIL import Image
from tqdm import tqdm

file_name = "Multimedia_dataset.zip"
zip_path = os.path.join('/content/drive/MyDrive/lab/Multimedia_dataset.zip')

!cp "{zip_path}" .
!unzip -q "{file_name}"
!rm "{file_name}"


file_list = os.listdir('/content/train')
for file in tqdm(file_list):
  image = Image.open("/content/train/" + str(file))
  inverted_image = image.transpose(Image.FLIP_LEFT_RIGHT)
  inverted_image.save('/content/train/' + 'inverted_' + str(file) + '.png')
  rotated_image = image.rotate(90)
  rotated_image.save('/content/train/' + 'rotated90_' + str(file) + '.png')
  rotated_image = image.rotate(180)
  rotated_image.save('/content/train/' + 'rotated180_' + str(file) + '.png')
  rotated_image = image.rotate(270)
  rotated_image.save('/content/train/' + 'rotated270_' + str(file) + '.png')

# Noise Transform  
If you want to change how much noise you are giving, change the stddev and mean values at 'gaussian_noise' function.

In [None]:
import torch
from torch.autograd import Variable
from torchvision import transforms

import random

class NoiseTransform(object):
  def __init__(self, size=180, mode="training"):
    super(NoiseTransform, self).__init__()
    self.size = size
    self.mode = mode
  
  def gaussian_noise(self, img):
    mean = 0
    stddev = 25
    noise = Variable(torch.zeros(img.size()))
    noise = noise.data.normal_(mean, stddev/255.)

    return noise

  def __call__(self, img):
    if (self.mode == "training") | (self.mode == "validation"):
      self.gt_transform = transforms.Compose([
        # transforms.RandomCrop(self.size),
        transforms.Resize((self.size, self.size), interpolation=2),
        transforms.ToTensor()])
      self.noise_transform = transforms.Compose([
        # transforms.RandomCrop(self.size),
        transforms.Resize((self.size, self.size), interpolation=2),
        transforms.ToTensor(),
        transforms.Lambda(self.gaussian_noise),
      ])
      return self.gt_transform(img), self.noise_transform(img)

    elif self.mode == "testing":
      self.gt_transform = transforms.Compose([
        # transforms.Resize((self.size, self.size), interpolation=2),
        transforms.ToTensor()])
      return self.gt_transform(img)
    else:
      return NotImplementedError


# Dataloader for Noise Dataset

In [None]:
import torch
import torch.utils.data  as data
import os
from PIL import Image

class NoiseDataset(data.Dataset):
  def __init__(self, root_path, size):
    super(NoiseDataset, self).__init__()

    self.root_path = root_path
    self.size = size
    self.transforms = None
    self.examples = None

  def set_mode(self, mode):
    self.mode = mode
    self.transforms = NoiseTransform(self.size, mode)
    if mode == "training":
      train_dir = os.path.join(self.root_path, "train")
      self.examples = [os.path.join(self.root_path, "train", dirs) for dirs in os.listdir(train_dir)]
    elif mode == "validation":
      val_dir = os.path.join(self.root_path, "validation")
      self.examples = [os.path.join(self.root_path, "validation", dirs) for dirs in os.listdir(val_dir)]
    elif mode == "testing":
      test_dir = os.path.join(self.root_path, "test")
      self.examples = [os.path.join(self.root_path, "test", dirs) for dirs in os.listdir(test_dir)]
    else:
      raise NotImplementedError
  
  def __len__(self):
    return len(self.examples)

  def __getitem__(self, idx):
    file_name = self.examples[idx]
    image = Image.open(file_name)

    if self.mode == "testing":
      input_img = self.transforms(image)
      sample = {"img": input_img}
    else:
      clean, noise = self.transforms(image)
      sample = {"img": clean, "noise": noise}

    return sample

# Example for Loading

In [None]:
import torch
import torch.utils.data  as data
import os
import matplotlib.pyplot as plt
from torchvision import transforms
import tqdm
from PIL import Image

def image_show(img):
  if isinstance(img, torch.Tensor):
    img = transforms.ToPILImage()(img)
  plt.imshow(img)
  plt.show()

# Change to your data root directory
root_path = "/content/"
# Depend on runtime setting
use_cuda = True

train_dataset = NoiseDataset(root_path, 128)
train_dataset.set_mode("training")

train_dataloader = data.DataLoader(train_dataset, batch_size=16, shuffle=True)

for i, data in enumerate(tqdm.tqdm(train_dataloader)):
  if use_cuda:
    img = data["img"].to('cuda')
    noise = data["noise"].to('cuda')
  
  model_input = img + noise
  noise_image = torch.clamp(model_input, 0, 1)


  if i % 100 == 0:
    image_show(img[0])
    image_show(noise_image[0])



# model

In [None]:
import torch 
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
from matplotlib import pyplot as plt
import os
import cv2
import numpy as np
from math import log10
from torch.autograd import Variable
from skimage.util import random_noise
from google.colab.patches import cv2_imshow
from torchsummary import summary

def conv3x3(in_chn, out_chn, bias=True):
    layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias)
    return layer

def conv_down(in_chn, out_chn, bias=False): #픽셀 절반으로 down
    layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias)
    return layer

def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias, stride = stride)

## Supervised Attention Module
class SAM(nn.Module):
    def __init__(self, n_feat, kernel_size=3, bias=True):
        super(SAM, self).__init__()
        self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
        self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)
        self.conv3 = conv(3, n_feat, kernel_size, bias=bias)

    def forward(self, x, x_img):
        x1 = self.conv1(x)
        img = self.conv2(x) + x_img
        x2 = torch.sigmoid(self.conv3(img))
        x1 = x1*x2
        x1 = x1+x
        return x1, img

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        
        self.down_path_1 = nn.ModuleList()
        self.down_path_2 = nn.ModuleList()
        self.up_path_1 = nn.ModuleList()
        self.up_path_2 = nn.ModuleList()
        self.skip_conv_1 = nn.ModuleList()
        self.skip_conv_2 = nn.ModuleList()

        self.conv_01 = nn.Conv2d(3, 64, 3, 1, 1) # 픽셀 변화 없이 채널만 3 -> 64
        self.conv_02 = nn.Conv2d(3, 64, 3, 1, 1) # 픽셀 변화 없이 채널만 3 -> 64

        self.down_path_1.append(UNetDown(64, 64, True, 0.2))
        self.down_path_1.append(UNetDown(64, 128, True, 0.2))
        self.down_path_1.append(UNetDown(128, 256, True, 0.2))
        self.down_path_1.append(UNetDown(256, 512, True, 0.2))
        self.down_path_1.append(UNetDown(512, 1024, False, 0.2))

        self.up_path_1.append(UNetUp(1024, 512, 0.2))
        self.skip_conv_1.append(nn.Conv2d(512, 512, 3, 1, 1))

        self.up_path_1.append(UNetUp(512, 256, 0.2))
        self.skip_conv_1.append(nn.Conv2d(256, 256, 3, 1, 1))

        self.up_path_1.append(UNetUp(256, 128, 0.2))
        self.skip_conv_1.append(nn.Conv2d(128, 128, 3, 1, 1))

        self.up_path_1.append(UNetUp(128, 64, 0.2))
        self.skip_conv_1.append(nn.Conv2d(64, 64, 3, 1, 1))


        self.sam12 = SAM(64)
        self.cat12 = nn.Conv2d(128, 64, 1, 1, 0)

        self.down_path_2.append(UNetDown(64, 64, True, 0.2))
        self.down_path_2.append(UNetDown(64, 128, True, 0.2))
        self.down_path_2.append(UNetDown(128, 256, True, 0.2))
        self.down_path_2.append(UNetDown(256, 512, True, 0.2))
        self.down_path_2.append(UNetDown(512, 1024, False, 0.2))

        self.up_path_2.append(UNetUp(1024, 512, 0.2))
        self.skip_conv_2.append(nn.Conv2d(512, 512, 3, 1, 1))

        self.up_path_2.append(UNetUp(512, 256, 0.2))
        self.skip_conv_2.append(nn.Conv2d(256, 256, 3, 1, 1))

        self.up_path_2.append(UNetUp(256, 128, 0.2))
        self.skip_conv_2.append(nn.Conv2d(128, 128, 3, 1, 1))

        self.up_path_2.append(UNetUp(128, 64, 0.2))
        self.skip_conv_2.append(nn.Conv2d(64, 64, 3, 1, 1))

        self.last = conv3x3(64, 3, bias=True)

    def forward(self, x):
        image = x

        x1 = self.conv_01(image) #3c -> 64c
        encs = []
        for i, down in enumerate(self.down_path_1):
            if (i+1) < 5:
                x1, x1_up = down(x1)
                encs.append(x1_up)
            else:
                x1 = down(x1)

        for i, up in enumerate(self.up_path_1):
            x1 = up(x1, self.skip_conv_1[i](encs[-i-1]))


        sam_feature, out_1 = self.sam12(x1, image)

        x2 = self.conv_02(image)
        x2 = self.cat12(torch.cat([x2, sam_feature], dim=1))
        
        encs2 = []
        for i, down in enumerate(self.down_path_2):
            if (i+1) < 5:
                x2, x2_up = down(x2)
                encs2.append(x2_up)
            else:
                x2 = down(x2)

        for i, up in enumerate(self.up_path_1):
            x2 = up(x2, self.skip_conv_2[i](encs2[-i-1]))

        out_2 = self.last(x2)
        out_2 = out_2 + image
        return [out_1, out_2]




class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, downsample, relu_slope):
        super(UNetDown, self).__init__()
        self.downsample = downsample
        self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0)

        self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True) #채널 변경, 픽셀변경 없음
        self.relu_1 = nn.LeakyReLU(relu_slope, inplace=False)# 리키렐루 경사각만 조정
        self.conv_2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True) #채널 픽셀 변경없이 conv 1층 추가
        self.relu_2 = nn.LeakyReLU(relu_slope, inplace=False)# 리키렐루 경사각만 조정

        if downsample:
            self.downsample = conv_down(out_size, out_size, bias=False) #크기 절반으로 down

    def forward(self, x):
        out = self.conv_1(x)

        out = self.relu_1(out)
        out = self.relu_2(self.conv_2(out))

        out += self.identity(x)

        if self.downsample:
            out_down = self.downsample(out)
            return out_down, out
        else:
            return out


class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, relu_slope):
        super(UNetUp, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True)
        self.conv_block = UNetConvBlock(in_size, out_size, False, relu_slope)

    def forward(self, x, bridge):
        up = self.up(x)
        out = torch.cat([up, bridge], 1)
        out = self.conv_block(out)
        return out


class skip_blocks(nn.Module):

    def __init__(self, in_size, out_size):
        super(skip_blocks, self).__init__()
        self.blocks = nn.ModuleList()
        self.blocks.append(UNetConvBlock(in_size, 128, False, 0.2))
        self.blocks.append(UNetConvBlock(128, out_size, False, 0.2))
        self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True)

    def forward(self, x):
        sc = self.shortcut(x)
        for m in self.blocks:
            x = m(x)
        return x + sc



In [None]:
summary(Net().cuda(), (3, 128, 128)) #모델 요약

In [None]:
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        # nn.init.uniform(m.weight.data, 1.0, 0.02)
        m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
        nn.init.constant(m.bias.data, 0.0)

In [None]:
import math

model = Net().cuda() #모델 정의
# model.apply(weights_init_kaiming)

# model = torch.load("/content/model_epoch_35.pth")
# model = model['arch'].cuda()
# model.load_state_dict(model['state_dict'])  # state_dict를 불러 온 후, 모델에 저장


MSE = nn.MSELoss()
# MSE = PSNRLoss()
# fn_loss = CharbonnierLoss()

# model = torch.load("/content/model_epoch_5.pth")
# model = model['arch'].cuda()

# fn_loss = CharbonnierLoss()
# optimizer = optim.Adam(model.parameters(),lr=1e-4)

optimizer = optim.Adam(model.parameters(), lr=1e-6, betas=(0.9, 0.999),eps=1e-8, weight_decay=1e-8)
# optimizer = optim.Adam(model.parameters(), lr=lr) # Adam Optimizer

In [None]:
def remove_noise(model, image): # 노이즈 이미지 - model output 이미지 계산 함수
    out = torch.clamp(image - model(image), 0., 1.)
    # out = image - model(image)
    out = out.cpu().clone()
    out = out.squeeze(0)
    trans = transforms.ToPILImage()
    plt.imshow(trans(out))
    plt.show()

def img_show(image): # 노이즈 이미지 - model output 이미지 계산 함수
    out = torch.clamp(image, 0., 1.)
    # out = image
    out = out.cpu().clone()
    out = out.squeeze(0)
    trans = transforms.ToPILImage()
    plt.imshow(trans(out))
    plt.show()

In [None]:
from skimage.measure.simple_metrics import compare_psnr
def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += compare_psnr(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range)
    return (PSNR/Img.shape[0])

ts = transforms.ToPILImage() # tensor 배열을 이미지로 변환
def train(epoch): # epoch 만큼 학습 반복
    epoch_loss = 0
    for iteration, data in enumerate(train_dataloader, 1):


        model.train()
        model.zero_grad()
        optimizer.zero_grad()

        target = data["img"] #target = 원본 사진
        noise = data["noise"] #노이즈만 있는 사진
        model_input = target + noise


        noise = Variable(noise.cuda())
        target = Variable(target.cuda())
        model_input = Variable(model_input.cuda())

        output = model(model_input)
        
        psnr_train = batch_PSNR(output[1],target, 1.)
        psnr_train2 = batch_PSNR(output[0],target, 1.)
        # loss = np.sum([fn_loss(torch.clamp(output[j],0,1),target) for j in range(len(output))])
        loss = MSE(output[1], target) + MSE(output[0], target)
        epoch_loss += loss.item() 


        loss.backward()
        optimizer.step()


        if iteration % 100 == 0:

          # remove_noise(model,model_input[0].view([-1, 3, 128, 128]).cuda())
          # img_show(real_output[0].view([-1, 3, 128, 128]).cuda())
          # img_show(output[0][0].view([-1, 3, 128, 128]).cuda())
          # img_show(output[1][0].view([-1, 3, 128, 128]).cuda())
          # img_show(output[2][0].view([-1, 3, 128, 128]).cuda())
          img_show(output[0][0].view([-1, 3, 128, 128]).cuda())
          img_show(output[1][0].view([-1, 3, 128, 128]).cuda())
          img_show(target[0].view([-1, 3, 128, 128]).cuda())
          img_show(model_input[0].view([-1, 3, 128, 128]).cuda())

          # img_show(output[0].view([-1, 3, 128, 128]).cuda())

        print("Epoch[{}]({}/{}): Loss: {:.4f} : psnr {},psnr {}".format(epoch, iteration, len(train_dataloader), loss.item(), psnr_train,psnr_train2))

    print("Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(train_dataloader)))

In [None]:
def save_checkpoint(state): #epoch 마다 모델 저장
    model_out_path = "model_epoch_{}.pth".format(epoch)
    torch.save(state, model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))

In [None]:
num_epochs = 1000
for epoch in range(1, num_epochs + 1):
    train(epoch)
    if epoch % 1 == 0:
      save_checkpoint({
          'epoch': epoch + 1,
          'arch': model,
          'state_dict': model.state_dict(),
          'optimizer' : optimizer.state_dict(),
      })

In [None]:
model = torch.load('/content/model_epoch_35.pth') # 모델 경로 불러오기
model = model['arch'] # 모델 구조 추출

def image_loader(image_name): # 이미지 로더
    image = Image.open(image_name)
    image = loader(image).float()
    image = Variable(image, requires_grad=True)
    image = image.unsqueeze(0) 

    return image.cuda()

val_dataset = NoiseDataset(root_path, 128)
val_dataset.set_mode("validation")

print(len(val_dataset))
loss = 0
for img in val_dataset:
  model_input = img["img"] + img["noise"]
  output = model(model_input.view([-1, 3, 128, 128]).cuda())
  psnr_train = batch_PSNR(output[1],img["img"].view([-1, 3, 128, 128]), 1.)
  loss += psnr_train
  print(psnr_train)

print(loss / 500)