In [1]:
from pathlib import Path

from src import checkpoint, data, metrics
from src.device import device
from src.model import Discriminator, Generator
from src.utils import UtilSRGAN
from torch import optim
from torch.utils.data import DataLoader


In [2]:
train_data = data.SRGANData(
    root_dir=Path("data/train"),
    scale=UtilSRGAN.scale,
    hr_name="hr",
    # lr_name="low_bicubic_x2",
    init_transfrom=UtilSRGAN.init_transform,
    hr_transform=UtilSRGAN.hr_transforms,
    lr_transform=UtilSRGAN.lr_transforms,
)

eval_data_set5 = data.SRGANData(
    root_dir=Path("data/test"),
    scale=UtilSRGAN.scale,
    hr_name="Set5",
    init_transfrom=UtilSRGAN.init_eval_transform,
    hr_transform=UtilSRGAN.hr_transforms,
    lr_transform=UtilSRGAN.lr_transforms,
)

eval_data_set14 = data.SRGANData(
    root_dir=Path("data/test"),
    scale=UtilSRGAN.scale,
    hr_name="Set14",
    init_transfrom=UtilSRGAN.init_eval_transform,
    hr_transform=UtilSRGAN.hr_transforms,
    lr_transform=UtilSRGAN.lr_transforms,
)


train_loader = DataLoader(train_data, batch_size=16, shuffle=False, num_workers=4)
eval_loader_set5 = DataLoader(eval_data_set5, batch_size=1, shuffle=False, num_workers=1)
eval_loader_set14 = DataLoader(eval_data_set14, batch_size=1, shuffle=False, num_workers=1)

In [3]:
generator = Generator(upscale_factor=UtilSRGAN.scale).to(device)
discriminator = Discriminator().to(device)

gen_optimizer = optim.Adam(generator.parameters(), lr=1e-4)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)

# update the optimizer to lr of 1e-5 after 100 epochs
gen_scheduler = optim.lr_scheduler.StepLR(gen_optimizer, step_size=1000, gamma=0.1)
disc_scheduler = optim.lr_scheduler.StepLR(disc_optimizer, step_size=1000, gamma=0.1)


In [4]:
model_name = "srgan-normal-loss-upscale-4"
metric = metrics.MetricSRGAN()

In [5]:
start = 0
end = 1000

In [None]:
for i in range(start, end):
    mean_gen_loss, mean_disc_loss, mean_psnr = UtilSRGAN.train(
        generator,
        discriminator,
        gen_optimizer,
        disc_optimizer,
        train_loader,
    )
    print(f"Epoch: {i} gen-loss: {mean_gen_loss:.5f}, disc-loss: {mean_disc_loss:.5f}, psnr: {mean_psnr:.5f}")
    metric.total_train_gen_loss.append(mean_gen_loss)
    metric.total_train_disc_loss.append(mean_disc_loss)
    metric.total_train_psnr.append(mean_psnr)

    mean_loss_set5, mean_psnr_set5 = UtilSRGAN.eval(
        generator,
        discriminator,
        eval_loader_set5,
        metric,
    )
    print(f"  Eval (Set5): gen-loss: {mean_loss_set5:.5f}, psnr: {mean_psnr_set5:.5f}")
    metric.total_eval_loss_set5.append(mean_loss_set5)
    metric.total_eval_psnr_set5.append(mean_psnr_set5)

    mean_loss_set14, mean_psnr_set14 = UtilSRGAN.eval(
        generator,
        discriminator,
        eval_loader_set14,
        metric,
    )
    print(f"  Eval (Set14): gen-loss: {mean_loss_set14:.5f}, psnr: {mean_psnr_set14:.5f}")
    metric.total_eval_loss_set14.append(mean_loss_set14)
    metric.total_eval_psnr_set14.append(mean_psnr_set14)

    disc_scheduler.step()
    gen_scheduler.step()

    curr_psnr = metric.get_eval_score()
    if curr_psnr > metric.best_psnr:
        print(f"  * New best psnr: {curr_psnr}")
        metric.best_epoch = i
        metric.best_psnr = curr_psnr
        checkpoint.save(
            name=f"{model_name}/best.pt",
            gen_model=generator.state_dict(),
            gen_optimizer=gen_optimizer.state_dict(),
            gen_scheduler=gen_scheduler.state_dict(),
            disc_model=discriminator.state_dict(),
            disc_optimizer=disc_optimizer.state_dict(),
            disc_scheduler=disc_scheduler.state_dict(),
            **metric.save_checkpoint(),
        )

    if not (i + 1) % 100 or i == end - 1 or i == start:
        metric.best_epoch = i
        checkpoint.save(
            name=f"{model_name}/{i}.pt",
            gen_model=generator.state_dict(),
            gen_optimizer=gen_optimizer.state_dict(),
            gen_scheduler=gen_scheduler.state_dict(),
            disc_model=discriminator.state_dict(),
            disc_optimizer=disc_optimizer.state_dict(),
            disc_scheduler=disc_scheduler.state_dict(),
            **metric.save_checkpoint(),
        )

Epoch: 0 gen-loss: 0.11420, disc-loss: 0.19062, psnr: 10.13968
  Eval (Set5): gen-loss: 0.10480, psnr: 10.62236
  Eval (Set14): gen-loss: 0.08898, psnr: 11.29387
  * New best psnr: 11.117156631068179
Model saved to model/export/srgan-normal-loss-upscale-4-new/best.pt
Model saved to model/export/srgan-normal-loss-upscale-4-new/0.pt
Epoch: 1 gen-loss: 0.07342, disc-loss: 0.26523, psnr: 11.79801
  Eval (Set5): gen-loss: 0.09718, psnr: 11.18446
  Eval (Set14): gen-loss: 0.08454, psnr: 11.80197
  * New best psnr: 11.639465683384946
Model saved to model/export/srgan-normal-loss-upscale-4-new/best.pt
