In [None]:
# Pytorch
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torch.nn.functional as F

# Others
import numpy as np
import cv2
import random
import time
import os
from torchsummary import summary
from IPython.display import display


In [None]:
# PReNet
class PReNet(nn.Module):
    def __init__(self, recurrent=6, use_GPU=False):
        super(PReNet, self).__init__()
        self.iteration = recurrent
        self.use_GPU = use_GPU
    
        # F(in)
        self.conv0 = nn.Sequential(
            nn.Conv2d(6, 32, 3, 1, 1),
            nn.ReLU()
        )
        
        # 五個 ResBlocks: F(res)
        self.res_conv1 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        
        self.res_conv2 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        
        self.res_conv3 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        
        self.res_conv4 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        
        self.res_conv5 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.conv_i = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, 1, 1),
            nn.Sigmoid()
            )
        self.conv_f = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, 1, 1),
            nn.Sigmoid()
            )
        self.conv_g = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, 1, 1),
            nn.Tanh()
            )
        self.conv_o = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, 1, 1),
            nn.Sigmoid()
            )
        # 最後的 convolutional layer: F(out)
        self.conv = nn.Sequential(
            nn.Conv2d(32, 3, 3, 1, 1)
        )
       
    # 定義 network 執行方法
    def forward(self, input):
        batch_size, row, col = input.size(0), input.size(2), input.size(3)
        
        x_list = []
        h = Variable(torch.zeros(batch_size, 32, row, col))
        c = Variable(torch.zeros(batch_size, 32, row, col))
        x = input

        if self.use_GPU:
          print('network cuda is available')
          h = h.cuda()
          c = c.cuda()

        for i in range(self.iteration):
            print("iteration: ", i)

            # ====================== F(in) start ====================== #

            # 合併 input image & 上一次 PReNet 結果
            x = torch.cat((input, x), 1)

            x = self.conv0(x)
            
            x = torch.cat((x, h), 1)
            i = self.conv_i(x)
            f = self.conv_f(x)
            g = self.conv_g(x)
            o = self.conv_o(x)
            c = f * c + i * g
            h = o * torch.tanh(c)

            x = h
            # ====================== F(in) finish ====================== #
            
            # ====================== F(res) start ====================== #
            resX = x

            # 有五個相同內容的 residual blocks
            x = F.relu(self.res_conv1(x) + resX)
            resX = x

            x = F.relu(self.res_conv2(x) + resX)
            resX = x

            x = F.relu(self.res_conv3(x) + resX)
            resX = x

            x = F.relu(self.res_conv4(x) + resX)
            resX = x

            x = F.relu(self.res_conv5(x) + resX)
            # ====================== F(res) finish ====================== #
            
            # ====================== F(out) start ====================== #
            
            x = self.conv(x)
            
            # ====================== F(res) finish ====================== #
            
            x_list.append(x)
            
        return x

In [None]:
'''net = PReNet(6)
#640, 480
summary(net, (3, 540, 405))'''

'net = PReNet(6)\n#640, 480\nsummary(net, (3, 540, 405))'