In [1]:
# BSDS300 dataset
!wget https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300-images.tgz
# If not working, uncomment the following line:
# !wget http://nipg12.inf.elte.hu:8000/BSDS300-images.tgz
!tar -xvzf BSDS300-images.tgz

--2021-01-28 17:48:06--  https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300-images.tgz
Resolving www2.eecs.berkeley.edu (www2.eecs.berkeley.edu)... 128.32.244.190
Connecting to www2.eecs.berkeley.edu (www2.eecs.berkeley.edu)|128.32.244.190|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 22211847 (21M) [application/x-tar]
Saving to: ‘BSDS300-images.tgz’


2021-01-28 17:48:09 (6.49 MB/s) - ‘BSDS300-images.tgz’ saved [22211847/22211847]

BSDS300/images/
BSDS300/images/train/
BSDS300/images/train/159029.jpg
BSDS300/images/train/20008.jpg
BSDS300/images/train/155060.jpg
BSDS300/images/train/286092.jpg
BSDS300/images/train/100075.jpg
BSDS300/images/train/61060.jpg
BSDS300/images/train/46076.jpg
BSDS300/images/train/301007.jpg
BSDS300/images/train/26031.jpg
BSDS300/images/train/232038.jpg
BSDS300/images/train/45077.jpg
BSDS300/images/train/365025.jpg
BSDS300/images/train/188091.jpg
BSDS300/images/train/299091.jpg
BSDS300/images/train

In [2]:
!ls

BSDS300  BSDS300-images.tgz  sample_data


In [3]:
import numpy as np
import os
import matplotlib.pyplot as plt
from math import log10
%matplotlib inline

In [4]:
import torch 
from torch.nn import init
from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F
from PIL import Image
import torch.utils.data as data

In [5]:
root = 'BSDS300/images/'

In [6]:
os.getcwd()

'/content'

In [7]:
class DataSuperRes(data.Dataset):
    def __init__(self, path, input_transform = None, target_transform = None):
        super(DataSuperRes, self).__init__()
        self.input_transform = input_transform
        self.target_transform = target_transform
        self.filepath = [os.path.join(path,x) for x in os.listdir(path)]
        
    def __len__(self):
        return len(self.filepath)
    
    def __getitem__(self,index):
        image = Image.open(self.filepath[index]).convert('YCbCr')
        image_rgb = image.convert('RGB')
        y, _ , _ = image.split()
        target = y.copy()
        if self.input_transform:
            img = self.input_transform(y)
        if self.target_transform:
            target = self.target_transform(target)
            image_rgb = self.target_transform(image_rgb)
        #tensor = transforms.ToTensor()
        #image = tensor(image)
        return image_rgb, img, target
    

In [8]:
upscale_factor = 3
batch_size = 4
epochs = 7 #700
lr = 0.001

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
def Valid_crop_size(crop_size, upscalefactor):
    return crop_size - (crop_size % upscalefactor)

In [11]:
# We resize the image to size/upscale_factor to match the output of the model
crop_size = Valid_crop_size(256,upscale_factor)
input_transforms = transforms.Compose([transforms.CenterCrop(crop_size),
                                      transforms.Resize(crop_size//upscale_factor),
                                      transforms.ToTensor()
                                      ])

target_transforms = transforms.Compose([transforms.CenterCrop(crop_size),
                                      transforms.ToTensor()
                                      ])


In [12]:
train_set = DataSuperRes(root + 'train', input_transform=input_transforms, target_transform= target_transforms)
test_set = DataSuperRes(root + 'test', input_transform=input_transforms, target_transform=target_transforms)

trainloader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
testloader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [13]:
#print(len(trainloader))

In [14]:
#len(os.listdir(root + 'train'))/8

In [15]:
#w = torch.empty(2,3)
#torch.nn.init.orthogonal_(w,gain=1)

In [16]:
#torch.nn.init.orthogonal_(w,gain=1)

In [17]:
def show_img(epoch, normal, super_resolution):
    
    tensor = transforms.ToTensor()
    PIL = transforms.ToPILImage()
    img_normal = PIL(normal)
    
    img_ycbcr = img_normal.convert('YCbCr')
    _, img_cb, img_cr = img_ycbcr.split()
    
    #super_resolution = np.transpose(super_resolution, (1,2,0))
    normal = np.transpose(normal, (1,2,0))
    
    out_img_y = super_resolution*255.0
    out_img_y = out_img_y.clip(0, 255)
    #print(out_img_y.shape)
    #out_img_y = out_img_y.squeeze()
    out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')
    print(out_img_y.size)
    
    out_img_cb = img_cb.resize(out_img_y.size, Image.BICUBIC)
    out_img_cr = img_cr.resize(out_img_y.size, Image.BICUBIC)
    out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr])
    out_img = out_img.convert('RGB')
    #out_img = Image.merge('RGB', [out_img_y, out_img_cb, out_img_cr])
    print(out_img.size)
    
    if epoch % 150 == 0:
        out_img.save(root+'out.png')
    fig=plt.figure(figsize=[10,5])
    
    fig.add_subplot(1, 2, 1, title='Original Image')
    plt.imshow(normal)
    
    fig.add_subplot(1, 2, 2, title='Super resolution Image')
    plt.imshow(out_img)
    '''
    fig.add_subplot(2, 3, 3, title='y component')
    plt.imshow(out_img_y)
    
    fig.add_subplot(2, 3, 4, title='cb')
    plt.imshow(out_img_cb)
    
    fig.add_subplot(2, 3, 5, title='cr')
    plt.imshow(out_img_cr)
    
    fig.add_subplot(2, 3, 6, title='ycbcr from rgb')
    plt.imshow(img_normal)'''
    
    fig.subplots_adjust(wspace = 0.5)
    plt.show()
    

In [18]:
class Network(nn.Module):
    def __init__(self, upscale_factor):
        super(Network, self).__init__()

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x
    
    # A way to initialize weights. Read more here: https://pytorch.org/docs/stable/_modules/torch/nn/init.html#orthogonal_
    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

In [19]:
model = Network(upscale_factor=upscale_factor).to(device)

In [20]:
model

Network(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)

In [21]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr=lr)

In [22]:
train_losses = []
test_losses = []
print_every = 25
steps = 1
for e in range(epochs):
    batch_loss = 0
    test_loss = 0
    avg_psnr = 0
    print(f'Starting epoch: {e+1}/{epochs}')
    for color, images, target in trainloader:
        images, target = images.to(device), target.to(device)
        #print('train images input shape:',images.shape)
        
        output = model(images)
        #print('train model output shape:',output.shape)
        loss = criterion(output,target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        steps = steps + 1
        batch_loss += loss.item()
        
        if steps % print_every == 0:
            with torch.no_grad():
                model.eval()
                for color_test, images_test, target_test in testloader:
                    images_test, target_test = images_test.to(device), target_test.to(device)
                    #print('test images input shape:',images_test.shape)
                    output_test = model(images_test)
                    #print('test model output shape:',output_test.shape)
                    loss = criterion(output_test,target_test)
                    test_loss +=loss.item()

                    psnr = 10 * log10(1 / loss.item())
                    avg_psnr += psnr

            #print(output.shape)
            #print(target.shape)
            #print(type(color))
            image_out = output_test[0].cpu().numpy()
            #target_out = target_test[0].cpu().numpy()
            color = color_test[0]
            if (e+1)%150 == 0:
                show_img(e, color, image_out)
            
    if(e+1)%200 == 0:
        model_out_path = "model_epoch_{}.pth".format(e)
        torch.save(model, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))
    train_losses.append(batch_loss/len(trainloader))
    test_losses.append(test_loss/len(testloader))
    print(f"Training loss at epoch {e}: {batch_loss/len(trainloader)}")
    print(f"Test loss at epoch {e}: {test_loss/len(testloader)}")
    print(f"Average PSNR: {avg_psnr/len(testloader)}")
    model.train()

Starting epoch: 1/7
Training loss at epoch 0: 0.01915089404210448
Test loss at epoch 0: 0.021783415414392947
Average PSNR: 39.762492696170455
Starting epoch: 2/7
Training loss at epoch 1: 0.0060701858252286914
Test loss at epoch 1: 0.011849460219964384
Average PSNR: 45.07034771119003
Starting epoch: 3/7
Training loss at epoch 2: 0.004434822204057127
Test loss at epoch 2: 0.00928661067970097
Average PSNR: 47.234458670909596
Starting epoch: 4/7
Training loss at epoch 3: 0.0039649225608445705
Test loss at epoch 3: 0.008565872320905327
Average PSNR: 47.947004550844895
Starting epoch: 5/7
Training loss at epoch 4: 0.003776757246814668
Test loss at epoch 4: 0.008088659956119954
Average PSNR: 48.47980955510162
Starting epoch: 6/7
Training loss at epoch 5: 0.0036252790014259517
Test loss at epoch 5: 0.007856538156047463
Average PSNR: 48.747187806469334
Starting epoch: 7/7
Training loss at epoch 6: 0.003536977848270908
Test loss at epoch 6: 0.007827226724475621
Average PSNR: 48.76347908279851
