In [None]:
import os
from math import log10

import pandas as pd
import torch
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision
from tqdm.notebook import tqdm

from pathlib import Path
from PIL import Image, ImageFile

# import pytorch_ssim
import data_utils
from data_utils import CustumDataset, display_transform
from loss import GeneratorLoss
from model import Generator, Discriminator

high_res = 128
upscale_factor = 2

In [None]:
train_set = CustumDataset('data/train/images', high_res, upscale_factor)
val_set = CustumDataset('data/valid/images', high_res, upscale_factor)

train_loader = DataLoader(dataset=train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=1, shuffle=True)

In [None]:
netG = Generator(upscale_factor).cuda()
netD = Discriminator().cuda()

generator_criterion = GeneratorLoss().cuda()

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

upsample_nearest = torch.nn.Upsample(scale_factor=2, mode='nearest').cuda()
upsample_bilinear = torch.nn.Upsample(scale_factor=2, mode='bilinear').cuda()

In [None]:
num_epochs = 100

for epoch in range(1, num_epochs + 1):
    train_bar = tqdm(train_loader)
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

    netG.train()
    netD.train()
    for data, target in train_bar:
        g_update_first = True
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        real_img = Variable(target).cuda()
        z = Variable(data).cuda()
        fake_img = netG(z)

        netD.zero_grad()
        real_out = netD(real_img).mean()
        fake_out = netD(fake_img).mean()
        d_loss = 1 - real_out + fake_out
        d_loss.backward(retain_graph=True)
        optimizerD.step()


        netG.zero_grad()
        fake_img = netG(z)
        fake_out = netD(fake_img).mean()
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()
        fake_img = netG(z)
        fake_out = netD(fake_img).mean()
        optimizerG.step()

        # loss for current batch before optimization 
        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['d_score'] += real_out.item() * batch_size
        running_results['g_score'] += fake_out.item() * batch_size

        train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, num_epochs, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))
                
    
    
    netG.eval()
    out_path = 'training_results/SRF_' + str(upscale_factor) + '/'
    if not os.path.exists(out_path):
            os.makedirs(out_path)

    with torch.no_grad():
        val_bar = tqdm(val_loader)
        valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
        val_images = []
        # for val_lr, val_hr_restore, val_hr in val_bar:
        for val_lr, val_hr in val_bar:
            batch_size = val_lr.size(0)
            valing_results['batch_sizes'] += batch_size
            lr = val_lr
            hr = val_hr
            if torch.cuda.is_available():
                lr = lr.cuda()
                hr = hr.cuda()
            sr = netG(lr)

            lr_upsample = upsample_nearest(lr)
            lr_bilinear = upsample_bilinear(lr)
    
            batch_mse = ((sr - hr) ** 2).data.mean()
            valing_results['mse'] += batch_mse * batch_size
            #batch_ssim = pytorch_ssim.ssim(sr, hr).item()
            #valing_results['ssims'] += batch_ssim * batch_size
            valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
            valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
            val_bar.set_description(
                desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                    valing_results['psnr'], valing_results['ssim']))
    
            val_images.extend(
                [display_transform()(lr_upsample.data.cpu().squeeze(0)), display_transform()(lr_bilinear.data.cpu().squeeze(0)), display_transform()(sr.data.cpu().squeeze(0)), display_transform()(hr.data.cpu().squeeze(0))])

            '''val_images.extend(
                [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
                    display_transform()(sr.data.cpu().squeeze(0))])'''
        val_images = torch.stack(val_images)
        val_images = torch.chunk(val_images, val_images.size(0) // 32)
        val_save_bar = tqdm(val_images, desc='[saving training results]')
        index = 1
        for image in val_save_bar:
            image = utils.make_grid(image, nrow=8, padding=5)
            utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
            index += 1

    # save loss\scores\psnr\ssim
    results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
    results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
    results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
    results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
    results['psnr'].append(valing_results['psnr'])
    results['ssim'].append(valing_results['ssim'])

    if epoch % 10 == 0 and epoch != 0:
        out_path = 'statistics/'
        data_frame = pd.DataFrame(
            data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                    'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
            index=range(1, epoch + 1))
        data_frame.to_csv(out_path + f'./res{high_res}_uf{upscale_factor}.csv', index_label='Epoch')

In [None]:
torch.save({
            'epoch': epoch,
            'G_model_state_dict': netG.state_dict(),
            'G_optimizer_state_dict': optimizerG.state_dict(),
            'G_loss': g_loss,
            'D_model_state_dict': netD.state_dict(),
            'D_optimizer_state_dict': optimizerD.state_dict(),
            'D_loss': d_loss
            }, f'./res{high_res}_uf{upscale_factor}.tar')

In [None]:
checkpoint = torch.load(f'./res{high_res}_uf{upscale_factor}.tar')
netG.load_state_dict(checkpoint['G_model_state_dict'])
netD.load_state_dict(checkpoint['D_model_state_dict'])
optimizerG.load_state_dict(checkpoint['G_optimizer_state_dict'])
optimizerD.load_state_dict(checkpoint['D_optimizer_state_dict'])

g_loss = checkpoint['G_loss']
d_loss = checkpoint['D_loss']

netG.train()
netD.train()