In [1]:
from src import checkpoint, data, metrics
from src.device import type
from src.model import SRCNN
from src.utils import UtilSRCNN
from torch import optim
from torch.utils.data import DataLoader


In [2]:
train_data = data.SRCNNData(
    "data/T91/cropped",
    transform=UtilSRCNN.transforms,
)
eval_data_set5 = data.SRCNNData(
    "data/test/Set5",
    transform=UtilSRCNN.transforms,
)
eval_data_set14 = data.SRCNNData(
    "data/test/Set14",
    transform=UtilSRCNN.transforms,
)

train_loader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=2)
eval_loader_set5 = DataLoader(eval_data_set5, batch_size=1, shuffle=False, num_workers=1)
eval_loader_set14 = DataLoader(eval_data_set14, batch_size=1, shuffle=False, num_workers=1)

In [3]:
srcnn = SRCNN().to(type)

optimizer = optim.SGD(
    [
        {"params": srcnn.block_1.parameters(), "lr": 1e-4},
        {"params": srcnn.block_final.parameters(), "lr": 1e-5},
    ],
    momentum=0.9,
)

In [4]:
model_name = "srcnn-gaussian-0.55-t91-amp"
metric = metrics.MetricSRCNN()


In [5]:
start = 0
end = 2000

In [6]:
for i in range(start, end):
    mean_loss, mean_psnr = UtilSRCNN.train(srcnn, optimizer, train_loader)
    print(f"Epoch: {i} loss: {mean_loss:.5f}, psnr: {mean_psnr:.5f}")
    metric.total_train_loss.append(mean_loss)
    metric.total_train_psnr.append(mean_psnr)

    mean_loss_set5, mean_psnr_set5 = UtilSRCNN.eval(srcnn, eval_loader_set5, metric)
    print(f"  Eval (Set5): loss: {mean_loss_set5:.5f}, psnr: {mean_psnr_set5:.5f}")
    metric.total_eval_loss_set5.append(mean_loss_set5)
    metric.total_eval_psnr_set5.append(mean_psnr_set5)

    mean_loss_set14, mean_psnr_set14 = UtilSRCNN.eval(srcnn, eval_loader_set14, metric)
    print(f"  Eval (Set14): loss: {mean_loss_set14:.5f}, psnr: {mean_psnr_set14:.5f}")
    metric.total_eval_loss_set14.append(mean_loss_set14)
    metric.total_eval_psnr_set14.append(mean_psnr_set14)

    curr_psnr = metric.get_eval_score()
    if curr_psnr > metric.best_psnr:
        print(f"  * New best psnr: {curr_psnr}")
        metric.best_epoch = i
        metric.best_psnr = curr_psnr

        checkpoint.save(
            name=f"{model_name}/best.pt",
            model=srcnn.state_dict(),
            optimizer=optimizer.state_dict(),
            **metric.save_checkpoint(),
        )

    if not (i + 1) % 100 or i == end - 1 or i == 0:
        metric.best_epoch = i
        checkpoint.save(
            name=f"{model_name}/{i}.pt",
            model=srcnn.state_dict(),
            optimizer=optimizer.state_dict(),
            **metric.save_checkpoint(),
        )


Epoch: 0 loss: 0.09619, psnr: 11.37278
  Eval (Set5): loss: 0.01955, psnr: 17.31449
  Eval (Set14): loss: 0.02332, psnr: 16.53685
  * New best psnr: 16.74149327529104
Model saved to model/export/srcnn-gaussian-0.55-t91-amp/best.pt
Model saved to model/export/srcnn-gaussian-0.55-t91-amp/0.pt
Epoch: 1 loss: 0.03685, psnr: 14.37310
  Eval (Set5): loss: 0.01765, psnr: 17.77472
  Eval (Set14): loss: 0.02102, psnr: 16.97608
  * New best psnr: 17.18625229283383
Model saved to model/export/srcnn-gaussian-0.55-t91-amp/best.pt
Epoch: 2 loss: 0.03321, psnr: 14.82398
  Eval (Set5): loss: 0.01585, psnr: 18.27312
  Eval (Set14): loss: 0.01850, psnr: 17.51324
  * New best psnr: 17.71320669274581
Model saved to model/export/srcnn-gaussian-0.55-t91-amp/best.pt
Epoch: 3 loss: 0.02980, psnr: 15.29618
  Eval (Set5): loss: 0.01419, psnr: 18.80278
  Eval (Set14): loss: 0.01614, psnr: 18.10282
  * New best psnr: 18.287017922652396
Model saved to model/export/srcnn-gaussian-0.55-t91-amp/best.pt
Epoch: 4 loss: