In [None]:
import torch
import torchvision
import glob
from PIL import Image
import matplotlib.pyplot as plt
import random
import os

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
ANSWER_SIZE = 128
WINDOW_SIZE = ANSWER_SIZE / 2
MINI_BATCH = 10
EPOCH = 500
DATASET_DIR = '/content/drive/MyDrive/Colab Notebooks/super_resolution_ai/dataset/images/*'
OUTPUT_DIR = '/content/drive/MyDrive/Colab Notebooks/super_resolution_ai/outputs/'

In [None]:
files = glob.glob(DATASET_DIR)
train_files = files[0:20]
valid_files = files[20:30]
print('{} images loaded.(train:valid={}:{})'.format(len(files),len(train_files),len(valid_files)))

30 images loaded.(train:valid=20:10)


In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, files):
        self.files = files
        self.toTensor = torchvision.transforms.ToTensor()
        self.randomCrop = torchvision.transforms.RandomCrop(ANSWER_SIZE, padding = ANSWER_SIZE-1)
        self.augmentation = torchvision.transforms.Compose([
            torchvision.transforms.ColorJitter(),
            torchvision.transforms.RandomGrayscale(p=0.1),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomVerticalFlip(),
            torchvision.transforms.RandomInvert(p=0.1)
        ])
        self.downsize = torchvision.transforms.Resize(int(ANSWER_SIZE/2))
        self.upsize = torchvision.transforms.Resize(int(ANSWER_SIZE))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # read image
        image = Image.open(self.files[idx])
        image = self.toTensor(image)
        # random crop
        image = self.randomCrop(image)
        # augmentation
        large = self.augmentation(image)
        # downsize
        small = self.downsize(large)
        small = self.upsize(small)
        return small, large

In [37]:
class UpSampleNet(torch.nn.Module):
    def __init__(self):
        super(UpSampleNet, self).__init__()
        self.input = self.output = 0
        self.layer1 = self.layer2 = self.layer3 = self.layer4 = self.layer5 = 0
        
        self.step1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=32,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
        )

        self.step2 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2,stride=2),
            torch.nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
        )

        self.step3 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2,stride=2),
            torch.nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(128), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(128), torch.nn.ReLU(),
        )

        self.deconv1 = torch.nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2)

        self.step4 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=128,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
        )

        self.deconv2 = torch.nn.ConvTranspose2d(in_channels=64,out_channels=32,kernel_size=2,stride=2)

        self.step5 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=64,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=32,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=32,out_channels=3,kernel_size=3,padding=1),
            torch.nn.Sigmoid()
        )

    def forward(self, small):
        self.input = small
        # Encode
        self.layer1 = self.step1(self.input)
        self.layer2 = self.step2(self.layer1)
        self.layer3 = self.step3(self.layer2)
        # Decode
        self.layer3 = self.deconv1(self.layer3)
        self.layer4 = self.step4(torch.cat((self.layer3,self.layer2),dim=1))
        self.layer4 = self.deconv2(self.layer4)
        self.layer5 = self.step5(torch.cat((self.layer4,self.layer1),dim=1))
        self.output = self.layer5
        return self.output

In [None]:
train_dataset, valid_dataset = Dataset(train_files), Dataset(valid_files)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=MINI_BATCH,num_workers=os.cpu_count(),pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=MINI_BATCH,num_workers=os.cpu_count(),pin_memory=True)
train_history, valid_history = {'loss':[]}, {'loss':[]}

net = UpSampleNet()
optim = torch.optim.Adam(params=net.parameters(),lr=0.0001)
loss_function = torch.nn.MSELoss()

In [None]:
def train_valid(mode, model, loader, loss_function, optim, history):
    train = True if mode == 'train' else False
    model.train(train)
    loss_sum = 0

    for small, large in loader:
        if train: optim.zero_grad()
        outputs = model(small)
        loss = loss_function(outputs, large)

        loss_sum += loss.item()

        if train:
            loss.backward()
            optim.step()
        
    history['loss'].append(loss_sum / len(loader))
    
    log = '[Train]' if train else '[Valid]'
    log += ' loss:' + str(history['loss'][-1])
    print(log)

def plot_graph(train_loss, valid_loss):
    plt.figure()
    plt.ioff()
    plt.plot(range(1, len(train_loss)+1), train_loss, label = 'train')
    plt.plot(range(1, len(valid_loss)+1), valid_loss, label = 'valid')
    plt.title('Loss')
    plt.xlabel('epoch')
    plt.legend()
    plt.savefig(OUTPUT_DIR + 'valid/' + '{}epoch_loss.png'.format(len(train_loss)), facecolor='white')
    plt.clf()

In [None]:
for epoch in range(EPOCH):
    print('--- {} Epoch ---'.format(epoch+1))
    train_valid('train',net,train_loader,loss_function,optim,train_history)
    train_valid('valid',net,valid_loader,loss_function,optim,valid_history)
    
    if min(valid_history['loss']) == valid_history['loss'][-1]:
        torch.save(net.state_dict(), OUTPUT_DIR + 'train/' +'best_model.pth')

    if (epoch+1) % 50 == 0:
        plot_graph(train_history['loss'], valid_history['loss'])

print('--- finish ---')
epoch = valid_history['loss'].index(min(valid_history['loss']))+1
print('The best epoch of valid loss : Epoch:{}', epoch)

--- 1 Epoch ---
[Train] loss:0.140010304749012
[Valid] loss:0.10648511350154877
--- 2 Epoch ---
[Train] loss:0.11071434617042542
[Valid] loss:0.14306935667991638
--- 3 Epoch ---
[Train] loss:0.12647808343172073
[Valid] loss:0.13729329407215118
--- 4 Epoch ---
[Train] loss:0.10051754489541054
[Valid] loss:0.1196698248386383
--- 5 Epoch ---
[Train] loss:0.09029032289981842
[Valid] loss:0.07534029334783554
--- 6 Epoch ---
[Train] loss:0.09513825178146362
[Valid] loss:0.15446002781391144
--- 7 Epoch ---
[Train] loss:0.06996199674904346
[Valid] loss:0.14027120172977448
--- 8 Epoch ---
[Train] loss:0.07038303650915623
[Valid] loss:0.10818600654602051
--- 9 Epoch ---
[Train] loss:0.0647780504077673
[Valid] loss:0.10142547637224197
--- 10 Epoch ---
[Train] loss:0.05939783528447151
[Valid] loss:0.12494337558746338
--- 11 Epoch ---
[Train] loss:0.05139869451522827
[Valid] loss:0.0863843634724617
--- 12 Epoch ---
[Train] loss:0.06841539032757282
[Valid] loss:0.08641985803842545
--- 13 Epoch ---
[

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>