In [1]:
import re
import os, glob, datetime, time
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import torch.nn.init as init
from torch.utils.data import DataLoader
import torch.optim as optim

from skimage import io, color
from skimage.measure import compare_ssim
from skimage.measure import compare_psnr
import matplotlib.pyplot as plt

# os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [2]:
gamma = 0.5
sigmaU = 1
sigmaW = 10
epochs = 100

model_name = 'SDN_Color_Block3_noAdd_gamma_%.1f_sigmaU_%.1f_sigmaW_%d'%(gamma, sigmaU, sigmaW)
save_dir = os.path.join('Models', model_name)

In [3]:
class Block3(nn.Module):
    def __init__(self, ch, kernel_size=3):
        super(Block3, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, kernel_size, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        
        c1 = self.conv1(x)
        c2 = self.conv2(c1+x)
        c3 = self.conv3(c2+x)
        
        return c3

class SDNCNN(nn.Module):
    
    def __init__(self, filters=64, image_channels=3, use_bnorm=True, kernel_size=3):
        super(SDNCNN, self).__init__()
        kernel_size = 3
        padding = 1
        self.conv0 = nn.Conv2d(in_channels=image_channels, out_channels=filters, kernel_size=kernel_size, padding=padding)
        
        self.convOut1 = nn.Conv2d(in_channels=filters, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=True)
        self.convOut2 = nn.Conv2d(in_channels=filters, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=True)
        self.convOut3 = nn.Conv2d(in_channels=filters, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=True)
        
        self.ResBlock1 = Block3(filters)
        self.ResBlock2 = Block3(filters)
        self.ResBlock3 = Block3(filters)
        self.ResBlock4 = Block3(filters)
        self.ResBlock5 = Block3(filters)
        self.ResBlock6 = Block3(filters)
        self.ResBlock7 = Block3(filters)
        
        self._initialize_weights()

    def forward(self, x):

        c0 = self.conv0(x)
        
        c1 = self.ResBlock1(c0)
        
        c2 = self.ResBlock2(c1)
        
        c3 = self.ResBlock3(c2)
        
        c4 = self.ResBlock4(c3)

        c5 = self.ResBlock5(c4)
        c55 = self.convOut1(c5)
        
        c6 = self.ResBlock6(c5)
        c66 = self.convOut2(c6)
        
        c7 = self.ResBlock7(c6)
        c77 = self.convOut3(c7)
        
        noise = c55 + c66 + c77
        rec = x - noise
        
        return rec, noise

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
                
def findLastCheckpoint(save_dir):
    file_list = glob.glob(os.path.join(save_dir, 'model_*.pth'))
    if file_list:
        epochs_exist = []
        for file_ in file_list:
            result = re.findall(".*model_(.*).pth.*", file_)
            epochs_exist.append(int(result[0]))
        initial_epoch = max(epochs_exist)
    else:
        initial_epoch = 0
    return initial_epoch
def calculate_psnr_fast(prediction, target):
    # Calculate PSNR
    # Data have to be in range (0, 1)
    assert prediction.max().cpu().data.numpy() <= 1
    assert prediction.min().cpu().data.numpy() >= 0
    psnr_list = []
    #print(prediction.size(0))
    for i in range(prediction.size(0)):
        mse = torch.mean(torch.pow(prediction.data[i]-target.data[i], 2))
        mse = mse.cpu()
        try:
            psnr_list.append(10 * np.log10(1**2 / mse))
        except:
            print('error in psnr calculation')
            continue
    return psnr_list

def show(x, title=None, cbar=False, figsize=None):
    plt.figure(figsize=figsize)
    plt.axis('off')
    plt.imshow(x, interpolation='nearest')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()

In [4]:
model = SDNCNN()
initial_epoch = findLastCheckpoint(save_dir=save_dir)  # load the last model in matconvnet style
if initial_epoch > 0:
    print('resuming by loading epoch %03d' % initial_epoch)
    model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))

resuming by loading epoch 100


  "type " + container_type.__name__ + ". It won't be checked "
  "type " + container_type.__name__ + ". It won't be checked "


In [6]:
with torch.no_grad():
    np.random.seed(42)
    torch.cuda.manual_seed(42)
    torch.manual_seed(42)
    
    output_dir = 'Results_Log/%s'%(model_name)
    output_file_name = 'Results_Log/%s/Log_output_%s.txt'%(model_name, 'CBSD68')
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    test_data = './Test_datasets/CBSD68/*.png'
    test_dir = glob.glob(test_data)
    
    initial_epoch = findLastCheckpoint(save_dir=save_dir)
    for e in range(1, initial_epoch+1):

        model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % e))
        model = model.cuda()
        model.eval()

        psnr_list = []
        ssim_list = []

        for i in range(len(test_dir)):

            start_time = time.time()

            batch_x = io.imread(test_dir[i])
            u = np.random.normal(0, sigmaU, batch_x.shape)
            w = np.random.normal(0, sigmaW, batch_x.shape)
            noise = np.power(batch_x, gamma)*u + w
            batch_y = batch_x + noise

            batch_x = torch.from_numpy(batch_x.transpose(2,0,1).astype('float32'))[None,:,:,:].cuda()
            batch_y = torch.from_numpy(batch_y.transpose(2,0,1).astype('float32'))[None,:,:,:].cuda()

            out, noise_out = model(batch_y)

            batch_x = batch_x.clamp(0, 255)[0,...].cpu().detach().numpy().transpose(1,2,0) /255.0
            out = out.clamp(0, 255)[0,...].cpu().detach().numpy().transpose(1,2,0) /255.0

            psnr_list += [compare_psnr(out, batch_x)]
            
            ssim_list += [compare_ssim(out, batch_x, multichannel=True)]
            
#         output_data = 'Epoch: %d, PSNR: %.2f, SSIM: %.4f\n' % (e, np.mean(psnr_list), np.mean(ssim_list))
#         output_file = open(output_file_name, 'a')
#         output_file.write(output_data)
#         output_file.close()

        print('Epoch:%d, PSNR: %.3f, SSIM: %.4f' % (e, np.mean(psnr_list), np.mean(ssim_list)))

Epoch:1, PSNR: 32.680, SSIM: 0.9487
Epoch:2, PSNR: 33.371, SSIM: 0.9571
Epoch:3, PSNR: 33.546, SSIM: 0.9592
Epoch:4, PSNR: 33.439, SSIM: 0.9577
Epoch:5, PSNR: 33.566, SSIM: 0.9597
Epoch:6, PSNR: 33.554, SSIM: 0.9595
Epoch:7, PSNR: 33.524, SSIM: 0.9589
Epoch:8, PSNR: 33.566, SSIM: 0.9596
Epoch:9, PSNR: 33.491, SSIM: 0.9583
Epoch:10, PSNR: 33.530, SSIM: 0.9590
Epoch:11, PSNR: 33.503, SSIM: 0.9586
Epoch:12, PSNR: 33.581, SSIM: 0.9600
Epoch:13, PSNR: 33.488, SSIM: 0.9584
Epoch:14, PSNR: 33.521, SSIM: 0.9589
Epoch:15, PSNR: 33.529, SSIM: 0.9591
Epoch:16, PSNR: 32.501, SSIM: 0.9459
Epoch:17, PSNR: 33.180, SSIM: 0.9528
Epoch:18, PSNR: 33.439, SSIM: 0.9566
Epoch:19, PSNR: 33.367, SSIM: 0.9555
Epoch:20, PSNR: 33.578, SSIM: 0.9589
Epoch:21, PSNR: 33.555, SSIM: 0.9582
Epoch:22, PSNR: 33.333, SSIM: 0.9551
Epoch:23, PSNR: 33.391, SSIM: 0.9557
Epoch:24, PSNR: 33.815, SSIM: 0.9633
Epoch:25, PSNR: 33.836, SSIM: 0.9634
Epoch:26, PSNR: 33.820, SSIM: 0.9633
Epoch:27, PSNR: 33.840, SSIM: 0.9634
Epoch:28, 

KeyboardInterrupt: 