In [1]:
from src import checkpoint, data, metrics
from src.device import device
from src.model import SRCNN
from src.utils import UtilSRCNN
from torch import optim
from torch.utils.data import DataLoader

In [2]:
train_data = data.SRCNNData("data/T91/cropped", transform=UtilSRCNN.transforms)
eval_data_set5 = data.SRCNNData("data/test/Set5", transform=UtilSRCNN.transforms)
eval_data_set14 = data.SRCNNData("data/test/Set14", transform=UtilSRCNN.transforms)

train_loader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=2)
eval_loader_set5 = DataLoader(eval_data_set5, batch_size=1, shuffle=False, num_workers=1)
eval_loader_set14 = DataLoader(eval_data_set14, batch_size=1, shuffle=False, num_workers=1)

In [4]:
model_name = "srcnn-gaussian-0.55"
ckpt = checkpoint.load(f"{model_name}/best.pt", device=device)
metric = metrics.MetricSRCNN()
metric.load_checkpoint(ckpt)


In [6]:
srcnn = SRCNN().to(device)
# srcnn.load_state_dict(ckpt["model"])
srcnn.load_state_dict(ckpt["model"])
srcnn.train()

optimizer = optim.SGD(
    [
        {"params": srcnn.block_1.parameters(), "lr": 1e-4},
        {"params": srcnn.block_final.parameters(), "lr": 1e-5},
    ],
    momentum=0.9,
)
optimizer.load_state_dict(ckpt["optimizer"])

In [8]:
start = metric.best_epoch
end = 3000

start, end

(2, 3000)

In [None]:
for i in range(start, end):
    mean_loss, mean_psnr = UtilSRCNN.train(srcnn, optimizer, train_loader)
    print(f"Epoch: {i} loss: {mean_loss:.5f}, psnr: {mean_psnr:.5f}")
    metric.total_train_loss.append(mean_loss)
    metric.total_train_psnr.append(mean_psnr)

    mean_loss_set5, mean_psnr_set5 = UtilSRCNN.eval(srcnn, eval_loader_set5, metric)
    metric.total_eval_loss_set5.append(mean_loss_set5)
    metric.total_eval_psnr_set5.append(mean_psnr_set5)
    print(f"  Eval (Set5): loss: {mean_loss_set5:.5f}, psnr: {mean_psnr_set5:.5f}")

    mean_loss_set14, mean_psnr_set14 = UtilSRCNN.eval(srcnn, eval_loader_set14, metric)
    metric.total_eval_loss_set14.append(mean_loss_set14)
    metric.total_eval_psnr_set14.append(mean_psnr_set14)
    print(f"  Eval (Set14): loss: {mean_loss_set14:.5f}, psnr: {mean_psnr_set14:.5f}")

    curr_psnr = metric.get_eval_score()
    if curr_psnr > metric.best_psnr:
        print(f"  * New best psnr: {curr_psnr}")
        metric.best_epoch = i
        metric.best_psnr = curr_psnr

        checkpoint.save(
            name=f"{model_name}/best.pt",
            model=srcnn.state_dict(),
            optimizer=optimizer.state_dict(),
            **metric.save_checkpoint(),
        )

    if not (i + 1) % 100 or i == end - 1 or i == 0:
        metric.best_epoch = i
        checkpoint.save(
            name=f"{model_name}/{i}.pt",
            model=srcnn.state_dict(),
            optimizer=optimizer.state_dict(),
            **metric.save_checkpoint(),
        )


Epoch: 2 loss: 0.03193, psnr: 14.99736
  Eval (Set5): loss: 0.01465, psnr: 18.67755
  Eval (Set14): loss: 0.01795, psnr: 17.74719
  * New best psnr: 17.99202191202264
Model saved to model/export/srcnn-gaussian-0.55/best.pt
Model saved to model/export/srcnn-gaussian-0.55/2.pt
Epoch: 3 loss: 0.02892, psnr: 15.43102
  Eval (Set5): loss: 0.01325, psnr: 19.14385
  Eval (Set14): loss: 0.01584, psnr: 18.28322
  * New best psnr: 18.509701678627415
Model saved to model/export/srcnn-gaussian-0.55/best.pt
Epoch: 4 loss: 0.02585, psnr: 15.91511
  Eval (Set5): loss: 0.01193, psnr: 19.63783
  Eval (Set14): loss: 0.01382, psnr: 18.87186
  * New best psnr: 19.073434227391292
Model saved to model/export/srcnn-gaussian-0.55/best.pt
Epoch: 5 loss: 0.02299, psnr: 16.41929
  Eval (Set5): loss: 0.01072, psnr: 20.15387
  Eval (Set14): loss: 0.01214, psnr: 19.44261
  * New best psnr: 19.629781923796003
Model saved to model/export/srcnn-gaussian-0.55/best.pt
Epoch: 6 loss: 0.02071, psnr: 16.87472
  Eval (Set5)