In [1]:
import torch
from torch import optim,nn
from torchvision import transforms,datasets
import torch.nn.functional as F
import torch.nn.init as init
import matplotlib as plt
import torch.utils.data as data

import os
from os import listdir
from math import log10
from PIL import Image


%matplotlib inline

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
def Downloaddata():
    import urllib
    #import wget
    path = "./"
    tgz_file = "BSDS300-images.tgz"
    print("Data Download Initialing...")
    url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
    #wget.download(url)
    data = urllib.request.urlretrieve(url,tgz_file)
    print("Data Download Complete")
    print("Data Extraction Initializing...")
    with tarfile.open(tgz_file) as tar:
        for item in tar:
            tar.extract(item,path)
    print("Data Extraction Complete")

In [4]:
def dataset():
    data_dir = "BSDS300/images"
    if(not os.path.exists(data_dir)):
        Downloaddata()
    return data_dir

In [5]:
No_of_Epoch = 30
upscale_Factor = 3
batch_Size = 4
Learning_Rate = 0.001

In [6]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png",".jpg",".jpeg"])

def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y,_,_ = img.split()
    return y


In [7]:

class DatasetFromFolder(data.Dataset):
    def __init__(self,img_dir,input_transform=None,target_transform=None):
        super(DatasetFromFolder,self).__init__()
        self.image_filenames = [img_dir+"/"+x for x in listdir(img_dir) if is_image_file(x)]
        
        self.input_transform = input_transform
        self.target_transform = target_transform
        
    
    
    def __getitem__(self,index):
        Input=load_img(self.image_filenames[index])
        target = Input.copy()
        if self.input_transform:
            Input = self.input_transform(Input)
            
        if self.target_transform:
            target = self.target_transform(target)
            
        return Input, target
    
    def __len__(self):
        return len(self.image_filenames)
        

In [8]:
def valid_crop_size(crop_size,upscale_factor):
    return crop_size - (crop_size % upscale_factor)

def input_transform(crop_size,upscale_factor):
    return transforms.Compose([transforms.CenterCrop(crop_size),
                               transforms.Resize(crop_size//upscale_factor),
                               transforms.ToTensor(),
                              ])

def target_transform(crop_size,upscale_factor):
    return transforms.Compose([transforms.CenterCrop(crop_size),
                               transforms.ToTensor(),
                              ])

def fetch_training_data(upscale_factor):
    dataset_dir = dataset()
    train_dir = dataset_dir+"/train/"
    crop_size = valid_crop_size(256,upscale_factor)
    
    return DatasetFromFolder(train_dir,
                             input_transform=input_transform(crop_size,upscale_factor),
                             target_transform=target_transform(crop_size,upscale_factor)
                            )

def fetch_test_Data(upscale_factor):
    datasetDir = dataset()
    test_dir = datasetDir+"/test"
    crop_size = valid_crop_size(256,upscale_factor)
    
    return DatasetFromFolder(test_dir,
                            input_transform = input_transform(crop_size,upscale_factor),
                            target_transform = target_transform(crop_size,upscale_factor)
                            )



In [9]:
train_data = fetch_training_data(upscale_Factor)
test_data = fetch_test_Data(upscale_Factor)

In [10]:
train_loader = torch.utils.data.DataLoader(train_data,batch_size=4,shuffle=True)
test_data = torch.utils.data.DataLoader(test_data,batch_size=4,shuffle=True)

In [11]:
class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1,64,(5,5),(1,1),(2,2))
        self.conv2 = nn.Conv2d(64,64,(3,3),(1,1),(1,1))
        self.conv3 = nn.Conv2d(64,32,(3,3),(1,1),(1,1))
        self.conv4 = nn.Conv2d(32,3**2,(3,3),(1,1),(1,1))
        self.pixel_shuffle = nn.PixelShuffle(3)
        self._initialize_weights()
    
    def forward(self,x):
        x=self.relu(self.conv1(x))
        x=self.relu(self.conv2(x))
        x=self.relu(self.conv3(x))
        x=self.pixel_shuffle(self.conv4(x))
        
        return x
    
    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight,init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight,init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

In [12]:
model = Model()
model

Model(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)

In [13]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)

In [14]:
model.to(device)
epoch=10
train_loss = 0.0
valid_loss = 0.0
accuracy = 0.0

for i in range(epoch):
    model.train()
    for input_image,target_image  in enumerate(train_loader,1):
        input_image,target_image = target_image[0].to(device),target_image[1].to(device)
        
        optimizer.zero_grad()
        loss = criterion(model(input_image),target_image)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
    for test_batch in test_data:
        input_image,target_image = test_batch[0].to(device),test_batch[1].to(device)
        output = model(input_image)
        loss = criterion(output,target_image)
        psnr = 10*log10(1/loss.item())
        valid_loss += psnr
        
    train_loss = train_loss/len(train_loader)
    valid_loss = valid_loss/len(test_data)
    accuracy = (valid_loss/40.0)*100
    print("Epoch {} \tTraining Loss: {:.4f} \tPSNR: {:.4f}\tAccuracy: {:.3f}%".format((i+1),train_loss,valid_loss,accuracy))
        



    

Epoch 1 	Training Loss: 0.0255 	PSNR: 21.4474	Accuracy: 53.619%
Epoch 2 	Training Loss: 0.0063 	PSNR: 23.7683	Accuracy: 59.421%
Epoch 3 	Training Loss: 0.0045 	PSNR: 24.6475	Accuracy: 61.619%
Epoch 4 	Training Loss: 0.0040 	PSNR: 25.1416	Accuracy: 62.854%
Epoch 5 	Training Loss: 0.0037 	PSNR: 25.3844	Accuracy: 63.461%
Epoch 6 	Training Loss: 0.0037 	PSNR: 25.1373	Accuracy: 62.843%
Epoch 7 	Training Loss: 0.0036 	PSNR: 25.6141	Accuracy: 64.035%
Epoch 8 	Training Loss: 0.0036 	PSNR: 25.6396	Accuracy: 64.099%
Epoch 9 	Training Loss: 0.0035 	PSNR: 25.7738	Accuracy: 64.435%
Epoch 10 	Training Loss: 0.0034 	PSNR: 25.6628	Accuracy: 64.157%
