In [8]:
from model import DNet, SNet
import data_generator as dg
from data_generator import DenoisingDataset

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import time, os
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.nn.modules.loss import _Loss

In [10]:
class sum_squared_error(_Loss):  # PyTorch 0.4.1
    """
    Definition: sum_squared_error = 1/2 * nn.MSELoss(reduction = 'sum')
    The backward is defined as: input-target
    """
    def __init__(self, size_average=None, reduce=None, reduction='sum'):
        super(sum_squared_error, self).__init__(size_average, reduce, reduction)

    def forward(self, input, target):
        # return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
        return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)

### Pretrain DNet

In [13]:
batch_size = 128
cuda = torch.cuda.is_available()
n_epoch = 150
sigma = 25
lr = 1e-3
data_dir = './data/Train400'
model_dir = 'models'
model_name = 'DNet_sigma=25_1.pth'

model = DNet()

model.train()
criterion = sum_squared_error()
if cuda:
    model = model.cuda()

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)  # learning rates
for epoch in range(n_epoch):
    x = dg.datagenerator(data_dir=data_dir).astype('float32')/255.0
    x = torch.from_numpy(x.transpose((0, 3, 1, 2)))
    dataset = DenoisingDataset(x, sigma)
    loader = DataLoader(dataset=dataset, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True)
    epoch_loss = 0
    start_time = time.time()
    n_count=0
    for cnt, batch_yx in enumerate(loader):
        optimizer.zero_grad()
        if cuda:
            batch_original, batch_noise = batch_yx[1].cuda(), batch_yx[0].cuda()
        loss = criterion(batch_noise- model(batch_noise), batch_original)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
        if cnt%100 == 0:
            print('%4d %4d / %4d loss = %2.4f' % (epoch+1, cnt, x.size(0)//batch_size, loss.item()/batch_size))
        n_count +=1
    
    elapsed_time = time.time() - start_time
    print('epoch = %4d , loss = %4.4f , time = %4.2f s' % (epoch+1, epoch_loss/n_count, elapsed_time))
    torch.save(model, os.path.join(model_dir, model_name))

torch.save(model, os.path.join(model_dir, model_name))    

init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
^_^-training data finished-^_^
   1    0 / 1862 loss = 668.0828
   1  100 / 1862 loss = 7.1625
   1  200 / 1862 loss = 4.9611
   1  300 / 1862 loss = 3.9745
   1  400 / 1862 loss = 3.4851
   1  500 / 1862 loss = 3.2205
   1  600 / 1862 loss = 2.7894
   1  700 / 1862 loss = 2.4739
   1  800 / 1862 loss = 2.3594
   1  900 / 1862 loss = 2.2900
   1 1000 / 1862 loss = 2.1299
   1 1100 / 1862 loss = 2.1909
   1 1200 / 1862 loss = 2.1657
   1 1300 / 1862 loss = 2.0969
   1 1400 / 1862 loss = 1.9320
   1 1500 / 1862 loss = 1.7944
   1 1600 / 1862 loss = 1.9186
   1 1700 / 1862 loss = 1.8426
   1 1800 / 1862 loss = 1.7906
epoch =    1 , loss = 600.3373 , time = 211.44 s
^_^-training data finished-^_^
   2    0 / 1862 loss = 1.9716
   2  100 / 1862 loss = 1.9431
   2  200 / 1862 loss = 1.7412
 

### Pretrain SNet

In [12]:
batch_size = 128
cuda = torch.cuda.is_available()
n_epoch = 10
sigma = 25
lr = 1e-3
data_dir = './data/Train400'
model_dir = 'models'
model_name = 'DNet_sigma=25_1.pth'
save_name_s = 'SNet_sigma=25_1.pth'

SNet_model = SNet()
DNet_model = torch.load(os.path.join(model_dir, model_name))

DNet_model.eval()
SNet_model.train()

criterion = sum_squared_error()

if cuda:
    DNet_model = DNet_model.cuda()
    SNet_model = SNet_model.cuda()

optimizer = optim.Adam(DNet_model.parameters(), lr=lr)
scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)  # learning rates

for epoch in range(n_epoch):
    x = dg.datagenerator(data_dir=data_dir).astype('float32')/255.0
    x = torch.from_numpy(x.transpose((0, 3, 1, 2)))
    dataset = DenoisingDataset(x, sigma)
    loader = DataLoader(dataset=dataset, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True)
    epoch_loss = 0
    start_time = time.time()
    n_count=0
    
    for cnt, batch_yx in enumerate(loader):
        optimizer.zero_grad()
        if cuda:
            batch_original, batch_noise = batch_yx[1].cuda(), batch_yx[0].cuda()
        
        residual = DNet_model(batch_noise)
        denoised = batch_noise - residual
        
        residual = 1.55*(residual+0.5)-0.8
        structure = SNet_model(residual, denoised)
        target = 1.8*(batch_original-denoised+0.5)-0.8
        
        loss = criterion(structure, target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
        if cnt%100 == 0:
            print('%4d %4d / %4d loss = %2.4f' % (epoch+1, cnt, x.size(0)//batch_size, loss.item()/batch_size))
        n_count +=1
    
    elapsed_time = time.time() - start_time
    print('epoch = %4d , loss = %4.4f , time = %4.2f s' % (epoch+1, epoch_loss/n_count, elapsed_time))
    torch.save(SNet_model, os.path.join(model_dir, save_name_s))

torch.save(SNet_model, os.path.join(model_dir, save_name_s))  

init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
^_^-training data finished-^_^
   1    0 / 1862 loss = 146.3578
   1  100 / 1862 loss = 24.3115
   1  200 / 1862 loss = 16.6896
   1  300 / 1862 loss = 14.8828
   1  400 / 1862 loss = 14.9912
   1  500 / 1862 loss = 12.3780
   1  600 / 1862 loss = 12.3868
   1  700 / 1862 loss = 12.3291
   1  800 / 1862 loss = 11.4862
   1  900 / 1862 loss = 10.6100
   1 1000 / 1862 loss = 10.7344
   1 1100 / 1862 loss = 11.0920
   1 1200 / 1862 loss = 10.7167
   1 1300 / 1862 loss = 10.7239
   1 1400 / 1862 loss = 9.9885
   1 1500 / 1862 loss = 10.2467
   1 1600 / 1862 loss = 9.9037
   1 1700 / 1862 loss = 9.7338
   1 1800 / 1862 loss = 9.6482
epoch =    1 , loss = 1750.1406 , time = 248.19 s
^_^-training data finished-^_^
   2    0 / 1862 loss = 9.3908
   2  100 / 