In [1]:
import torch
import torch.nn as nn
import imageio
import matplotlib.pyplot as plt
import os
import numpy as np
import time
import torchvision.transforms.functional as F
from torch.nn.functional import interpolate as interpolate
from PIL import Image

resize_scale=1.25
training_height = 128
training_width = 128
shrink_source = 1

In [2]:
class RURCNN(nn.Module):
    def __init__(self):
        super(RURCNN, self).__init__()
        
        self.ratio=nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.ZeroPad2d((1,0,1,0)),
            nn.MaxPool2d(kernel_size=2, stride=1),
            nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True)
        )
        
        self.assumption=nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=1, stride=1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 64, kernel_size=1, stride=1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 16, kernel_size=1, stride=1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(16, 3, kernel_size=1, stride=1, padding=0, bias=True)
        )
        
    def forward(self, x, h, w):
        #origin = interpolate(x, size=(h,w), scale_factor=None, mode='nearest', align_corners=None)
        ratio = self.ratio(x)
        #assumption = self.assumption(x) 
        result = x*ratio
        #result = assumption
        return result
    
    def init_weight(self):
        
        for layer in self.ratio:
            if hasattr(layer,"weight") :
                layer.weight.data.normal_(mean=0,std=0.01)
            if hasattr(layer,"bias") and layer.bias is not None:
                layer.bias.data.normal_(mean=0,std=0.01)
        
        for layer in self.assumption:
            if hasattr(layer,"weight") :
                layer.weight.data.normal_(mean=1,std=0.01)
            if hasattr(layer,"bias") and layer.bias is not None:
                layer.bias.data.normal_(mean=0,std=0.01)

In [3]:
class Dataloader():
    def __init__(self,path):
        self.load_img(path)
        
    def load_img(self,path):
        imgs = []
        for root, dirs, files in os.walk(path):
            for file in files:
                #img = imageio.imread(root+'\\'+file)\
                img = Image.open(root+'\\'+file)
                width, height = img.size
                width = int(width*shrink_source)
                height = int(height*shrink_source)
                img = F.resize(img,(height,width))
                simg = F.resize(img,(int(height/resize_scale)+1,int(width/resize_scale)+1))
                simg = F.resize(simg,(height,width))
                
                img = F.to_tensor(img)/255.0
                simg = F.to_tensor(simg)/255.0
                timg = torch.as_tensor(img,dtype=torch.float32)
                simg = torch.as_tensor(simg,dtype=torch.float32)
                timg.unsqueeze_(0)
                simg.unsqueeze_(0)
                
                height_num = 2*int(height/training_height)-1
                width_num = 2*int(width/training_width)-1
                height_stride = height/(height_num+1)
                width_stride = width/(width_num+1)
                
                divided_imgs = [] 
                
                for i in range(height_num):
                    for j in range(width_num):
                        h_start = int(height_stride*i)
                        w_start = int(width_stride*j)
                        dtimg = timg[:,:,h_start:h_start+training_height,w_start:w_start+training_width]
                        dsimg = simg[:,:,h_start:h_start+training_height,w_start:w_start+training_width]
                        divided_imgs.append((dsimg, dtimg))
                
                imgs.extend(divided_imgs)
                del img
        self.imgs=imgs
    
    def get_minibatch(self, minibatch_size):
        chosen_idxs = np.random.randint(len(self.imgs), size = minibatch_size)
        chosen_imgs = [self.imgs[chosen_idx] for chosen_idx in chosen_idxs]
        return chosen_imgs
        

In [4]:
minibatch_size = 1000
epoch = 10

time1 = time.time()
rurcnn = RURCNN().to("cuda")
rurcnn.init_weight()
loader=Dataloader("01-64")
l1loss = nn.L1Loss(reduction="sum")
mse = nn.MSELoss()
optimizer = torch.optim.SGD(rurcnn.parameters(), lr= 0.001, momentum=0.9)
for e in range(epoch):
    minibatch=loader.get_minibatch(minibatch_size)
    loss_sum=0
    for (simg, img) in minibatch:
        optimizer.zero_grad()
        simg = simg.to("cuda")
        img = img.to("cuda")
        pimg = rurcnn(simg, img.shape[2],img.shape[3])
        loss = l1loss(img, pimg)
        loss.backward()
        optimizer.step()
        loss_sum += float(loss.item())
        del loss, simg, img
    print("epoch ",e," complete ","avg_loss=",loss_sum/minibatch_size)
    

epoch  0  complete  avg_loss= 16.787336539506914
epoch  1  complete  avg_loss= 14.334095644533635
epoch  2  complete  avg_loss= 14.17923548734188
epoch  3  complete  avg_loss= 15.998909077316522
epoch  4  complete  avg_loss= 15.392855456769466
epoch  5  complete  avg_loss= 16.518263228356837
epoch  6  complete  avg_loss= 14.060387136876583
epoch  7  complete  avg_loss= 16.741891341552137
epoch  8  complete  avg_loss= 17.50336689823866
epoch  9  complete  avg_loss= 15.265599391102791


In [5]:
minibatch=loader.get_minibatch(minibatch_size)
for i,img in enumerate(minibatch):
    simg = resize(img,int(img.shape[2]/resize_scale,img.shape[3]/resize_scale))
    pimg = rurcnn(simg)
    pimg.squeeze()
    imageio.imwrite("test_img\\"+i+".jpg", pimg.numpy())

NameError: name 'resize' is not defined