In [1]:
!git clone https://github.com/niazwazir/SUB_PIXEL_CNN.git

Cloning into 'SUB_PIXEL_CNN'...
remote: Enumerating objects: 29, done.[K
remote: Counting objects: 100% (29/29), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 29 (delta 10), reused 25 (delta 9), pack-reused 0[K
Unpacking objects: 100% (29/29), done.


In [2]:
!ls

sample_data  SUB_PIXEL_CNN


In [3]:
cd SUB_PIXEL_CNN/

/content/SUB_PIXEL_CNN


In [4]:
!ls

 1609.05158.pdf         Sample.png
 ESPCN_IMAGE_SR.ipynb   Super_resolution.ipynb
 LICENSE	       'Super resolution.odt'
 model_epoch_599.pth    super_resolution.onnx
 README.md	        Super_resolution_on_onnx_Runtime.ipynb


In [5]:
import numpy as np
import os
import matplotlib.pyplot as plt
from math import log10
%matplotlib inline

In [6]:
import torch 
from torch.nn import init
from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F
from PIL import Image
import torch.utils.data as data

In [7]:
!git clone https://github.com/niazwazir/DATASETS.git

Cloning into 'DATASETS'...
remote: Enumerating objects: 1423, done.[K
remote: Total 1423 (delta 0), reused 0 (delta 0), pack-reused 1423[K
Receiving objects: 100% (1423/1423), 388.60 MiB | 41.01 MiB/s, done.
Resolving deltas: 100% (617/617), done.


In [8]:
!ls

 1609.05158.pdf         Sample.png
 DATASETS	        Super_resolution.ipynb
 ESPCN_IMAGE_SR.ipynb  'Super resolution.odt'
 LICENSE	        super_resolution.onnx
 model_epoch_599.pth    Super_resolution_on_onnx_Runtime.ipynb
 README.md


In [9]:
root = 'DATASETS/data/'

In [10]:
os.getcwd()

'/content/SUB_PIXEL_CNN'

In [11]:
class DataSuperRes(data.Dataset):
    def __init__(self, path, input_transform = None, target_transform = None):
        super(DataSuperRes, self).__init__()
        self.input_transform = input_transform
        self.target_transform = target_transform
        self.filepath = [os.path.join(path,x) for x in os.listdir(path)]
        
    def __len__(self):
        return len(self.filepath)
    
    def __getitem__(self,index):
        image = Image.open(self.filepath[index]).convert('YCbCr')
        image_rgb = image.convert('RGB')
        y, _ , _ = image.split()
        target = y.copy()
        if self.input_transform:
            img = self.input_transform(y)
        if self.target_transform:
            target = self.target_transform(target)
            image_rgb = self.target_transform(image_rgb)
        #tensor = transforms.ToTensor()
        #image = tensor(image)
        return image_rgb, img, target
    

In [12]:
upscale_factor = 3
batch_size = 4
epochs = 50 #700
lr = 0.001

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
def Valid_crop_size(crop_size, upscalefactor):
    return crop_size - (crop_size % upscalefactor)

In [15]:
# We resize the image to size/upscale_factor to match the output of the model
crop_size = Valid_crop_size(256,upscale_factor)
input_transforms = transforms.Compose([transforms.CenterCrop(crop_size),
                                      transforms.Resize(crop_size//upscale_factor),
                                      transforms.ToTensor()
                                      ])

target_transforms = transforms.Compose([transforms.CenterCrop(crop_size),
                                      transforms.ToTensor()
                                      ])


In [16]:
train_set = DataSuperRes(root + 'bsd200', input_transform=input_transforms, target_transform= target_transforms)
test_set = DataSuperRes(root + 'set5', input_transform=input_transforms, target_transform=target_transforms)

trainloader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
testloader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [17]:
print(len(trainloader))

50


In [18]:
len(os.listdir(root + 'bsd200'))/8

25.0

In [19]:
w = torch.empty(2,3)
torch.nn.init.orthogonal_(w,gain=1)

tensor([[-0.8958, -0.3955,  0.2028],
        [-0.3224,  0.2639, -0.9091]])

In [20]:
torch.nn.init.orthogonal_(w,gain=1)

tensor([[-0.6843, -0.7140,  0.1482],
        [-0.4149,  0.5484,  0.7260]])

In [21]:
def show_img(epoch, normal, super_resolution):
    
    tensor = transforms.ToTensor()
    PIL = transforms.ToPILImage()
    img_normal = PIL(normal)
    
    img_ycbcr = img_normal.convert('YCbCr')
    _, img_cb, img_cr = img_ycbcr.split()
    
    #super_resolution = np.transpose(super_resolution, (1,2,0))
    normal = np.transpose(normal, (1,2,0))
    
    out_img_y = super_resolution*255.0
    out_img_y = out_img_y.clip(0, 255)
    #print(out_img_y.shape)
    #out_img_y = out_img_y.squeeze()
    out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')
    print(out_img_y.size)
    
    out_img_cb = img_cb.resize(out_img_y.size, Image.BICUBIC)
    out_img_cr = img_cr.resize(out_img_y.size, Image.BICUBIC)
    out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr])
    out_img = out_img.convert('RGB')
    #out_img = Image.merge('RGB', [out_img_y, out_img_cb, out_img_cr])
    print(out_img.size)
    
    if epoch % 150 == 0:
        out_img.save(root+'out.png')
    fig=plt.figure(figsize=[10,5])
    
    fig.add_subplot(1, 2, 1, title='Original Image')
    plt.imshow(normal)
    
    fig.add_subplot(1, 2, 2, title='Super resolution Image')
    plt.imshow(out_img)
    
    
    fig.subplots_adjust(wspace = 0.5)
    plt.show()
    

In [22]:
class Network(nn.Module):
    def __init__(self, upscale_factor):
        super(Network, self).__init__()

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x
    
    # A way to initialize weights. Read more here: https://pytorch.org/docs/stable/_modules/torch/nn/init.html#orthogonal_
    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

In [23]:
model = Network(upscale_factor=upscale_factor).to(device)

In [24]:
model

Network(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)

In [25]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr=lr)

In [26]:
train_losses = []
test_losses = []
print_every = 25
steps = 1
for e in range(epochs):
    batch_loss = 0
    test_loss = 0
    avg_psnr = 0
    print(f'Starting epoch: {e+1}/{epochs}')
    for color, images, target in trainloader:
        images, target = images.to(device), target.to(device)
        #print('train images input shape:',images.shape)
        
        output = model(images)
        #print('train model output shape:',output.shape)
        loss = criterion(output,target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        steps = steps + 1
        batch_loss += loss.item()
        
        if steps % print_every == 0:
            with torch.no_grad():
                model.eval()
                for color_test, images_test, target_test in testloader:
                    images_test, target_test = images_test.to(device), target_test.to(device)
                    #print('test images input shape:',images_test.shape)
                    output_test = model(images_test)
                    #print('test model output shape:',output_test.shape)
                    loss = criterion(output_test,target_test)
                    test_loss +=loss.item()

                    psnr = 10 * log10(1 / loss.item())
                    avg_psnr += psnr

            #print(output.shape)
            #print(target.shape)
            #print(type(color))
            image_out = output_test[0].cpu().numpy()
            #target_out = target_test[0].cpu().numpy()
            color = color_test[0]
            if (e+1)%150 == 0:
                show_img(e, color, image_out)
            
    if(e+1)%200 == 0:
        model_out_path = "model_epoch_{}.pth".format(e)
        torch.save(model, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))
    train_losses.append(batch_loss/len(trainloader))
    test_losses.append(test_loss/len(testloader))
    print(f"Training loss at epoch {e}: {batch_loss/len(trainloader)}")
    print(f"Test loss at epoch {e}: {test_loss/len(testloader)}")
    print(f"Average PSNR: {avg_psnr/len(testloader)}")
    model.train()

Starting epoch: 1/50
Training loss at epoch 0: 0.02471798149868846
Test loss at epoch 0: 0.016669172910042107
Average PSNR: 42.947094617233844
Starting epoch: 2/50
Training loss at epoch 1: 0.00591379898134619
Test loss at epoch 1: 0.00592293287627399
Average PSNR: 52.42121645100697
Starting epoch: 3/50
Training loss at epoch 2: 0.004543734798207879
Test loss at epoch 2: 0.0038431825523730367
Average PSNR: 56.87080895870194
Starting epoch: 4/50
Training loss at epoch 3: 0.003978439818602055
Test loss at epoch 3: 0.0030722775554750115
Average PSNR: 59.00826780887776
Starting epoch: 5/50
Training loss at epoch 4: 0.0037350031756795943
Test loss at epoch 4: 0.0028542785439640284
Average PSNR: 59.17387312191029
Starting epoch: 6/50
Training loss at epoch 5: 0.0036912717088125645
Test loss at epoch 5: 0.0027712090086424723
Average PSNR: 59.29684246316546
Starting epoch: 7/50
Training loss at epoch 6: 0.0035165865067392586
Test loss at epoch 6: 0.0024416307860519737
Average PSNR: 61.05808252