In [1]:
import torch
from torch import optim,nn
from torchvision import transforms,datasets
import torch.nn.functional as F
import torch.nn.init as init
import matplotlib as plt
import torch.utils.data as data

import os
from os import listdir
from math import log10
from PIL import Image


%matplotlib inline

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
def Downloaddata():
    import urllib
    #import wget
    path = "./"
    tgz_file = "BSDS300-images.tgz"
    print("Data Download Initialing...")
    url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
    #wget.download(url)
    data = urllib.request.urlretrieve(url,tgz_file)
    print("Data Download Complete")
    print("Data Extraction Initializing...")
    with tarfile.open(tgz_file) as tar:
        for item in tar:
            tar.extract(item,path)
    print("Data Extraction Complete")

In [4]:
def dataset():
    data_dir = "BSDS300/images"
    if(not os.path.exists(data_dir)):
        Downloaddata()
    return data_dir

In [5]:
upscale_Factor = 3
batch_Size = 4
Learning_Rate = 0.001

In [6]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png",".jpg",".jpeg"])

def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y,_,_ = img.split()
    return y


In [7]:

class DatasetFromFolder(data.Dataset):
    def __init__(self,img_dir,input_transform=None,target_transform=None):
        super(DatasetFromFolder,self).__init__()
        self.image_filenames = [img_dir+"/"+x for x in listdir(img_dir) if is_image_file(x)]
        
        self.input_transform = input_transform
        self.target_transform = target_transform
        
    
    
    def __getitem__(self,index):
        Input=load_img(self.image_filenames[index])
        target = Input.copy()
        if self.input_transform:
            Input = self.input_transform(Input)
            
        if self.target_transform:
            target = self.target_transform(target)
            
        return Input, target
    
    def __len__(self):
        return len(self.image_filenames)
        

In [8]:
def valid_crop_size(crop_size,upscale_factor):
    return crop_size - (crop_size % upscale_factor)

def input_transform(crop_size,upscale_factor):
    return transforms.Compose([transforms.CenterCrop(crop_size),
                               transforms.Resize(crop_size//upscale_factor),
                               transforms.ToTensor(),
                              ])

def target_transform(crop_size,upscale_factor):
    return transforms.Compose([transforms.CenterCrop(crop_size),
                               transforms.ToTensor(),
                              ])

def fetch_training_data(upscale_factor):
    dataset_dir = dataset()
    train_dir = dataset_dir+"/train/"
    crop_size = valid_crop_size(256,upscale_factor)
    
    return DatasetFromFolder(train_dir,
                             input_transform=input_transform(crop_size,upscale_factor),
                             target_transform=target_transform(crop_size,upscale_factor)
                            )

def fetch_test_Data(upscale_factor):
    datasetDir = dataset()
    test_dir = datasetDir+"/test"
    crop_size = valid_crop_size(256,upscale_factor)
    
    return DatasetFromFolder(test_dir,
                            input_transform = input_transform(crop_size,upscale_factor),
                            target_transform = target_transform(crop_size,upscale_factor)
                            )
def checkpoint(epoch):
    model_out_path = "model_epoch_{}.pth".format(epoch)
    torch.save(model,model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))



In [9]:
train_data = fetch_training_data(upscale_Factor)
test_data = fetch_test_Data(upscale_Factor)

In [10]:
train_loader = torch.utils.data.DataLoader(train_data,batch_size=4,shuffle=True)
test_data = torch.utils.data.DataLoader(test_data,batch_size=4,shuffle=True)

In [11]:
class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1,64,(5,5),(1,1),(2,2))
        self.conv2 = nn.Conv2d(64,64,(3,3),(1,1),(1,1))
        self.conv3 = nn.Conv2d(64,32,(3,3),(1,1),(1,1))
        self.conv4 = nn.Conv2d(32,3**2,(3,3),(1,1),(1,1))
        self.pixel_shuffle = nn.PixelShuffle(3)
        self._initialize_weights()
    
    def forward(self,x):
        x=self.relu(self.conv1(x))
        x=self.relu(self.conv2(x))
        x=self.relu(self.conv3(x))
        x=self.pixel_shuffle(self.conv4(x))
        
        return x
    
    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight,init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight,init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

In [12]:
model = Model()
model

Model(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)

In [13]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr=Learning_Rate)

In [14]:
model.to(device)
epoch=500
train_loss = 0.0
valid_loss = 0.0
accuracy = 0.0

for i in range(epoch):
    model.train()
    for input_image,target_image  in enumerate(train_loader,1):
        input_image,target_image = target_image[0].to(device),target_image[1].to(device)
        
        optimizer.zero_grad()
        loss = criterion(model(input_image),target_image)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
    for test_batch in test_data:
        input_image,target_image = test_batch[0].to(device),test_batch[1].to(device)
        output = model(input_image)
        loss = criterion(output,target_image)
        psnr = 10*log10(1/loss.item())
        valid_loss += psnr
        
    train_loss = train_loss/len(train_loader)
    valid_loss = valid_loss/len(test_data)
    accuracy = (valid_loss/40.0)*100
    print("Epoch {} \tTraining Loss: {:.4f} \tPSNR: {:.4f}\tAccuracy: {:.3f}%".format((i+1),train_loss,valid_loss,accuracy))
        



    

Epoch 1 	Training Loss: 0.0251 	PSNR: 20.4208	Accuracy: 51.052%
Epoch 2 	Training Loss: 0.0074 	PSNR: 23.2980	Accuracy: 58.245%
Epoch 3 	Training Loss: 0.0051 	PSNR: 24.3642	Accuracy: 60.910%
Epoch 4 	Training Loss: 0.0043 	PSNR: 24.8232	Accuracy: 62.058%
Epoch 5 	Training Loss: 0.0040 	PSNR: 24.8854	Accuracy: 62.213%
Epoch 6 	Training Loss: 0.0038 	PSNR: 25.1964	Accuracy: 62.991%
Epoch 7 	Training Loss: 0.0037 	PSNR: 25.3176	Accuracy: 63.294%
Epoch 8 	Training Loss: 0.0036 	PSNR: 25.5771	Accuracy: 63.943%
Epoch 9 	Training Loss: 0.0042 	PSNR: 23.4886	Accuracy: 58.721%
Epoch 10 	Training Loss: 0.0045 	PSNR: 25.3841	Accuracy: 63.460%
Epoch 11 	Training Loss: 0.0036 	PSNR: 25.5794	Accuracy: 63.949%
Epoch 12 	Training Loss: 0.0035 	PSNR: 25.8045	Accuracy: 64.511%
Epoch 13 	Training Loss: 0.0034 	PSNR: 25.5532	Accuracy: 63.883%
Epoch 14 	Training Loss: 0.0034 	PSNR: 25.6897	Accuracy: 64.224%
Epoch 15 	Training Loss: 0.0034 	PSNR: 25.7931	Accuracy: 64.483%
Epoch 16 	Training Loss: 0.0034 	P

Epoch 127 	Training Loss: 0.0030 	PSNR: 26.1964	Accuracy: 65.491%
Epoch 128 	Training Loss: 0.0030 	PSNR: 26.1591	Accuracy: 65.398%
Epoch 129 	Training Loss: 0.0030 	PSNR: 26.1259	Accuracy: 65.315%
Epoch 130 	Training Loss: 0.0030 	PSNR: 26.0762	Accuracy: 65.190%
Epoch 131 	Training Loss: 0.0030 	PSNR: 26.1462	Accuracy: 65.365%
Epoch 132 	Training Loss: 0.0030 	PSNR: 26.1422	Accuracy: 65.355%
Epoch 133 	Training Loss: 0.0030 	PSNR: 26.1300	Accuracy: 65.325%
Epoch 134 	Training Loss: 0.0030 	PSNR: 26.2610	Accuracy: 65.653%
Epoch 135 	Training Loss: 0.0030 	PSNR: 26.2045	Accuracy: 65.511%
Epoch 136 	Training Loss: 0.0031 	PSNR: 26.1772	Accuracy: 65.443%
Epoch 137 	Training Loss: 0.0030 	PSNR: 25.9680	Accuracy: 64.920%
Epoch 138 	Training Loss: 0.0030 	PSNR: 26.2692	Accuracy: 65.673%
Epoch 139 	Training Loss: 0.0031 	PSNR: 25.5454	Accuracy: 63.863%
Epoch 140 	Training Loss: 0.0032 	PSNR: 26.0976	Accuracy: 65.244%
Epoch 141 	Training Loss: 0.0030 	PSNR: 26.2614	Accuracy: 65.654%
Epoch 142 

Epoch 252 	Training Loss: 0.0029 	PSNR: 26.2888	Accuracy: 65.722%
Epoch 253 	Training Loss: 0.0029 	PSNR: 26.4498	Accuracy: 66.125%
Epoch 254 	Training Loss: 0.0029 	PSNR: 26.3116	Accuracy: 65.779%
Epoch 255 	Training Loss: 0.0029 	PSNR: 26.3303	Accuracy: 65.826%
Epoch 256 	Training Loss: 0.0029 	PSNR: 26.3424	Accuracy: 65.856%
Epoch 257 	Training Loss: 0.0029 	PSNR: 26.3205	Accuracy: 65.801%
Epoch 258 	Training Loss: 0.0029 	PSNR: 26.1289	Accuracy: 65.322%
Epoch 259 	Training Loss: 0.0029 	PSNR: 26.2063	Accuracy: 65.516%
Epoch 260 	Training Loss: 0.0030 	PSNR: 25.9292	Accuracy: 64.823%
Epoch 261 	Training Loss: 0.0029 	PSNR: 26.2871	Accuracy: 65.718%
Epoch 262 	Training Loss: 0.0029 	PSNR: 26.4066	Accuracy: 66.016%
Epoch 263 	Training Loss: 0.0029 	PSNR: 26.3716	Accuracy: 65.929%
Epoch 264 	Training Loss: 0.0029 	PSNR: 26.2423	Accuracy: 65.606%
Epoch 265 	Training Loss: 0.0029 	PSNR: 26.4048	Accuracy: 66.012%
Epoch 266 	Training Loss: 0.0029 	PSNR: 26.1548	Accuracy: 65.387%
Epoch 267 

Epoch 377 	Training Loss: 0.0028 	PSNR: 26.2467	Accuracy: 65.617%
Epoch 378 	Training Loss: 0.0028 	PSNR: 26.1843	Accuracy: 65.461%
Epoch 379 	Training Loss: 0.0028 	PSNR: 26.3692	Accuracy: 65.923%
Epoch 380 	Training Loss: 0.0028 	PSNR: 26.1524	Accuracy: 65.381%
Epoch 381 	Training Loss: 0.0028 	PSNR: 26.2087	Accuracy: 65.522%
Epoch 382 	Training Loss: 0.0028 	PSNR: 26.1974	Accuracy: 65.494%
Epoch 383 	Training Loss: 0.0028 	PSNR: 26.1748	Accuracy: 65.437%
Epoch 384 	Training Loss: 0.0028 	PSNR: 26.3225	Accuracy: 65.806%
Epoch 385 	Training Loss: 0.0028 	PSNR: 26.2397	Accuracy: 65.599%
Epoch 386 	Training Loss: 0.0028 	PSNR: 26.3047	Accuracy: 65.762%
Epoch 387 	Training Loss: 0.0028 	PSNR: 26.1825	Accuracy: 65.456%
Epoch 388 	Training Loss: 0.0028 	PSNR: 26.2186	Accuracy: 65.546%
Epoch 389 	Training Loss: 0.0028 	PSNR: 26.2202	Accuracy: 65.550%
Epoch 390 	Training Loss: 0.0028 	PSNR: 26.2735	Accuracy: 65.684%
Epoch 391 	Training Loss: 0.0028 	PSNR: 26.3554	Accuracy: 65.889%
Epoch 392 

In [16]:
model_out_path = "model_epoch.pth".format(epoch)
torch.save(model,model_out_path)
print("Checkpoint saved to".format(model_out_path))

Checkpoint saved to


  "type " + obj.__name__ + ". It won't be checked "
