In [1]:
from __future__ import print_function, division
import os
import torchvision
import torch
from skimage import io, transform
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn
from math import log10, pi
import time

import utils
from datasetsMultiple import DatasetMultiple
from vgg import Vgg16

c:\users\adams\appdata\local\programs\python\python39\lib\site-packages\numpy\.libs\libopenblas.EL2C6PLE4ZYW3ECEVIV3OXXGRN2NRFM2.gfortran-win_amd64.dll
c:\users\adams\appdata\local\programs\python\python39\lib\site-packages\numpy\.libs\libopenblas.GK7GX5KEQ4F6UYO3P26ULGBQYHGQO7J4.gfortran-win_amd64.dll


In [2]:
class MFFNet(torch.nn.Module):
    def __init__(self):
        super(MFFNet, self).__init__()
        
        self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        self.res6 = ResidualBlock(128)
        self.res7 = ResidualBlock(128)
        self.res8 = ResidualBlock(128)
        self.res9 = ResidualBlock(128)
        self.res10 = ResidualBlock(128)
        self.res11 = ResidualBlock(128)
        self.res12 = ResidualBlock(128)
        self.res13 = ResidualBlock(128)
        self.res14 = ResidualBlock(128)
        self.res15 = ResidualBlock(128)
        self.res16 = ResidualBlock(128)
        
        self.deconv1 = UpsampleConvLayer(128*2, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64*2, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32*2, 3, kernel_size=9, stride=1)

        self.relu = torch.nn.ReLU()
    
    def forward(self, X):
        o1 = self.relu(self.conv1(X))
        o2 = self.relu(self.conv2(o1))
        o3 = self.relu(self.conv3(o2))

        y = self.res1(o3)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.res6(y)
        y = self.res7(y)
        y = self.res8(y)
        y = self.res9(y)
        y = self.res10(y)
        y = self.res11(y)
        y = self.res12(y)
        y = self.res13(y)
        y = self.res14(y)
        y = self.res15(y)
        y = self.res16(y)
        
        in1 = torch.cat( (y, o3), 1 )
        y = self.relu(self.deconv1(in1))
        in2 = torch.cat( (y, o2), 1 )
        y = self.relu(self.deconv2(in2))
        in3 = torch.cat( (y, o1), 1 )
        y = self.deconv3(in3)
        
        return y

class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

class ResidualBlock(torch.nn.Module):
    
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out = out + residual
        return out


class UpsampleConvLayer(torch.nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out
  

In [3]:
train_dataset = DatasetMultiple('rgb train/', 'rgb train/', '380 train a/', '640 train a/', True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=16,
                                           shuffle=True,
                                           num_workers=4)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
imageFilter = MFFNet().to(device).float()

# Initializing VGG16 model for perceptual loss
VGG = Vgg16(requires_grad=False)
VGG = VGG.to(device)


num_epochs = 600
learning_rate = 1e-4

criterion_img = nn.MSELoss()
criterion_vgg = nn.MSELoss()

optimizer = torch.optim.Adam(imageFilter.parameters(), lr=learning_rate)
total_step = len(train_loader)


start_time = time.time()
previous_time = start_time
for epoch in range(num_epochs):
    loss_tol = 0
    loss_tol_vgg  = 0
    loss_tol_l2   = 0
    
    if epoch == 300:
        learning_rate = 1e-5
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate
        
    if epoch == 600:
        learning_rate = 1e-6
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate
    
    for i, im in enumerate(train_loader):
        inputs = im[0].float().to(device)
        target = im[1].float().to(device)
        
        outputs = imageFilter(inputs)
        
        loss_l2 = criterion_img( outputs, target )
        
        outputs_n = utils.normalize_ImageNet_stats(outputs)
        target_n  = utils.normalize_ImageNet_stats(target)
        
        feature_o = VGG(outputs_n, 3)
        feature_t = VGG(target_n, 3)
        VGG_loss = []
        for l in range(3+1):
            VGG_loss.append( criterion_vgg(feature_o[l], feature_t[l]) )
        
        loss_vgg = sum(VGG_loss)
        loss = loss_l2 + 0.01*loss_vgg
    
        loss_tol += loss.item()
        
        loss_tol_vgg  += loss_vgg
        loss_tol_l2   += loss_l2
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print ( 'Epoch [{}/{}], Training Loss: {:.4f}, vgg Loss: {:.4f}, L2 Loss: {:.4f}, time: {:.4f}' .format(epoch+1, num_epochs, loss_tol, loss_tol_vgg, loss_tol_l2, time.time() - previous_time) )
    previous_time = time.time()
    
print("--- %0.4f seconds ---" % (time.time() - start_time)) 
torch.save(imageFilter.state_dict(), 'MFF-net_all3_old.ckpt')

Epoch [1/600], Training Loss: 793.1590, vgg Loss: 66963.1406, L2 Loss: 123.5277, time: 45.6465
Epoch [2/600], Training Loss: 169.8015, vgg Loss: 14634.4258, L2 Loss: 23.4573, time: 45.9326
Epoch [3/600], Training Loss: 104.7223, vgg Loss: 8924.1572, L2 Loss: 15.4807, time: 46.6451
Epoch [4/600], Training Loss: 85.5938, vgg Loss: 7259.2783, L2 Loss: 13.0010, time: 47.6569
Epoch [5/600], Training Loss: 73.8493, vgg Loss: 6196.3823, L2 Loss: 11.8855, time: 48.6194
Epoch [6/600], Training Loss: 61.4747, vgg Loss: 5118.2417, L2 Loss: 10.2923, time: 47.6413
Epoch [7/600], Training Loss: 56.7786, vgg Loss: 4707.0068, L2 Loss: 9.7086, time: 47.7698
Epoch [8/600], Training Loss: 48.4824, vgg Loss: 4063.5142, L2 Loss: 7.8473, time: 56.5676
Epoch [9/600], Training Loss: 42.1711, vgg Loss: 3557.1360, L2 Loss: 6.5997, time: 60.6503
Epoch [10/600], Training Loss: 42.1526, vgg Loss: 3509.6594, L2 Loss: 7.0560, time: 46.6941
Epoch [11/600], Training Loss: 76.9300, vgg Loss: 5867.4961, L2 Loss: 18.2551

Epoch [91/600], Training Loss: 8.2050, vgg Loss: 741.4686, L2 Loss: 0.7903, time: 47.8773
Epoch [92/600], Training Loss: 7.8288, vgg Loss: 711.1799, L2 Loss: 0.7170, time: 47.8338
Epoch [93/600], Training Loss: 7.3973, vgg Loss: 676.1514, L2 Loss: 0.6358, time: 48.2539
Epoch [94/600], Training Loss: 7.4018, vgg Loss: 673.0460, L2 Loss: 0.6713, time: 47.8021
Epoch [95/600], Training Loss: 6.9356, vgg Loss: 633.1016, L2 Loss: 0.6046, time: 47.7496
Epoch [96/600], Training Loss: 6.8190, vgg Loss: 624.4165, L2 Loss: 0.5748, time: 48.7054
Epoch [97/600], Training Loss: 6.6821, vgg Loss: 611.2844, L2 Loss: 0.5693, time: 47.8222
Epoch [98/600], Training Loss: 6.4609, vgg Loss: 593.4432, L2 Loss: 0.5264, time: 49.6294
Epoch [99/600], Training Loss: 6.4360, vgg Loss: 587.1775, L2 Loss: 0.5642, time: 48.1080
Epoch [100/600], Training Loss: 6.4755, vgg Loss: 580.3886, L2 Loss: 0.6716, time: 47.8432
Epoch [101/600], Training Loss: 6.0634, vgg Loss: 557.4421, L2 Loss: 0.4890, time: 47.9391
Epoch [1

Epoch [181/600], Training Loss: 2.9265, vgg Loss: 268.6066, L2 Loss: 0.2404, time: 112.8723
Epoch [182/600], Training Loss: 2.9150, vgg Loss: 264.5136, L2 Loss: 0.2699, time: 106.6943
Epoch [183/600], Training Loss: 2.8244, vgg Loss: 258.2632, L2 Loss: 0.2418, time: 106.6279
Epoch [184/600], Training Loss: 2.8554, vgg Loss: 261.3473, L2 Loss: 0.2420, time: 106.6895
Epoch [185/600], Training Loss: 2.9187, vgg Loss: 266.8355, L2 Loss: 0.2503, time: 106.8366
Epoch [186/600], Training Loss: 3.0960, vgg Loss: 272.2135, L2 Loss: 0.3739, time: 106.7011
Epoch [187/600], Training Loss: 2.7673, vgg Loss: 254.5007, L2 Loss: 0.2223, time: 106.7686
Epoch [188/600], Training Loss: 2.9143, vgg Loss: 261.4978, L2 Loss: 0.2993, time: 106.7588
Epoch [189/600], Training Loss: 2.8086, vgg Loss: 257.0572, L2 Loss: 0.2381, time: 112.6536
Epoch [190/600], Training Loss: 2.8452, vgg Loss: 261.3557, L2 Loss: 0.2316, time: 106.7472
Epoch [191/600], Training Loss: 2.8396, vgg Loss: 259.5612, L2 Loss: 0.2440, tim

Epoch [271/600], Training Loss: 2.2171, vgg Loss: 205.9955, L2 Loss: 0.1572, time: 106.8413
Epoch [272/600], Training Loss: 2.2415, vgg Loss: 207.2918, L2 Loss: 0.1686, time: 106.7240
Epoch [273/600], Training Loss: 2.1787, vgg Loss: 203.9000, L2 Loss: 0.1397, time: 106.6999
Epoch [274/600], Training Loss: 2.1950, vgg Loss: 205.7212, L2 Loss: 0.1378, time: 107.0441
Epoch [275/600], Training Loss: 2.2601, vgg Loss: 209.9336, L2 Loss: 0.1607, time: 106.5435
Epoch [276/600], Training Loss: 2.2418, vgg Loss: 210.0637, L2 Loss: 0.1412, time: 106.8367
Epoch [277/600], Training Loss: 2.3378, vgg Loss: 214.5841, L2 Loss: 0.1919, time: 106.6328
Epoch [278/600], Training Loss: 2.2712, vgg Loss: 211.3019, L2 Loss: 0.1582, time: 106.6286
Epoch [279/600], Training Loss: 2.2523, vgg Loss: 208.5473, L2 Loss: 0.1668, time: 106.7937
Epoch [280/600], Training Loss: 2.2473, vgg Loss: 208.6768, L2 Loss: 0.1605, time: 107.1040
Epoch [281/600], Training Loss: 2.1788, vgg Loss: 203.4180, L2 Loss: 0.1446, tim

Epoch [361/600], Training Loss: 1.9056, vgg Loss: 180.3900, L2 Loss: 0.1017, time: 126.5091
Epoch [362/600], Training Loss: 1.8824, vgg Loss: 178.4774, L2 Loss: 0.0976, time: 126.5623
Epoch [363/600], Training Loss: 1.8357, vgg Loss: 174.2763, L2 Loss: 0.0929, time: 126.4609
Epoch [364/600], Training Loss: 1.8669, vgg Loss: 176.8188, L2 Loss: 0.0987, time: 126.8896
Epoch [365/600], Training Loss: 1.8509, vgg Loss: 175.5612, L2 Loss: 0.0953, time: 126.5665
Epoch [366/600], Training Loss: 1.8774, vgg Loss: 177.9319, L2 Loss: 0.0981, time: 126.6266
Epoch [367/600], Training Loss: 1.8524, vgg Loss: 175.8039, L2 Loss: 0.0944, time: 126.5787
Epoch [368/600], Training Loss: 1.8834, vgg Loss: 178.5221, L2 Loss: 0.0982, time: 126.5863
Epoch [369/600], Training Loss: 1.8017, vgg Loss: 170.8948, L2 Loss: 0.0928, time: 126.4530
Epoch [370/600], Training Loss: 1.8503, vgg Loss: 175.7437, L2 Loss: 0.0928, time: 126.3449
Epoch [371/600], Training Loss: 1.8532, vgg Loss: 175.8519, L2 Loss: 0.0947, tim

Epoch [451/600], Training Loss: 1.7782, vgg Loss: 168.8057, L2 Loss: 0.0901, time: 102.4794
Epoch [452/600], Training Loss: 1.7661, vgg Loss: 168.2996, L2 Loss: 0.0831, time: 106.4496
Epoch [453/600], Training Loss: 1.8176, vgg Loss: 172.6759, L2 Loss: 0.0909, time: 105.6013
Epoch [454/600], Training Loss: 1.8006, vgg Loss: 170.9950, L2 Loss: 0.0906, time: 107.0261
Epoch [455/600], Training Loss: 1.7672, vgg Loss: 167.6921, L2 Loss: 0.0903, time: 106.2017
Epoch [456/600], Training Loss: 1.7604, vgg Loss: 167.0941, L2 Loss: 0.0895, time: 106.8772
Epoch [457/600], Training Loss: 1.8075, vgg Loss: 171.4257, L2 Loss: 0.0933, time: 106.1299
Epoch [458/600], Training Loss: 1.7919, vgg Loss: 170.2957, L2 Loss: 0.0889, time: 111.5677
Epoch [459/600], Training Loss: 1.7921, vgg Loss: 170.0140, L2 Loss: 0.0920, time: 104.9966
Epoch [460/600], Training Loss: 1.7259, vgg Loss: 164.4544, L2 Loss: 0.0814, time: 104.5332
Epoch [461/600], Training Loss: 1.8001, vgg Loss: 171.3415, L2 Loss: 0.0867, tim

Epoch [541/600], Training Loss: 1.8010, vgg Loss: 171.0112, L2 Loss: 0.0909, time: 102.2845
Epoch [542/600], Training Loss: 1.7746, vgg Loss: 169.4372, L2 Loss: 0.0802, time: 102.4294
Epoch [543/600], Training Loss: 1.7769, vgg Loss: 169.5301, L2 Loss: 0.0816, time: 105.6822
Epoch [544/600], Training Loss: 1.7335, vgg Loss: 164.9791, L2 Loss: 0.0837, time: 107.1197
Epoch [545/600], Training Loss: 1.7918, vgg Loss: 170.3321, L2 Loss: 0.0885, time: 111.0040
Epoch [546/600], Training Loss: 1.8050, vgg Loss: 171.6841, L2 Loss: 0.0882, time: 107.5666
Epoch [547/600], Training Loss: 1.7251, vgg Loss: 163.8823, L2 Loss: 0.0863, time: 119.9624
Epoch [548/600], Training Loss: 1.7357, vgg Loss: 165.0745, L2 Loss: 0.0850, time: 135.9485
Epoch [549/600], Training Loss: 1.7417, vgg Loss: 166.0554, L2 Loss: 0.0811, time: 116.8246
Epoch [550/600], Training Loss: 1.7267, vgg Loss: 164.5173, L2 Loss: 0.0815, time: 103.5070
Epoch [551/600], Training Loss: 1.7801, vgg Loss: 168.7897, L2 Loss: 0.0922, tim