In [1]:
import argparse
import os
from math import log10

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.models import vgg19
from tqdm import tqdm
import pytorch_ssim

from dataloader import*
from fastprogress import master_bar, progress_bar

# import pytorch_ssim
# from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform

In [2]:
class ResnetBlock(nn.Module):
    
    def __init__(self, channels=64):
        
        super().__init__()
        
        self.identity = nn.Sequential(
            
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels)
        )

        
    def forward(self, x):
        
        identity = self.identity(x)
        x = identity + x
        return x
    

In [3]:
class UpSampleBlock(nn.Module):
    
    def __init__(self, up_scale = 2, channels = 64):
        
        super().__init__()
        
        self.model = nn.Sequential(
                
          
        
            nn.Conv2d(channels, channels*(up_scale**2), kernel_size=3, padding=1),
            nn.PixelShuffle(up_scale),    
          
        
            nn.PReLU()
        )
        
        
    def forward(self, x):
        
        x = self.model(x)
        return x

In [4]:
class generator(nn.Module):
    
    def __init__(self, B_numResnetBlock = 4, in_channels = 3, step_channels = 64):
        
        super().__init__()
        
        self.init_model = nn.Sequential(
            nn.Conv2d(in_channels, step_channels, kernel_size=9, padding=4),
            nn.PReLU()
        )
        
        mid_model = []
        
        for i in range(B_numResnetBlock):
            mid_model.append(ResnetBlock(step_channels))
            
        mid_model.append(
            nn.Sequential(
                nn.Conv2d(step_channels, step_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(step_channels)
            )
        )
        
        self.mid_model = nn.Sequential(*mid_model)
        
        end_model = []
        
        self.r = 2
        for i in range(self.r):
            end_model.append(UpSampleBlock(up_scale = 2, channels = step_channels))
        
        end_model.append(
            nn.Sequential(
                nn.Conv2d(step_channels, in_channels, kernel_size=9, padding=4)
            )
        )
        
        self.end_model = nn.Sequential(*end_model)
        
        self._weight_initializer()
    
    
    def forward(self, x):
        
        x = self.init_model(x)
        skip_connection = self.mid_model(x)
        x = skip_connection + x
        x = self.end_model(x)
        
        return x
            
            
            
    def _weight_initializer(self):
        r"""Default weight initializer for all generator models.
        Models that require custom weight initialization can override this method
        """
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)
        

In [5]:
class generatorLoss(nn.Module):
    
    def __init__(self, generator, discriminator, device):
        
        super().__init__()
        
        self.generator = generator
        self.discriminator = discriminator
        
        vgg = vgg19(pretrained=True, progress=True)
        
        vgg_loss = nn.Sequential(*(list(vgg.features)[:9])).eval()
        for param in vgg_loss.parameters():
            param.requires_grad = False
        
        self.vgg_features = vgg_loss.to(device)
        self.mseloss = nn.MSELoss.to(device)
        self.bceloss = nn.BCELoss.to(device)
        
        
    def forward(self, LR_image, HR_image):
        
        SR_image = self.generator(LR_image)
        SR_pred = self.discriminator(SR_image)
        real_ = torch.ones(SR_pred.shape).to(self.device)
        
        adversial_loss = self.bceloss(SR_pred, real_)
        perceptual_loss = self.mseloss(self.vgg_features(HR_image), self.vgg_features(SR_image))
        content_loss = self.mseloss(HR_image, SR_image)
        
        return content_loss + 0.001*adversial_loss + 0.006*perceptual_loss
    

In [6]:
class parser():
    
    def __init__(self):
        
        self.crop_size = 88
        self.upscale_factor = 4 
        self.num_epochs = 100

In [7]:
opt = parser()
UPSCALE_FACTOR =4

In [8]:
path = '/data/nirmalps/VOC2012/VOC2012/JPEGImages/'
valpath = '/data/nirmalps/VOC2012/VOC2012/Val/'

In [9]:
train_set = dataset_train_from_folder(path, crop_size=opt.crop_size, upscale_factor=opt.upscale_factor)
val_set = dataset_val_from_folder(valpath, upscale_factor=UPSCALE_FACTOR, crop_size=opt.crop_size)
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

In [10]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [11]:
generator = generator().to(device)

In [12]:
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    generator = nn.DataParallel(generator)

Let's use 4 GPUs!


In [13]:
# generatorLoss = generatorLoss(generator, discriminator, device)
# discriminatorLoss = discriminatorLoss(generator, discriminator, device)

In [14]:
optimizerG = optim.Adam(generator.parameters())

In [15]:

vgg = vgg19(pretrained=True, progress=True)
        
vgg_loss = nn.Sequential(*(list(vgg.features)[:9])).eval()
for param in vgg_loss.parameters():
    param.requires_grad = False
        
vgg_features = vgg_loss.to(device)
mseloss = nn.MSELoss().cuda()
mseloss = nn.MSELoss().cuda()

In [16]:
print(len(train_loader))
UPSCALE_FACTOR = 4

268


In [18]:
checkpoint = torch.load('/data/nirmalps/SRGAN/training_results/pretrainedResnet/epochs/netG_epoch_4_50.pth')

In [23]:
generator.load_state_dict(checkpoint)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [24]:
torch.save(generator.state_dict(), 'new1')

In [25]:
checkpoint1 = torch.load('./new1')

In [26]:
checkpoint1

OrderedDict([('module.init_model.0.weight',
              tensor([[[[-0.0068,  0.0318, -0.0456,  ..., -0.0509, -0.0424,  0.0391],
                        [ 0.0354,  0.0578,  0.0389,  ...,  0.0376, -0.0438,  0.0220],
                        [ 0.0168, -0.0496, -0.0589,  ..., -0.0308,  0.0104,  0.0517],
                        ...,
                        [-0.0201,  0.0233,  0.0311,  ...,  0.0165,  0.0360, -0.0129],
                        [ 0.0157,  0.0335,  0.0113,  ...,  0.0431, -0.0108, -0.0268],
                        [ 0.0024,  0.0248, -0.0198,  ..., -0.0396, -0.0305, -0.0164]],
              
                       [[ 0.0493, -0.0125, -0.0070,  ...,  0.0495, -0.0475, -0.0427],
                        [-0.0420, -0.0736,  0.0176,  ...,  0.0077, -0.0771,  0.0278],
                        [ 0.0375, -0.0007, -0.0084,  ...,  0.0234,  0.0340, -0.0624],
                        ...,
                        [ 0.0384,  0.0328, -0.0283,  ...,  0.0400, -0.0243, -0.0731],
                      

In [27]:
checkpoint

OrderedDict([('module.init_model.0.weight',
              tensor([[[[-0.0068,  0.0318, -0.0456,  ..., -0.0509, -0.0424,  0.0391],
                        [ 0.0354,  0.0578,  0.0389,  ...,  0.0376, -0.0438,  0.0220],
                        [ 0.0168, -0.0496, -0.0589,  ..., -0.0308,  0.0104,  0.0517],
                        ...,
                        [-0.0201,  0.0233,  0.0311,  ...,  0.0165,  0.0360, -0.0129],
                        [ 0.0157,  0.0335,  0.0113,  ...,  0.0431, -0.0108, -0.0268],
                        [ 0.0024,  0.0248, -0.0198,  ..., -0.0396, -0.0305, -0.0164]],
              
                       [[ 0.0493, -0.0125, -0.0070,  ...,  0.0495, -0.0475, -0.0427],
                        [-0.0420, -0.0736,  0.0176,  ...,  0.0077, -0.0771,  0.0278],
                        [ 0.0375, -0.0007, -0.0084,  ...,  0.0234,  0.0340, -0.0624],
                        ...,
                        [ 0.0384,  0.0328, -0.0283,  ...,  0.0400, -0.0243, -0.0731],
                      

In [29]:
generator = generator.to('cpu')

In [30]:
torch.save(generator.state_dict(),'./generator_cpu')

In [28]:
results = {'g_loss': [], 'psnr': [], 'ssim': []}

mb = master_bar(range(1, opt.num_epochs))
for epoch in mb:
    generator.train()
    
    running_results = {'batch_sizes': 1, 'd_loss': 0, 'g_loss': 0}
    for i, data in zip(progress_bar(range(len(train_loader)), parent=mb), train_loader):
        
        HR_image, LR_image = data
        HR_image = HR_image.to(device)
        LR_image = LR_image.to(device)
        
        m = HR_image.size(0)
        batch_size = m
        running_results['batch_sizes'] += batch_size
        
        ############################
        # (2) Update G network: minimize adversial loss + Perception Loss + content Loss
        ###########################
        
        generator.zero_grad()
        
        SR_image = generator(LR_image)
        
        perceptual_loss = mseloss(vgg_features(HR_image), vgg_features(SR_image))
        content_loss = mseloss(HR_image, SR_image)
        
        g_loss = content_loss + 0.001*perceptual_loss
        
        running_results['g_loss'] += g_loss.item() * batch_size
        
#         g_loss = generatorLoss(LR, HR)
        g_loss.backward()
        optimizerG.step()
    
    ############################
    # (3) post epoch Summary
    ###########################
    
    print('[%d/%d] Loss_G: %.4f' % (
                  epoch, opt.num_epochs, running_results['g_loss'] / running_results['batch_sizes']))
    
    generator.eval()
    
    
    #directory for result
    out_path = 'training_results/pretrainedResnet/images/SRF_' + str(UPSCALE_FACTOR) + '/'
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    
    #validation starts
    val_bar = tqdm(val_loader)
    valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
    val_images = []
    
    for val_lr, val_hr_restore, val_hr in val_bar:
        batch_size = val_lr.size(0)
        valing_results['batch_sizes'] += batch_size
        LR = Variable(val_lr, volatile=True)
        HR = Variable(val_hr, volatile=True)
        if torch.cuda.is_available():
            LR = LR.to(device)
            HR = HR.to(device)
        SR = generator(LR)

        batch_mse = ((SR - HR) ** 2).data.mean()
        valing_results['mse'] += batch_mse * batch_size
        batch_ssim = pytorch_ssim.ssim(SR, HR).item()
        valing_results['ssims'] += batch_ssim * batch_size
        valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes']))
        valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
        val_bar.set_description(
            desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                valing_results['psnr'], valing_results['ssim']))

        val_images.extend(
            [display_transform()(val_hr_restore.squeeze(0)), display_transform()(HR.data.cpu().squeeze(0)),
             display_transform()(SR.data.cpu().squeeze(0))])
    
    #save val image (sr, lr, hr)
    val_images = torch.stack(val_images)
    val_images = torch.chunk(val_images, val_images.size(0) // 15)
    val_save_bar = tqdm(val_images, desc='[saving training results]')
    index = 1
    for image in val_save_bar:
        image = utils.make_grid(image, nrow=3, padding=5)
        utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
        index += 1

    # save model parameters
    out_path = 'training_results/pretrainedResnet/epochs/'
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    if epoch %5 ==0 and epoch !=0:
        torch.save(generator.state_dict(), 'training_results/pretrainedResnet/epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
      
    # save loss\scores\psnr\ssim
#     results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
    results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
#     results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
#     results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
    results['psnr'].append(valing_results['psnr'])
    results['ssim'].append(valing_results['ssim'])

    #displaying statistics
    out_path = 'training_results/pretrainedResnet/statistics/'
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    
    if epoch % 10 == 0 and epoch != 0:
        
        data_frame = pd.DataFrame(
            data={'Loss_G': results['g_loss'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
            index=range(1, epoch + 1))

data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')

    
        
        

Traceback (most recent call last):
  File "/data/nirmalps/Anaconda3/envs/env1/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/data/nirmalps/Anaconda3/envs/env1/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/data/nirmalps/Anaconda3/envs/env1/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/data/nirmalps/Anaconda3/envs/env1/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/data/nirmalps/Anaconda3/envs/env1/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/data/nirmalps/Anaconda3/envs/env1/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
Traceback (most recent call last):
  Fil

KeyboardInterrupt: 

In [45]:
def load_checkpoint(generator,discriminator, optimizerD, optimizerG, filename=resume):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        generator.load_state_dict(checkpoint['g_state_dict'])
        optimizerG.load_state_dict(checkpoint['g_optimizer'])
        discriminator.load_state_dict(checkpoint['d_state_dict'])
        optimizerD.load_state_dict(checkpoint['d_optimizer'])
        
#         losslogger = checkpoint['losslogger']
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return generator,discriminator, optimizerD, optimizerG