<a href="https://colab.research.google.com/github/leticia-chen/DeepLearning_DeblurGan2_SR/blob/main/DeblurGanv2_SR_for_stanCode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# this mounts your Google Drive to the Colab VM.
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
# 請輸入資料夾之所在位置
FOLDERNAME = 'Colab\ Notebooks/DeblurGanv2-SR'
assert FOLDERNAME is not None, "[!] Enter the foldername."

In [None]:
# now that we've mounted your Drive, this ensures that
# the Python interpreter of the Colab VM can load
# python files from within it.
import sys
sys.path.append('/content/drive/MyDrive/{}'.format(FOLDERNAME))

In [None]:
# %pwd 显示当前工作目录
# %cd 改变当前工作目录

In [None]:
# Get to the folder we are at
%cd drive/MyDrive/$FOLDERNAME/

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('using device:', device)

using device: cuda


In [None]:
if hasattr(torch.cuda, 'empty_cache'):
    torch.cuda.empty_cache()

In [None]:
%ls

In [None]:
# 查数量
import os

folders = os.listdir('./GOPRO_Large/train')
print(folders)
folders = os.listdir('./GOPRO_Large/test')
print(folders)

In [None]:
for folder in folders:
  blur_sharp_folders = os.path.join('./GOPRO_Large/train', folder)
  blur_sharp_folders = os.path.join('./GOPRO_Large/test', folder)
  print(blur_sharp_folders)

In [None]:
blur_files = os.listdir('./DeblurGANv2/submit')
print('blured:', len(blur_files))
sharp_files = os.listdir('./GOPRO_Large/train/sharp')
print('sharp:', len(sharp_files))
test_blur_files = os.listdir('./DeblurGANv2/submit_test')
print('blured:', len(test_blur_files))
test_sharp_files = os.listdir('./GOPRO_Large/test/sharp')
print('sharp:', len(test_sharp_files))

blured: 2103
sharp: 2103
blured: 1111
sharp: 1111


Building Dataset

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

class GoproDataset(Dataset):
  def __init__(self, root, submit, hr_shape):
    # hr=high resolution, lr=low resolution
    hr_height, hr_width = hr_shape                     

    self.lr_transform = transforms.Compose([
            transforms.Resize((hr_height // 4, hr_width // 4), Image.BICUBIC),     
            transforms.ToTensor(),                                                  # channel, H, W
            transforms.Normalize(mean, std)])
    
    self.hr_transform = transforms.Compose([
            transforms.Resize((hr_height, hr_width), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])
    
    self.sharp_file_paths = []

    sub_folders = os.listdir(root)                                                  # blur, sharp

    for folder_name in sub_folders:
      if folder_name == 'sharp':
        sharp_sub_folder = os.path.join(root, folder_name)
        print(sharp_sub_folder)
        sharp_file_names = os.listdir(sharp_sub_folder)                             # img´s files name

        for file_name in sharp_file_names:
          sharp_file_path = os.path.join(sharp_sub_folder, file_name)
          self.sharp_file_paths.append(sharp_file_path)
          self.sharp_file_paths.sort()
    # print('sharp:', self.sharp_file_paths[:10])
    # print('3 last:', self.sharp_file_paths[2100:])

    self.blur_file_paths = []

    for file_name in os.listdir(submit):
      blur_file_path = os.path.join(submit, file_name)
      self.blur_file_paths.append(blur_file_path)
      self.blur_file_paths.sort()
    # print('blur:', self.blur_file_paths[:10])
    # print('3 last:', self.blur_file_paths[2100:])

  def __getitem__(self, index):

      sharp_file_path = self.sharp_file_paths[index % len(self.sharp_file_paths)]
      blur_file_path = self.blur_file_paths[index % len(self.blur_file_paths)]

      blur_img = Image.open(blur_file_path).convert('RGB')             
      sharp_img = Image.open(sharp_file_path).convert('RGB')

      img_lr = self.lr_transform(blur_img)                             
      img_hr = self.hr_transform(sharp_img)

      return {"blur": img_lr, "sharp": img_hr}                         

  # 定义dataloader和每次读取图像时均调用
  def __len__(self):
      return len(self.sharp_file_paths)


Data Load


In [None]:
from torch.utils.data import DataLoader

# train 的 blur 照片在 DeblurGANv2/submit
train_path = 'GOPRO_Large/train'                                               
deblurred_path = 'DeblurGANv2/submit'
test_path = 'GOPRO_Large/test'
deblurred_test_path = 'DeblurGANv2/submit_test'

train = False

if train:
  mini_train = (DataLoader(GoproDataset(train_path, deblurred_path, (288, 512)),
                          batch_size=5, shuffle=True))
else:
  mini_test = (DataLoader(GoproDataset(test_path, deblurred_test_path, (720, 1280)),
                          batch_size=5, shuffle=True))

mini_train_shape = next(iter(mini_test))

In [None]:
print(mini_train_shape['blur'].shape)
print(mini_train_shape['sharp'].shape)

torch.Size([5, 3, 180, 320])
torch.Size([5, 3, 720, 1280])


Generator: Residual block, Upsample block, Generator net

In [None]:
import math
import torch
from torch import nn


class ResidualBlock(nn.Module):
  def __init__(self, channels):                                 
    super(ResidualBlock, self).__init__()
    # in_channel X out_channel X kernel X padding
    # channels = 64                                                   
    self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    self.bn1 = nn.BatchNorm2d(channels, 0.8)                          
    self.prelu = nn.PReLU(channels)
    self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(channels, 0.8)

  def forward(self, x):
    short_cut = x
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.prelu(x)
    x = self.conv2(x)
    x = self.bn2(x)

    return x + short_cut

In [None]:
class UpsampleBlock(nn.Module):
  def __init__(self, in_channels, up_scale):                           
    super(UpsampleBlock, self).__init__()
    self.conv = nn.Conv2d(in_channels, in_channels*up_scale**2, kernel_size=3, padding=1)      # 64->256
    self.pixel_suffle = nn.PixelShuffle(up_scale)
    self.prelu = nn.PReLU(in_channels)

  def forward(self, x):
    x = self.conv(x)
    x = self.pixel_suffle(x)
    x = self.prelu(x)
    return(x)


In [None]:
class NetG(nn.Module):
  def __init__(self, num_residual=16):                            

    super(NetG, self).__init__()

    # First layer
    self.conv1 = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=9, padding=4),                 
        nn.PReLU(64)
    )

    # Residual blocks
    self.res_blocks = []
    for _ in range(num_residual):
      self.res_blocks.append(ResidualBlock(64))
    self.res_blocks = nn.Sequential(*self.res_blocks)

    # Second conv layer pos residual
    self.conv2 = nn.Sequential(
        nn.Conv2d(64, 64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64, 0.8)
    )

    # Upsampling layer
    self.upsample = []
    for _ in range(2):
      self.upsample.append(UpsampleBlock(64, 2))
    self.upsample = nn.Sequential(*self.upsample)

    # the last conv layer
    self.conv3 = nn.Sequential(
        nn.Conv2d(64, 3, kernel_size=9, stride= 1, padding=4),
        nn.Tanh())

  def forward(self, x):
    x = self.conv1(x)
    short_cut = x
    x = self.res_blocks(x)
    x = self.conv2(x)
    x = x + short_cut
    x = self.upsample(x)
    out = self.conv3(x)

    return out


Discriminator

In [None]:
class NetD(nn.Module):
  def __init__(self):
    super(NetD, self).__init__()
    self.d_net = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=3, padding=1),
        nn.LeakyReLU(0.2),

        nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2),

        nn.Conv2d(64, 128, kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),

        nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),

        nn.Conv2d(128, 256, kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(256, 512, kernel_size=3, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),

        nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),

        # AdaptiveAvgPool2d 可将图像最后调整成自己要的 (H,W)
        nn.AdaptiveAvgPool2d(1),                        
        nn.Conv2d(512, 1024, kernel_size=1),
        nn.LeakyReLU(0.2),
        nn.Conv2d(1024, 1, kernel_size=1)
    )

  def forward(self, x):
    batch_size = x.size(0)                
    return torch.sigmoid(self.d_net(x).view(batch_size))


Load vgg19 pretrained model

Extractor 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'

In [None]:
from torchvision.models import vgg19

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

    def forward(self, img):
        return self.feature_extractor(img)
        

Inicialization

In [None]:
netG = NetG().cuda()
netD = NetD().cuda()
feature_extractor = FeatureExtractor().cuda()
feature_extractor.eval()

Define Optimizers and LOSS

In [None]:
lr = 0.00008
betas = (0.5, 0.999)

d_optimizer = torch.optim.Adam(netD.parameters(), lr=lr, betas=betas)
g_optimizer = torch.optim.Adam(netG.parameters(), lr=lr, betas=betas)

In [None]:
mes_loss = torch.nn.MSELoss().to(device)
l1_loss = torch.nn.L1Loss().to(device)

def discriminator_loss(real_output, fake_output1):                            # discriminator 的 output 是 sigmoid, 机率，这里是一个 batch 单位
  real_loss = mes_loss(real_output, torch.ones_like(real_output))             # 产生与 given tensor 一样的 shape,但 elements 全为 1 的 tensor
  fake_loss = mes_loss(fake_output1, torch.zeros_like(fake_output1))          # 与 fake 的 label 比
  return (real_loss + fake_loss)/2

def generator_loss(fake_output2):
  return mes_loss(fake_output2, torch.ones_like(fake_output2))

Model-checkpoint

In [None]:
def model_checkpoint(filename):
  generator_file_paths = []
  discriminator_file_paths = []

  files = os.listdir(filename)                                  # blur, sharp
  # print(len(files))
  # print('files:', files)
  # print(files[0], files[1], files[64], files[65])

  for file in files:
    if file[0] == 'g':
      generator_file = os.path.join(filename, file)
      generator_file_paths.append(generator_file)
      generator_file_paths.sort()
    else:
      discriminator_file = os.path.join(filename, file)
      discriminator_file_paths.append(discriminator_file)
      discriminator_file_paths.sort()

  return generator_file_paths, discriminator_file_paths  

PSNR Function

In [None]:
import torch
import math

def PSNR(imgs1, imgs2):

  total_p = 0
  total_s = 0

  for i in range(len(imgs1)):

    mse = torch.mean((imgs1[i] / 255. - imgs2[i] / 255.) ** 2)

    if mse == 0:
      p = 100
    else:
      PIXEL_MAX = 1
      p = 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
    
    total_p += p
    # total_s += s

  avg_p = total_p/len(imgs1)
  avg_s = total_s/len(imgs2)
    
  return avg_p, avg_s

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def design(title, lst1, lst2, label1, label2, xlabel, ylabel, mum):
  plt.figure(figsize=(6,4))                         # 图表大小
  plt.title(str(title))                             # 图表名称
  plt.plot(lst1,label=str(label1))                  # 画出线条，label 是线条名称
  plt.plot(lst2,label=str(label2))
  plt.xlabel(str(xlabel))                           # x 轴名称
  plt.ylabel(str(ylabel))                           # y 轴名称
  plt.xticks(ticks=np.arange(len(lst1)), labels=np.arange(1, len(lst1)+1))   # 刻度
  plt.legend()
   # save 是 plt.savefig('档名', bbox_inches=`tight`)->将图表多余的空白区域裁减掉
  plt.savefig(f'matplotlib/{title}.png', bbox_inches='tight')
  plt.subplot(33, 2, num)
  plt.show() 

Define Train or Test model

In [None]:
train = False
test = True

Training

In [None]:
from torch.autograd import Variable
import sys
import itertools
from torchvision.utils import save_image, make_grid
from PIL import Image
import time
from torchvision.io import read_image
import matplotlib.pyplot as plt
import numpy as np

In [None]:
d_losses_0106 = []
g_losses_0106 = []

In [None]:
cuda = torch.cuda.is_available()
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

num_epochs = 100            # 100
num_epochs_test = 1
checkpoint_interval = 5     # 10
num_iters = 100             # 200

start = time.time()

if train:

  for epoch in range(num_epochs):  
    
    loss_dis, loss_gen = 0, 0

    for i, imgs in enumerate(mini_train):
      
      netD.train()
      netG.train()
     
      imgs_lr = Variable(imgs["blur"].type(Tensor)).to(device)   # torch.Size([4,3,32,32])
      imgs_hr = Variable(imgs["sharp"].type(Tensor)).to(device)   # torch.Size([4,3,128,128])
      # print('imgs_hr:', len(imgs_hr))

      # fake_imgs = netG(imgs_lr)         # 这是一个 batch 
      # print(len(fake_imgs))
      # print(fake_imgs[0][0])
      # print(fake_imgs)

      # Train discriminator
      netD.zero_grad()

      real_output = netD(imgs_hr)                               # 清晰的照片->变机率 sigmoid

      fake_imgs = netG(imgs_lr)                                 # deblured 的照片, 出来是 Tanh
      fake_output1 = netD(fake_imgs.detach())                   # 再丢入discriminator, fake_images.detach()-> 梯度截断，在 backward 不会进行 G.D.

      d_loss = discriminator_loss(real_output, fake_output1)
      d_loss.backward()
      d_optimizer.step()

      # Train generator

      netG.zero_grad()

      fake_output2 = netD(fake_imgs)                 # 这边的W已经跟上面的fake_output不同了，因为上面已经做了 backward, W 已被更新，这个 output 是可以做G.D.的
      gen_loss = generator_loss(fake_output2)

        # content loss
      gen_features = feature_extractor(fake_imgs)
      real_features = feature_extractor(imgs_hr)
      content_loss = l1_loss(gen_features, real_features.detach())

        # total generator loss-> g_loss = gen_loss + content loss
      g_loss = content_loss + 1e-3 * gen_loss
      g_loss.backward()
      g_optimizer.step()

  # ------------------------------------------------------------------------------
      loss_dis += d_loss.item()         # .item()->to get value
      loss_gen += g_loss.item()

      d_losses_0106.append(d_loss.item())
      g_losses_0106.append(g_loss.item())

  # ------------------------------------------------------------------------------
      if i % 100 == 0:
        sys.stdout.write(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, num_epochs, i, len(mini_train), d_loss.item(), g_loss.item())+'\n'
            ) # 相当于print()

      batches_done = epoch * len(mini_train) + i
      if batches_done % num_iters == 0:
        #保存上采样和SRGAN输出的图像
        imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)          # 可以是 size=(H,W) or scale_factor
        fake_imgs = make_grid(fake_imgs, nrow=1, normalize=True)
        imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
        imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
        img_grid = torch.cat((imgs_lr, fake_imgs, imgs_hr), -1)            # -1
        save_image(img_grid, f"images_0106/%d.png" % batches_done, normalize=False)

    if epoch % checkpoint_interval == 0:
      torch.save(netG.state_dict(), "saved_models_0106/generator_%d.pth" % epoch)
      torch.save(netD.state_dict(), "saved_models_0106/discriminator_%d.pth" % epoch)

  end = time.time()
  print(f'Total time: {end - start} seconds.')

Testing

In [None]:
if test:

  generator_w, discriminator_w = model_checkpoint('Pre_trained_models_by_Jay')    # images 720x1280
  # generator_w, discriminator_w = model_checkpoint('saved_models_0107')          # images 288x512
  generator_w.sort()
  discriminator_w.sort()
  # print(len(generator_w))
  # print(len(discriminator_w))
  # print(generator_w)
  # print(discriminator_w)

  all_pre_train_dloss_status, all_pre_train_gloss_status = [], []
  all_pre_train_hr_lr_p, all_pre_train_hr_fake_p, all_pre_train_lr_fake_p = [], [], []

  for j in range(len(generator_w)):

    netG.load_state_dict(torch.load(generator_w[j]))
    netD.load_state_dict(torch.load(discriminator_w[j]))

    locals()['d_losses_test'+str(j)], locals()['g_losses_test'+str(j)] = [], []
    p_hr_lr, p_hr_fake, p_lr_fake = [], [], []
    l_d, l_g, p_h_l, p_h_f = [], [], [], []

    psnr_total1, psnr_total2, psnr_total3 = 0, 0, 0
    loss_dis_test, loss_gen_test = 0, 0
    dummy1, dummy2, dummy3, dummy4 = 0, 0, 0, 0

    start = time.time() 
    num = 0
    for i, imgs in enumerate(mini_test):
      netD.eval()
      netG.eval()

      imgs_lr = Variable(imgs["blur"].type(Tensor)).to(device)   # torch.Size([4,3,32,32])
      imgs_hr = Variable(imgs["sharp"].type(Tensor)).to(device)

      # Test discriminator
      netD.zero_grad()

      real_output = netD(imgs_hr)                            

      fake_imgs = netG(imgs_lr)                               
      fake_output1 = netD(fake_imgs.detach())                 

      d_loss = discriminator_loss(real_output, fake_output1)
      # d_loss.backward()   # no need to run
      # d_optimizer.step()  # no need to run

      # Test generator

      netG.zero_grad()

      fake_output2 = netD(fake_imgs)                 
      gen_loss = generator_loss(fake_output2)

        # content loss
      gen_features = feature_extractor(fake_imgs)
      real_features = feature_extractor(imgs_hr)
      content_loss = l1_loss(gen_features, real_features.detach())

        # total generator loss-> g_loss = gen_loss + content loss
      g_loss = content_loss + 1e-3 * gen_loss
      # g_loss.backward()   # no need to run
      # g_optimizer.step()  # no need to run

# ------------------------------------------------------------------------------
      loss_dis_test += d_loss.item()
      loss_gen_test += g_loss.item()

      locals()['d_losses_test'+str(j)].append(d_loss.item())
      locals()['g_losses_test'+str(j)].append(g_loss.item())
      # print(locals()['d_losses_test'+str(j)])
      
      dummy1 += d_loss.item()
      dummy2 += g_loss.item()
      if i % 50 == 0:
        d = dummy1/50
        g = dummy2/50
        l_d.append(d)
        l_g.append(g)
        dummy1 = 0
        dummy2 = 0

# ------------------------------------------------------------------------------
      imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
    
      p1, s1 = PSNR(imgs_hr, imgs_lr)
      p2, s2 = PSNR(imgs_hr, fake_imgs)
      p3, s3 = PSNR(imgs_lr, fake_imgs)
      
      p_hr_lr.append(p1)
      p_hr_fake.append(p2)
      p_lr_fake.append(p3)
      # print(p_hr_lr)

      dummy3 += p1
      dummy4 += p2
      if i % 50 == 0:
        hl = dummy3/50
        hf = dummy4/50
        p_h_l.append(hl)
        p_h_f.append(hf)
        dummy3 = 0
        dummy4 = 0
# ------------------------------------------------------------------------------
      if i % 100 == 0:          # i 是 mini batch 的数量
        sys.stdout.write(
                "[Pre train model %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [psnr hr_lr: %f] [psnr hr_fake: %f] [psnr lr_fake: %f]"
                % (j, len(generator_w), i, len(mini_test), d_loss.item(), g_loss.item(), p1, p2, p3)+'\n'
            ) # 相当于print()

      if i % 100 == 0:
        
        # imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)          # 可以是 size=(H,W) or scale_factor
        fake_imgs = make_grid(fake_imgs, nrow=1, normalize=True)
        imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
        imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
        img_grid = torch.cat((imgs_lr, fake_imgs, imgs_hr), -1)            # -1
        save_image(img_grid, f"images_test_128X512_all_big/pre_train_model_Jay:{j}_{i} of {len(mini_test)}.png", normalize=False)

    if j & checkpoint_interval == 0:
      torch.save(netG.state_dict(), "saved_models_test_128X512_all_big/generator_%d.pth" % j)
      torch.save(netD.state_dict(), "saved_models_test_128X512_all_big/discriminator_%d.pth" % j)

    every_pre_train_dw = (sum(locals()['d_losses_test'+str(j)])/len(locals()['d_losses_test'+str(j)]), int(j))
    every_pre_train_gw = (sum(locals()['g_losses_test'+str(j)])/len(locals()['g_losses_test'+str(j)]), int(j))

    all_pre_train_dloss_status.append(every_pre_train_dw)
    all_pre_train_gloss_status.append(every_pre_train_gw)
    # print('all_pre_train_dloss_status:', all_pre_train_dloss_status)

    every_pre_train_hr_lr_p = (sum(p_hr_lr)/len(p_hr_lr), int(j))
    every_pre_train_hr_fake_p = (sum(p_hr_fake)/len(p_hr_fake), int(j))
    every_pre_train_lr_fake_p = (sum(p_lr_fake)/len(p_lr_fake), int(j))
    
    all_pre_train_hr_lr_p.append(every_pre_train_hr_lr_p)
    all_pre_train_hr_fake_p.append(every_pre_train_hr_fake_p)
    all_pre_train_lr_fake_p.append(every_pre_train_lr_fake_p)
    # print('all_pre_train_hr_lr_p:', all_pre_train_hr_lr_p)
    
    ###### Loss result figure for each pre train model
    plt.figure(figsize=(20,5))
    ax1 = plt.subplot(1,2,1)
    ax2 = plt.subplot(1,2,2)
    ax1.set_title('D_loss and G_loss_'+str(j))
    ax1.plot(l_d,label='D_loss')                 
    ax1.plot(l_g,label='G_loss')
    ax1.set_xlabel('Batch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ###### PSNR figure result for each pre train model                                       
    ax2.set_title('PSNR hr_lr X hr_fake_'+str(j))                                    
    ax2.plot(p_h_l, label='hr_lr')                                     
    ax2.plot(p_h_f, label='hr_fake')
    ax2.set_xlabel('Batch')                                                       
    ax2.set_ylabel('PSNR')                                                        
    ax2.legend()
    plt.show()

    end = time.time()
    print(f'Total time: {end - start} seconds.')

Save Losses and PSNR in CSV file

In [None]:
with open('losses.csv', 'w') as out:
  out.write('d_loss,index_d,g_loss,index_g\n')
  for i in range(len(all_pre_train_dloss_status)):
    out.write(str(all_pre_train_dloss_status[i][0])+','+str(all_pre_train_dloss_status[i][1])+','
    +(str(all_pre_train_gloss_status[i][0])+','+str(all_pre_train_gloss_status[i][1])+'\n'))

with open('psnr.csv', 'w') as out:
  out.write('hr_lr,index,hr_fake,index\n')
  for i in range(len(all_pre_train_hr_lr_p)):
    out.write(str(all_pre_train_hr_lr_p[i][0])+','+str(all_pre_train_hr_lr_p[i][1])+','
    +(str(all_pre_train_hr_fake_p[i][0])+','+str(all_pre_train_hr_fake_p[i][1])+'\n'))

In [None]:
print('D loss:', all_pre_train_dloss_status)
print('******************************************')
print('G loss:', all_pre_train_gloss_status)
print('*********************************************')
print('PSNR hr x lr:', all_pre_train_hr_lr_p)
print('***********************************************')
print('PSNR hr x fake:', all_pre_train_hr_fake_p)

Find the best

In [None]:
d_L = all_pre_train_dloss_status
g_L = all_pre_train_gloss_status
d = max(d_L)
print('d loss:', d)
g = min(g_L)
print('g loss:', g)

P1 = all_pre_train_hr_lr_p
P2 = all_pre_train_hr_fake_p
P3 = all_pre_train_lr_fake_p
p1 = max(P1)
p2 = max(P2)
p3 = max(P3)
print('psnr_hr_lr:', p1, 'psnr_hr_fake:', p2, 'psnr_lr_fake:', p3)