In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
import torch.nn.init as init
import time

### parameter ###
batchSize = 64
setEpoch = 300
### parameter ###

### dataset ###

#Set normalizer
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
#Set transform function
transform_train = transforms.Compose(
    [transforms.ToTensor(),
     normalize])
# transform_train = transforms.Compose(
#     [transforms.RandomCrop(32),
#      transforms.RandomHorizontalFlip(),
#      transforms.ToTensor(),
#      normalize])
transform_test = transforms.Compose(
    [transforms.ToTensor(),
     normalize])

#set dataset  
trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                        download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=transform_test)
#set loader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batchSize,
                                          shuffle=True, num_workers=2, pin_memory=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batchSize,
                                         shuffle=False, num_workers=2,pin_memory=True)
#set class label on dataset
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#import matplotlib.pyplot as plt
#import numpy as np


Files already downloaded and verified
Files already downloaded and verified


In [2]:
#쓸데없이 pixelShuffle 있는지 모르고 만듬...ㅠㅠㅠㅠ

# channel = 1
# orig = 3
# scale = 4
# a = torch.rand(1,channel*scale*scale, orig, orig)
# print(a)
# b = torch.rand(1,channel, orig*scale, orig*scale)
# c = torch.rand(1,channel, orig*scale, orig*scale)
# for c in range(channel):
#     for i in range(orig) :
#         for j in range(orig) :
#             for k in range(scale):
#                 for l in range(scale):
#                     b[0,c, scale*i+k, scale*j+l] = a[0,(c*channel + k)*scale + l, i, j]
# ps=nn.PixelShuffle(scale)
# print(ps(a))
# # for i in range(5) :
# #     for j in range(5) :
# #         c[scale*i+0, 2*j+0] = a[0, i, j]
# #         c[scale*i+0, 2*j+1] = a[1, i, j]
# #         c[scale*i+1, 2*j+0] = a[2, i, j]
# #         c[scale*i+1, 2*j+1] = a[3, i, j]
# print(b)
# print(c)

In [7]:
#define NN

#노쓸모
# class upNet(nn.Conv2d):
#     def __init__(self):
        
#     def forward(self, x, scale, outchannel):
#         channel = 3
#         orig = 3
#         scale = 4
#         a = torch.rand(channel*scale*scale, orig, orig)
#         print(a)
#         b = torch.rand(channel, orig*scale, orig*scale)
#         c = torch.rand(orig*scale, orig*scale)
#         for c in range(channel):
#             for i in range(orig) :
#                 for j in range(orig) :
#                     for k in range(scale):
#                         for l in range(scale):
#                             b[c, scale*i+k, scale*j+l] = a[(c*channel + k)*scale + l, i, j]
        
class RDB(nn.Module):
    def __init__(self, inChannel, growthRate, layerDepth):
        super(RDB, self).__init__()
        
        self.layerDepth = layerDepth
        outChannel = inChannel
        
        layers = []
        for i in range(layerDepth):
            # 3x3 conv
            layers.append(nn.Conv2d(inChannel, growthRate, kernel_size=3, padding = 1, bias=False))
            inChannel += growthRate
        self.rdbLayers = nn.ModuleList(layers)
     
        #local feature fusion
        self.conv1x1=nn.Conv2d(inChannel, outChannel, kernel_size = 1)
    def forward(self, x):
        localRes = x
        
        for i in range(self.layerDepth):
            out = self.rdbLayers[i](x)
            out = F.relu(out)
            x = torch.cat((x,out),dim=1)
        
        #local feature fusion
        x = self.conv1x1(x)
        
        #local residual learning
        x = localRes + x
        return x
    
class RDN(nn.Module):

    def __init__(self, colorChannel, scale):
        super(RDN, self).__init__()
        #parameter
        self.growthRate = 64
        self.g0 = 64
        self.blockDepth = 16
        self.convDepth = 8
        
        #SFE
        self.sfe1 = nn.Conv2d(colorChannel, self.g0 , kernel_size=3, padding = 1, bias = False)
        self.sfe2 = nn.Conv2d(self.g0, self.g0 , kernel_size=3, padding = 1, bias = False)

        #RDBs
        layers = []
        for i in range(self.blockDepth):
            layers.append(RDB(self.g0, self.growthRate, self.convDepth))
        self.rdbs = nn.ModuleList(layers)
        
        #DFF
        self.gff = nn.Sequential(nn.Conv2d(self.g0 * self.blockDepth, self.g0, kernel_size=1, bias = False),
            nn.Conv2d(self.g0, self.g0 , kernel_size=3, padding = 1, bias = False))
        
        #upnet
        if (scale == 4) :
            self.pixelShuffle = nn.Sequential(
                nn.Conv2d(self.g0, self.g0*(2**2), kernel_size=3, padding = 1, bias = False),
                nn.PixelShuffle(2),
                nn.Conv2d(self.g0, self.g0*(2**2), kernel_size=3, padding = 1, bias = False),
                nn.PixelShuffle(2))
        else :
            self.pixelShuffle = nn.Sequential(
                nn.Conv2d(self.g0, self.g0*(scale**2), kernel_size=3, padding = 1, bias = False),
                nn.PixelShuffle(scale))

        #HR
        self.convHR = nn.Conv2d(self.g0, colorChannel, kernel_size=3, padding = 1, bias = False)

    def forward(self, x):
        #SFENet
        x = self.sfe1(x)
        out = self.sfe2(x)
        #RDBs
        rdbResult=[]
        for i in range(self.blockDepth) :
            out = self.rdbs[i](out)
            rdbResult.append(out)

        #concat
        rdbResult = tuple(rdbResult)
        out = torch.cat(rdbResult,dim=1)

        #DFF
        out = self.gff(out)
        #Global residual
        x = x + out
        
        #UPNet
        x = self.pixelShuffle(x)

        #HR
        x = self.convHR(x)
        
        return x

In [8]:
model = RDN(3,4)
device = torch.device("cuda:0")
model.to(device)
summary(model,(3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
            Conv2d-2           [-1, 64, 32, 32]          36,864
            Conv2d-3           [-1, 64, 32, 32]          36,864
            Conv2d-4           [-1, 64, 32, 32]          73,728
            Conv2d-5           [-1, 64, 32, 32]         110,592
            Conv2d-6           [-1, 64, 32, 32]         147,456
            Conv2d-7           [-1, 64, 32, 32]         184,320
            Conv2d-8           [-1, 64, 32, 32]         221,184
            Conv2d-9           [-1, 64, 32, 32]         258,048
           Conv2d-10           [-1, 64, 32, 32]         294,912
           Conv2d-11           [-1, 64, 32, 32]          36,928
              RDB-12           [-1, 64, 32, 32]               0
           Conv2d-13           [-1, 64, 32, 32]          36,864
           Conv2d-14           [-1, 64,