Training of the SRGAN

This notebook call the training function. generated models can then be evaluated

In [None]:
from train import train_gan
from dataclass.MicroImageClass import MicroImages
import torch
from matplotlib import pyplot as plt
from dataclass.prepare_data import prepare_datasets
from dataclass.dataset import ImagesDataset, convert_image
import torchvision
import numpy as np
from torchvision.utils import save_image
from models.Generator import GeneratorV0
from piqa import SSIM

In [None]:
scaling_factor = 4

In [None]:
# Some datasets
# '../data/DIV2K_train_LR_bicubic/X2'
# '../data/DIV2K_train_LR_bicubic/X2'
# '../data/DIV2K_train_HR'
# '../data/val2014'
prepare_datasets(train_folders=['../data/val2014'],
                      test_folders=['../data/DIV2K_train_HR'],
                      min_size=int(96*3),
                      output_folder='dataclass/')

In [None]:
output_format = "[0, 1]"
input_format = "[0, 1]"
trainset = ImagesDataset("dataclass/", crop_size=96, scaling_factor=scaling_factor, lr_format=input_format, hr_format=output_format)
testset = ImagesDataset("dataclass/", crop_size=96, scaling_factor=scaling_factor, lr_format=input_format, hr_format=output_format, train=False)


In [None]:
model_name = "new-gan-model"

In [None]:
torch.cuda.reset_peak_memory_stats()

In [None]:
# train_gan, for details see train.py
gen, dis, g_losses, d_losses, ssim_scores = train_gan(
    trainset=trainset,
    testset=testset,
    batch_size=32,
    epochs=8,
    lr=0.0001,
    gpu=True,
    gen_args={"nbr_channels": 64, "nbr_blocks": 5, "normalize": True, "scaling_factor": scaling_factor}, 
    dis_args={"nbr_channels": 64},
    num_workers=4,
    alpha=0.001,
    r1_penalty=0.01,
    labels="smooth",
    content_loss_type="MSE_SSIM",
    save_file="models/save/"+model_name+".pt")

In [None]:
plt.plot([x * 32 for x in range(len(g_losses))], g_losses)
plt.xlabel("Iterations")
plt.ylabel("Generator loss")
plt.title("Generator loss")
plt.savefig('figures/gen_loss_'+model_name+'.pdf')

In [None]:
plt.plot([x * 32 for x in range(len(d_losses))], d_losses)
plt.xlabel("Iterations")
plt.ylabel("Discriminator loss")
plt.title("Discriminator loss")
plt.savefig('figures/dis_loss_'+model_name+'.pdf')

In [None]:
if trainset is not None and ssim_scores:
    plt.plot(range(len(ssim_scores)), ssim_scores)
    plt.xlabel("Epochs")
    plt.ylim((0, 1.0))
    plt.ylabel("SSIM score")
    plt.title("SSIM score")
    plt.savefig('figures/ssim_score_'+model_name+'.pdf')

In [None]:
gen.eval() # vs gen.train()


In [None]:
evalset = ImagesDataset("dataclass/", crop_size=200, scaling_factor=scaling_factor, lr_format=input_format, hr_format=output_format, train=False)
evalset_x2 = ImagesDataset("dataclass/", crop_size=96, scaling_factor=2, lr_format=input_format, hr_format=output_format, train=True)
evalset_x4 = ImagesDataset("dataclass/", crop_size=96, scaling_factor=4, lr_format=input_format, hr_format=output_format, train=True)

In [None]:
# Some function to look at output of the trained model
def show_images(img):
    img = img 
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def show_input_images(index: int, evalset):
    image_lr, image_hr = evalset[index]
    image_lr = convert_image(image_lr, input_format, "[0, 1]")
    image_hr = convert_image(image_hr, output_format, "[0, 1]")
    show_images(torchvision.utils.make_grid([image_lr]))
    show_images(torchvision.utils.make_grid([image_hr]))

def compare_hr_images(model, evalset, device="cuda:0", index=0):
    image_lr, image_hr = evalset[index]
    with torch.no_grad():
        ouput_lr = model(image_lr.unsqueeze(0).to(device))
        ouput_lr = convert_image(ouput_lr, output_format, "[0, 1]")
    image_hr = convert_image(image_hr, output_format, "[0, 1]")
    show_images(torchvision.utils.make_grid([image_hr, ouput_lr.cpu().detach()[0]]))

def compare_hr_images_with_input(model, evalset, device="cuda:0", index=0):
    image_lr, image_hr = evalset[index]
    with torch.no_grad():
        ouput_lr = model(image_lr.unsqueeze(0).to(device))
        ouput_lr = convert_image(ouput_lr, output_format, "[0, 1]")
    image_hr = convert_image(image_hr, output_format, "[0, 1]")
    show_images(torchvision.utils.make_grid([image_lr]))
    show_images(torchvision.utils.make_grid([image_hr, ouput_lr.cpu().detach()[0]]))

def compare_hr_images_x4(model, device="cuda:0", index=0):
    image_lr, image_hr = evalset_x4[index]
    with torch.no_grad():
        ouput_lr = model(image_lr.unsqueeze(0).to(device))
        ouput_lr = model(convert_image(ouput_lr, output_format, input_format))
        ouput_lr = convert_image(ouput_lr, output_format, "[0, 1]")
    image_hr = convert_image(image_hr, output_format, "[0, 1]")
    show_images(torchvision.utils.make_grid([image_hr, ouput_lr.cpu().detach()[0]]))

def compare_hr_images_x4_and_save(model, device="cuda:0", index=0, name=""):
    image_lr, image_hr = evalset_x4[index]
    with torch.no_grad():
        ouput_lr = model(image_lr.unsqueeze(0).to(device))
        ouput_lr = model(convert_image(ouput_lr, output_format, input_format))
        ouput_lr = convert_image(ouput_lr, output_format, "[0, 1]")
    image_hr = convert_image(image_hr, output_format, "[0, 1]")
    show_images(torchvision.utils.make_grid([image_hr, ouput_lr.cpu().detach()[0]]))
    save_image(image_hr, "figures/" + model_name + "_"+name+ "_real_x4.png")
    save_image(ouput_lr, "figures/" + model_name + "_" +name+"_fake_x4.png")

def compare_hr_images_and_save(model, evalset, device="cuda:0", index=0, name=''):
    image_lr, image_hr = evalset[index]
    with torch.no_grad():
        ouput_lr = model(image_lr.unsqueeze(0).to(device))
        ouput_lr = convert_image(ouput_lr, output_format, "[0, 1]")
    image_hr = convert_image(image_hr, output_format, "[0, 1]")
    show_images(torchvision.utils.make_grid([image_hr, ouput_lr.cpu().detach()[0]]))
    save_image(image_hr, "figures/" + model_name + "_"+name+ "_real.png")
    save_image(ouput_lr, "figures/" + model_name + "_" +name+"_fake.png")



In [None]:
show_input_images(4, evalset)

In [None]:
compare_hr_images(gen, evalset, 0)

In [None]:
compare_hr_images(gen, evalset, index=1)

In [None]:
compare_hr_images(gen, evalset, index=2)

In [None]:
compare_hr_images(gen, evalset, index=3)

In [None]:
compare_hr_images(gen, evalset, index=4)

In [None]:
compare_hr_images(gen, evalset, index=5)

In [None]:
compare_hr_images(gen, evalset, index=6)

Comparing with previous models

Quatlity + test ssim comme loss function + rapport + tested avec crop size plus grandre

In [None]:
from ignite.metrics import SSIM
def test_ssim_score(model, testset, device="cuda:0"):
    testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=True)
    test_ssim = SSIM(data_range=1.0, device="cpu")
    model.to(device)
    for (x, y) in testloader:  # [batch_size x 3 x w x h]
        outputs = model(x.to(device))
        test_ssim.update((outputs, y.to(device)))
    ssim_score =  test_ssim.compute()
    # print(f'ssim score: {ssim_score}')
    test_ssim.reset()
    return ssim_score

In [None]:
# test_ssim_score(gen, testset=testset, device="cuda:0")