In [None]:
import argparse
import os
from math import log10

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.models import vgg19
from tqdm import tqdm

from dataloader import*

# import pytorch_ssim
# from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform

In [2]:
class ResnetBlock(nn.Module):
    
    def __init__(self, channels=64):
        
        super().__init__()
        
        self.identity = nn.Sequential(
            
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels)
        )

        
    def forward(self, x):
        
        identity = self.identity(x)
        x = identity + x
        return x
    

In [3]:
class UpSampleBlock(nn.Module):
    
    def __init__(self, up_scale = 2, channels = 64):
        
        super().__init__()
        
        self.model = nn.Sequential(
                
          
        
            nn.Conv2d(channels, channels*(up_scale**2), kernel_size=3, padding=1),
            nn.PixelShuffle(up_scale),    
          
        
            nn.PReLU()
        )
        
        
    def forward(self, x):
        
        x = self.model(x)
        return x

In [4]:
class generator(nn.Module):
    
    def __init__(self, B_numResnetBlock = 16, in_channels = 3, step_channels = 64):
        
        super().__init__()
        
        self.init_model = nn.Sequential(
            nn.Conv2d(in_channels, step_channels, kernel_size=9, padding=4),
            nn.PReLU()
        )
        
        mid_model = []
        
        for i in range(B_numResnetBlock):
            mid_model.append(ResnetBlock(step_channels))
            
        mid_model.append(
            nn.Sequential(
                nn.Conv2d(step_channels, step_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(step_channels)
            )
        )
        
        self.mid_model = nn.Sequential(*mid_model)
        
        end_model = []
        
        self.r = 2
        for i in range(self.r):
            end_model.append(UpSampleBlock(up_scale = 2, channels = step_channels))
        
        end_model.append(
            nn.Sequential(
                nn.Conv2d(step_channels, in_channels, kernel_size=9, padding=4)
            )
        )
        
        self.end_model = nn.Sequential(*end_model)
        
        self._weight_initializer()
    
    
    def forward(self, x):
        
        x = self.init_model(x)
        skip_connection = self.mid_model(x)
        x = skip_connection + x
        x = self.end_model(x)
        
        return x
            
            
            
    def _weight_initializer(self):
        r"""Default weight initializer for all generator models.
        Models that require custom weight initialization can override this method
        """
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)
        

In [5]:
class discriminator(nn.Module):
    
    def __init__(self, in_channels = 3, step_channels = 64):
        
        super().__init__()
        
        model = []
        
        model.append(
            nn.Sequential(
                nn.Conv2d(in_channels, step_channels, kernel_size=3, padding=1),
                nn.LeakyReLU(.2),
                nn.Conv2d(step_channels, step_channels, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(step_channels),
                nn.LeakyReLU(.2)
            )
        )
        
        self.expansion = step_channels
        
        for i in range(3):
            
            model.append(
                nn.Sequential(
                    nn.Conv2d(self.expansion, self.expansion*2, kernel_size=3, padding=1),
                    nn.BatchNorm2d(self.expansion*2),
                    nn.LeakyReLU(.2),
                    nn.Conv2d(self.expansion*2, self.expansion*2, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.expansion*2),
                    nn.LeakyReLU(.2)    
          
        
                )
            )
            
            self.expansion = self.expansion*2
            
        
        model.append(
            nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(self.expansion, self.expansion*2, kernel_size=1),
                nn.LeakyReLU(.2),
                nn.Conv2d(self.expansion*2, 1, kernel_size=1),
                nn.Sigmoid()
            )
        )
        
        
        self.model = nn.Sequential(*model)
        
        self._weight_initializer()
        

    
    def forward(self, x):
        
        x = self.model(x)
        return x.view(-1)
        
        
        
    def _weight_initializer(self):
        r"""Default weight initializer for all generator models.
        Models that require custom weight initialization can override this method
        """
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)
        

In [23]:
class discriminatorLoss(nn.Module):
    
    def __init__(self, generator, discriminator, device):
        
        super().__init__()
        self.device = device
        self.generator = generator
        self.discriminator = discriminator
        self.bceloss = nn.BCELoss()
        
    def forward(self, LR_image, HR_image):
        
        HR_pred = self.discriminator(HR_image)
        SR_image = self.generator(LR_image)
        SR_pred = self.discriminator(SR_image)
        real_ = torch.ones(HR_pred.shape).to(self.device)
        fake_ = torch.zeros(SR_pred.shape).to(self.device)
        
        HR_loss = self.bceloss(HR_pred, real_)
        SR_loss = self.bceloss(SR_pred, fake_)
        loss = HR_loss + SR_loss
        
        return loss
        

In [24]:
class generatorLoss(nn.Module):
    
    def __init__(self, generator, discriminator, device):
        
        super().__init__()
        
        self.generator = generator
        self.discriminator = discriminator
        
        vgg = vgg19(pretrained=True, progress=True)
        
        vgg_loss = nn.Sequential(*(list(vgg.features)[:9])).eval()
        for param in vgg_loss.parameters():
            param.requires_grad = False
        
        self.vgg_features = vgg_loss.to(device)
        self.mseloss = nn.MSELoss
        self.bceloss = nn.BCELoss
        
        
    def forward(self, LR_image, HR_image):
        
        SR_image = self.generator(LR_image)
        SR_pred = self.discriminator(SR_image)
        real_ = torch.ones(SR_pred.shape).to(self.device)
        
        adversial_loss = self.bceloss(SR_pred, real_)
        perceptual_loss = self.mseloss(self.vgg_features(HR_image), self.vgg_features(SR_image))
        content_loss = self.mseloss(HR_image, SR_image)
        
        return content_loss + 0.001*adversial_loss + 0.006*perceptual_loss
    

In [25]:
class parser():
    
    def __init__(self):
        
        self.crop_size = 88
        self.upscale_factor = 4 
        self.num_epochs = 100

In [26]:
opt = parser()

In [34]:
train_set = dataset_train_from_folder('/home/nirmal/VOC2012/JPEGImages/', crop_size=opt.crop_size, upscale_factor=opt.upscale_factor)
# val_set = ValDatasetFromFolder('data/VOC2012/val', upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set, num_workers=8, batch_size=64, shuffle=True)
# val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

In [35]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    # Use deterministic cudnn algorithms
    torch.backends.cudnn.deterministic = True
    epochs = 100
else:
    device = torch.device("cpu")
    epochs = 5

print("Device: {}".format(device))
print("Epochs: {}".format(epochs))

Device: cpu
Epochs: 5


In [36]:
generator = generator().to(device)
discriminator = discriminator().to(device)

TypeError: forward() missing 1 required positional argument: 'x'

In [37]:
generatorLoss = generatorLoss(generator, discriminator, device)
discriminatorLoss = discriminatorLoss(generator, discriminator, device)

TypeError: forward() takes 3 positional arguments but 4 were given

In [31]:
optimizerG = optim.Adam(generator.parameters())
optimizerD = optim.Adam(discriminator.parameters())

In [33]:
for epoch in range(1, opt.num_epochs):
    
    
    generator.train()
    discriminator.train()
    
    train_bar = tqdm(train_loader)
    for HR, LR in train_bar:
        
        HR = HR.to(device)
        LR = LR.to(device)
        
        m = HR.size(0)
        
        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        
        discriminator.zero_grad()
        d_loss = discriminatorLoss(LR, HR)
        d_loss.backward(retain_graph = True)
        optimizerD.step()
        
        ############################
        # (2) Update G network: minimize adversial loss + Perception Loss + content Loss
        ###########################
        
        generator.zero_grad()
        g_loss = generatorLoss(LR, HR)
        g_loss.backward()
        optimizerG.step()
        
        
        


  0%|          | 0/268 [00:00<?, ?it/s][ATraceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 

RuntimeError: DataLoader worker (pid 16805) is killed by signal: Killed. 