In [2]:
import torch
import torchvision
import glob
from PIL import Image
import matplotlib.pyplot as plt
import random
import os

In [3]:
from google.colab import drive
drive.mount('/content/drive')
DATA_DIR = '/content/drive/MyDrive/Colab Notebooks/super_resolution_ai/dataset/test_images/'
OUTPUT_DIR = '/content/drive/MyDrive/Colab Notebooks/super_resolution_ai/outputs/infer/'

Mounted at /content/drive


In [4]:
files = glob.glob(DATA_DIR + '*')
print('{} images loaded.'.format(len(files)))

10 images loaded.


In [5]:
class UpSampleNet(torch.nn.Module):
    def __init__(self):
        super(UpSampleNet, self).__init__()
        self.input = self.output = 0
        self.layer1 = self.layer2 = self.layer3 = self.layer4 = self.layer5 = 0
        
        self.step1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=32,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
        )

        self.step2 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2,stride=2),
            torch.nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
        )

        self.step3 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2,stride=2),
            torch.nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(128), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(128), torch.nn.ReLU(),
        )

        self.deconv1 = torch.nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=2,stride=2)

        self.step4 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=128,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(64), torch.nn.ReLU(),
        )

        self.deconv2 = torch.nn.ConvTranspose2d(in_channels=64,out_channels=32,kernel_size=2,stride=2)

        self.step5 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=64,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=32,out_channels=32,kernel_size=3,padding=1),
            torch.nn.BatchNorm2d(32), torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=32,out_channels=3,kernel_size=3,padding=1),
            torch.nn.Sigmoid()
        )

    def forward(self, small):
        self.input = small
        # Encode
        self.layer1 = self.step1(self.input)
        self.layer2 = self.step2(self.layer1)
        self.layer3 = self.step3(self.layer2)
        # Decode
        self.layer3 = self.deconv1(self.layer3)
        self.layer4 = self.step4(torch.cat((self.layer3,self.layer2),dim=1))
        self.layer4 = self.deconv2(self.layer4)
        self.layer5 = self.step5(torch.cat((self.layer4,self.layer1),dim=1))
        self.output = self.layer5
        return self.output

In [6]:
def infer(model, image):
    model.train(False)
    x,y = image.shape[-2], image.shape[-1]
    image = torchvision.transforms.functional.resize(image, size = [x*2, y*2])
    output = model(image)
    return output

In [7]:
net = UpSampleNet()
net.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/super_resolution_ai/outputs/train/best_model.pth'))

<All keys matched successfully>

In [8]:
for filename in files:
    image = Image.open(filename)
    image = torchvision.transforms.functional.to_tensor(image)
    x,y = image.shape[-2], image.shape[-1]
    after_image = infer(net, torch.unsqueeze(image,0))
    after_image = torch.squeeze(after_image)

    torchvision.utils.save_image(image, OUTPUT_DIR+'before_'+filename.split('/')[-1])
    torchvision.utils.save_image(after_image, OUTPUT_DIR+'after_'+filename.split('/')[-1])