In [1]:
# Pytorch
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torch.nn.functional as F

# Others
import numpy as np
import cv2
import random
import time
import os
from torchsummary import summary
from IPython.display import display


In [2]:
# PReNet
class PReNet(nn.Module):
    def __init__(self, recurrent):
        super(PReNet, self).__init__()
        self.iteration = recurrent
    
        # F(in)
        self.conv0 = nn.Sequential(
            nn.Conv2d(6, 32, 3, 1, 1),
            nn.ReLU()
        )
        
        # 五個 ResBlocks: F(res)
        self.res_conv1 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        
        self.res_conv2 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        
        self.res_conv3 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        
        self.res_conv4 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        
        self.res_conv5 = nn.Sequential(
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU()
        )
        self.conv_i = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, 1, 1),
            nn.Sigmoid()
            )
        self.conv_f = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, 1, 1),
            nn.Sigmoid()
            )
        self.conv_g = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, 1, 1),
            nn.Tanh()
            )
        self.conv_o = nn.Sequential(
            nn.Conv2d(32 + 32, 32, 3, 1, 1),
            nn.Sigmoid()
            )
        # 最後的 convolutional layer: F(out)
        self.conv = nn.Sequential(
            nn.Conv2d(32, 3, 3, 1, 1)
        )
       
    # 定義 network 執行方法
    def forward(self, input):
        batch_size, row, col = input.size(0), input.size(2), input.size(3)
        
        x_list = []
        h = Variable(torch.zeros(batch_size, 32, row, col))
        c = Variable(torch.zeros(batch_size, 32, row, col))
        
        x = input

        for i in range(self.iteration):
            print("iteration: ", i)

            # ====================== F(in) start ====================== #

            # 合併 input image & 上一次 PReNet 結果
            x = torch.cat((input, x), 1)
            #print('x size: ', x.size())
            x = self.conv0(x)
            #print('conv0, x.size = ', x.size())
            
            x = torch.cat((x, h), 1)
            #print('1, x.size = ', x.size())
            i = self.conv_i(x)
            #print('i.size = ', i.size())
            f = self.conv_f(x)
            #print('f.size = ', i.size())
            g = self.conv_g(x)
            #print('g.size = ', i.size())
            o = self.conv_o(x)
            #print('o.size = ', i.size())
            c = f * c + i * g
            #print('c.size = ', c.size())
            h = o * torch.tanh(c)
            #print('h.size = ', h.size())

            x = h
            # ====================== F(in) finish ====================== #
            
            # ====================== F(res) start ====================== #
            resX = x
            #print('resX.size = ', resX.size())
            # 有五個相同內容的 residual blocks
            x = F.relu(self.res_conv1(x) + resX)
            resX = x
            #print('conv1, x.size = ', x.size())
            x = F.relu(self.res_conv2(x) + resX)
            resX = x
            #print('conv2')
            x = F.relu(self.res_conv3(x) + resX)
            resX = x
            #print('conv3')
            x = F.relu(self.res_conv4(x) + resX)
            resX = x
            #print('conv4')
            x = F.relu(self.res_conv5(x) + resX)
            # ====================== F(res) finish ====================== #
            
            # ====================== F(out) start ====================== #
            
            x = self.conv(x)
            
            # ====================== F(res) finish ====================== #
            
            x_list.append(x)
            '''
            print(x.shape)

            o = o.reshape(1, 3, 100, 200)
            o = x.squeeze(0)
            print(o.shape)
            new_img = transforms.ToPILImage()(o).convert('RGB')

            trns = transforms.Compose([transforms.Scale((640,480))])
            new_img = trns(new_img)
            display(new_img)  
            '''
        return x, x_list

In [3]:
#net = PReNet(6)
#summary(net, (3, 25, 25))