In [17]:
import torch
import torchvision
import glob
from PIL import Image
import matplotlib.pyplot as plt
import random
import os

In [18]:
ANSWER_SIZE = 128
WINDOW_SIZE = ANSWER_SIZE / 2
MINI_BATCH = 10
EPOCH = 500

In [19]:
files = glob.glob('./dataset/images/*')
train_files = files[0:20]
valid_files = files[20:30]
print('{} images loaded.(train:valid={}:{})'.format(len(files),len(train_files),len(valid_files)))

30 images loaded.(train:valid=20:10)


In [20]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, files):
        self.files = files
        self.toTensor = torchvision.transforms.ToTensor()
        self.randomCrop = torchvision.transforms.RandomCrop(ANSWER_SIZE, padding = ANSWER_SIZE-1)
        self.augmentation = torchvision.transforms.Compose([
            torchvision.transforms.ColorJitter(),
            torchvision.transforms.RandomGrayscale(p=0.1),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomVerticalFlip(),
            torchvision.transforms.RandomInvert(p=0.1)
        ])
        self.downsize = torchvision.transforms.Resize(int(ANSWER_SIZE/2))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # read image
        image = Image.open(self.files[idx])
        image = self.toTensor(image)
        # random crop
        image = self.randomCrop(image)
        # augmentation
        large = self.augmentation(image)
        # downsize
        small = self.downsize(large)
        return small, large

In [21]:
class UpSampleNet(torch.nn.Module):
    def __init__(self):
        super(UpSampleNet, self).__init__()
        self.input = self.output = 0
        self.layer1 = self.layer2 = self.layer3 = self.layer4 = self.layer5 = 0
        
        self.step1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=32,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
        )

        self.step2 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2,stride=2),
            torch.nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
        )

        self.step3 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2,stride=2),
            torch.nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(128), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(128), torch.nn.ReLU(),
        )

        self.deconv1 = torch.nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2)

        self.step4 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=128,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
        )

        self.deconv2 = torch.nn.ConvTranspose2d(in_channels=64,out_channels=32,kernel_size=2,stride=2)

        self.step5 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=64,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=32,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
        )

        self.deconv3 = torch.nn.ConvTranspose2d(in_channels=32,out_channels=16,kernel_size=2,stride=2)

        self.step6 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=16,out_channels=8,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(8), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=8,out_channels=3,kernel_size=3,padding=1),
            torch.nn.Sigmoid()
        )

    def forward(self, small):
        self.input = small
        # Encode
        self.layer1 = self.step1(self.input)
        self.layer2 = self.step2(self.layer1)
        self.layer3 = self.step3(self.layer2)
        # Decode
        self.layer3 = self.deconv1(self.layer3)
        self.layer4 = self.step4(torch.cat((self.layer3,self.layer2),dim=1))
        self.layer4 = self.deconv2(self.layer4)
        self.layer5 = self.step5(torch.cat((self.layer4,self.layer1),dim=1))
        self.layer5 = self.deconv3(self.layer5)
        self.output = self.step6(self.layer5)
        return self.output


In [22]:
train_dataset, valid_dataset = Dataset(train_files), Dataset(valid_files)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=MINI_BATCH,num_workers=os.cpu_count(),pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=MINI_BATCH,num_workers=os.cpu_count(),pin_memory=True)
train_history, valid_history = {'loss':[]}, {'loss':[]}

net = UpSampleNet()
optim = torch.optim.Adam(params=net.parameters(),lr=0.0001)
loss_function = torch.nn.MSELoss()

In [23]:
from zipfile import LargeZipFile


def train_valid(mode, model, loader, loss_function, optim, history):
    train = True if mode == 'train' else False
    model.train(train)
    loss_sum = 0

    for small, large in loader:
        if train: optim.zero_grad()
        outputs = model(small)
        loss = loss_function(outputs, large)

        loss_sum += loss.item()

        if train:
            loss.backward()
            optim.step()
        
    history['loss'].append(loss_sum / len(loader))
    
    log = '[Train]' if train else '[Valid]'
    log += ' loss:' + str(history['loss'][-1])
    print(log)

def infer(model, small, padding=WINDOW_SIZE-1, stride=WINDOW_SIZE):
    model.train(False)
    h,w = len(small[0]), len(small[0,0])
    small = torchvision.transforms.functional.pad(small, padding)
    large = torch.zeros(len(small),len(small[0])*2,len(small[0,0]*2))

    top = left = 0
    while len(small[0]) - top <= WINDOW_SIZE:
        while len(small[0,0]) - left <= WINDOW_SIZE:
            input = torchvision.transforms.functional.crop(small,top,left,WINDOW_SIZE,WINDOW_SIZE)
            large[:,top*2:top*2+ANSWER_SIZE,left*2:left*2+ANSWER_SIZE] = model(input)
    
    large = torchvision.transforms.functional.center_crop(large,[h,w])
    return large

def plot_graph(train_loss, valid_loss):
    plt.figure()
    plt.ioff()
    plt.plot(range(1, len(train_loss)+1), train_loss, label = 'train')
    plt.plot(range(1, len(valid_loss)+1), valid_loss, label = 'valid')
    plt.title('Loss')
    plt.xlabel('epoch')
    plt.legend()
    plt.savefig('./outputs/valid/' + '{}epoch_loss.png'.format(len(train_loss)), facecolor='white')
    plt.clf()

In [24]:
for epoch in range(EPOCH):
    print('--- {} Epoch ---'.format(epoch+1))
    train_valid('train',net,train_loader,loss_function,optim,train_history)
    train_valid('valid',net,valid_loader,loss_function,optim,valid_history)
    
    if min(valid_history['loss']) == valid_history['loss'][-1]:
        torch.save(net.state_dict(), './outputs/'+'best_model.pth')

    if (epoch+1) % 50 == 0:
        plot_graph(train_history['loss'], valid_history['loss'])

print('--- finish ---')
epoch = valid_history['loss'].index(min(valid_history['loss']))+1
print('The best epoch of valid loss : Epoch:{}', epoch)

--- 1 Epoch ---
[Train] loss:0.1551910862326622
[Valid] loss:0.10784023255109787
--- 2 Epoch ---
[Train] loss:0.12218901515007019
[Valid] loss:0.08526812493801117
--- 3 Epoch ---
[Train] loss:0.13182483986020088
[Valid] loss:0.11819498240947723
--- 4 Epoch ---
[Train] loss:0.11450852826237679
[Valid] loss:0.11753392219543457
--- 5 Epoch ---
[Train] loss:0.11182986199855804
[Valid] loss:0.10486562550067902
--- 6 Epoch ---
[Train] loss:0.11924086138606071
[Valid] loss:0.128995880484581
--- 7 Epoch ---
[Train] loss:0.12303702160716057
[Valid] loss:0.09180113673210144
--- 8 Epoch ---
[Train] loss:0.11128494143486023
[Valid] loss:0.09868933260440826
--- 9 Epoch ---
[Train] loss:0.0932990163564682
[Valid] loss:0.1269991397857666
--- 10 Epoch ---
[Train] loss:0.11625882983207703
[Valid] loss:0.09036464989185333
--- 11 Epoch ---
[Train] loss:0.10877098515629768
[Valid] loss:0.12136247754096985
--- 12 Epoch ---
[Train] loss:0.09628643095493317
[Valid] loss:0.12366890162229538
--- 13 Epoch ---
[