In [1]:
import re
import os, glob, datetime, time
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import torch.nn.init as init
from torch.utils.data import DataLoader
import torch.optim as optim

from skimage import io, color
from skimage.measure import compare_ssim
from skimage.measure import compare_psnr
import matplotlib.pyplot as plt

# os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [2]:
gamma = 0.5
sigmaU = 1
sigmaW = 10
epochs = 100

model_name = 'SDN_Color_Block2_gamma_%.1f_sigmaU_%.1f_sigmaW_%d'%(gamma, sigmaU, sigmaW)
save_dir = os.path.join('Models', model_name)

In [3]:
class Block2(nn.Module):
    def __init__(self, ch, kernel_size=3):
        super(Block2, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        c1 = self.conv1(x)
        c2 = self.conv2(c1+x)
        c3 = self.conv3(c2+x)
        
        return c3

class SDNCNN(nn.Module):
    
    def __init__(self, filters=64, image_channels=3, use_bnorm=True, kernel_size=3):
        super(SDNCNN, self).__init__()
        kernel_size = 3
        padding = 1
        self.conv0 = nn.Conv2d(in_channels=image_channels, out_channels=filters, kernel_size=kernel_size, padding=padding)
        
        self.convOut1 = nn.Conv2d(in_channels=filters, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=True)
        self.convOut2 = nn.Conv2d(in_channels=filters, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=True)
        self.convOut3 = nn.Conv2d(in_channels=filters, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=True)
        
        self.ResBlock1 = Block2(filters)
        self.ResBlock2 = Block2(filters)
        self.ResBlock3 = Block2(filters)
        self.ResBlock4 = Block2(filters)
        self.ResBlock5 = Block2(filters)
        self.ResBlock6 = Block2(filters)
        self.ResBlock7 = Block2(filters)
        
        self._initialize_weights()

    def forward(self, x):

        c0 = self.conv0(x)
        
        c1 = self.ResBlock1(c0)
        
        c2 = self.ResBlock2(c1)
        
        c3 = self.ResBlock3(c2)
        
        c4 = self.ResBlock4(c3)

        c5 = self.ResBlock5(c4+c3)
        c55 = self.convOut1(c5)
        
        c6 = self.ResBlock6(c5+c2)
        c66 = self.convOut2(c6)
        
        c7 = self.ResBlock7(c6+c1)
        c77 = self.convOut3(c7)
        
        noise = c55 + c66 + c77
        rec = x - noise
        
        return rec, noise

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)


def findLastCheckpoint(save_dir):
    file_list = glob.glob(os.path.join(save_dir, 'model_*.pth'))
    if file_list:
        epochs_exist = []
        for file_ in file_list:
            result = re.findall(".*model_(.*).pth.*", file_)
            epochs_exist.append(int(result[0]))
        initial_epoch = max(epochs_exist)
    else:
        initial_epoch = 0
    return initial_epoch

In [4]:
model = SDNCNN()
initial_epoch = findLastCheckpoint(save_dir=save_dir)  # load the last model in matconvnet style
if initial_epoch > 0:
    print('resuming by loading epoch %03d' % initial_epoch)
    model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))

resuming by loading epoch 082


  "type " + container_type.__name__ + ". It won't be checked "
  "type " + container_type.__name__ + ". It won't be checked "


In [None]:
with torch.no_grad():
    np.random.seed(42)
    torch.cuda.manual_seed(42)
    torch.manual_seed(42)
    
    output_dir = 'Results_Log/%s'%(model_name)
    output_file_name = 'Results_Log/%s/Log_output_%s.txt'%(model_name, 'CBSD68')
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    test_data = './Test_datasets/CBSD68/*.png'
    test_dir = glob.glob(test_data)
    
    initial_epoch = findLastCheckpoint(save_dir=save_dir)
    for e in range(95, 101):

        model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % e))
        model = model.cuda()
        model.eval()

        psnr_list = []
        ssim_list = []

        for i in range(len(test_dir)):

            start_time = time.time()

            batch_x = io.imread(test_dir[i])
            u = np.random.normal(0, sigmaU, batch_x.shape)
            w = np.random.normal(0, sigmaW, batch_x.shape)
            noise = np.power(batch_x, gamma)*u + w
            batch_y = batch_x + noise

            batch_x = torch.from_numpy(batch_x.transpose(2,0,1).astype('float32'))[None,:,:,:].cuda()
            batch_y = torch.from_numpy(batch_y.transpose(2,0,1).astype('float32'))[None,:,:,:].cuda()

            out, noise_out = model(batch_y)

            batch_x = batch_x.clamp(0, 255)[0,...].cpu().detach().numpy().transpose(1,2,0) /255.0
            out = out.clamp(0, 255)[0,...].cpu().detach().numpy().transpose(1,2,0) /255.0

            psnr_list += [compare_psnr(out, batch_x)]
            
            ssim_list += [compare_ssim(out, batch_x, multichannel=True)]
            
        output_data = 'Epoch: %d, PSNR: %.2f, SSIM: %.4f\n' % (e, np.mean(psnr_list), np.mean(ssim_list))
        output_file = open(output_file_name, 'a')
        output_file.write(output_data)
        output_file.close()

        print('Epoch:%d, PSNR: %.3f, SSIM: %.4f' % (e, np.mean(psnr_list), np.mean(ssim_list)))

In [None]:
with torch.no_grad():
    np.random.seed(42)
    torch.cuda.manual_seed(42)
    torch.manual_seed(42)
    
    output_dir = 'Results_Log/%s'%(model_name)
    output_file_name = 'Results_Log/%s/Log_output_%s.txt'%(model_name, 'kodak')
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    test_data = './Test_datasets/kodak/*.png'
    test_dir = glob.glob(test_data)
    
    initial_epoch = findLastCheckpoint(save_dir=save_dir)
    for e in range(95, initial_epoch+1):

        model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % e))
        model = model.cuda()
        model.eval()

        psnr_list = []
        ssim_list = []

        for i in range(len(test_dir)):
            start_time = time.time()

            batch_x = io.imread(test_dir[i])
            u = np.random.normal(0, sigmaU, batch_x.shape)
            w = np.random.normal(0, sigmaW, batch_x.shape)
            noise = np.power(batch_x, gamma)*u + w
            batch_y = batch_x + noise
            
            batch_x = torch.from_numpy(batch_x.transpose(2,0,1).astype('float32'))[None,:,:,:].cuda()
            batch_y = torch.from_numpy(batch_y.transpose(2,0,1).astype('float32'))[None,:,:,:].cuda()

            out, noise_out = model(batch_y)
            
            batch_x = batch_x.clamp(0, 255)[0,...].cpu().detach().numpy().transpose(1,2,0) /255.0
            out = out.clamp(0, 255)[0,...].cpu().detach().numpy().transpose(1,2,0) /255.0

            psnr_list += [compare_psnr(out, batch_x)]
            
            ssim_list += [compare_ssim(out, batch_x, multichannel=True)]
            
        output_data = 'Epoch: %d, PSNR: %.2f, SSIM: %.4f\n' % (e, np.mean(psnr_list), np.mean(ssim_list))
        output_file = open(output_file_name, 'a')
        output_file.write(output_data)
        output_file.close()

        print('Epoch:%d, PSNR: %.3f, SSIM: %.4f' % (e, np.mean(psnr_list), np.mean(ssim_list)))