In [9]:
import numpy as np
import torch
from src import data, loss, model
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms as T


def train(
    srcnn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    loader: DataLoader,
    loss_arr: list,
    psnr_arr: list,
):
    srcnn.train()
    for lr, hr in loader:
        lr = lr.to(model.device)
        hr = hr.to(model.device)

        pred_hr = srcnn(lr)
        loss_score = loss.mse_loss(hr, pred_hr)

        optimizer.zero_grad()
        loss_score.backward()
        optimizer.step()

        loss_item = loss_score.detach()
        loss_arr.append(loss_item.item())
        psnr_arr.append(loss.calculate_psnr(loss_item).item())


def eval(
    srcnn: torch.nn.Module,
    loader: DataLoader,
    loss_arr: list,
    psnr_arr: list,
):
    srcnn.eval()
    with torch.no_grad():
        for lr, hr in loader:
            lr = lr.to(model.device)
            hr = hr.to(model.device)

            pred_hr = srcnn(lr)
            loss_score = loss.mse_loss(hr, pred_hr)

            loss_item = loss_score.detach()
            loss_arr.append(loss_item.item())
            psnr_arr.append(loss.calculate_psnr(loss_item).item())

In [2]:
transforms = T.Compose(
    [
        T.ToTensor(),
        # T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

train_data = data.SRCNNData(
    "data/T91/cropped",
    transform=transforms,
)

eval_data_set5 = data.SRCNNData(
    "data/test/Set5",
    transform=transforms,
)

eval_data_set14 = data.SRCNNData(
    "data/test/Set14",
    transform=transforms,
)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=1)
eval_loader_set5 = DataLoader(eval_data_set5, batch_size=1, shuffle=False, num_workers=1)
eval_loader_set14 = DataLoader(eval_data_set14, batch_size=1, shuffle=False, num_workers=1)

In [3]:
srcnn = model.srcnn.SRCNN().to(model.device)

optimizer = optim.SGD(
    [
        {"params": srcnn.block_1.parameters(), "lr": 1e-4},
        {"params": srcnn.block_final.parameters(), "lr": 1e-5},
    ],
    momentum=0.9,
)

In [4]:
best_psnr = float("-inf")
best_epoch = 0

total_train_loss = []
total_train_psnr = []
total_eval_loss_set5 = []
total_eval_psnr_set5 = []
total_eval_loss_set14 = []
total_eval_psnr_set14 = []

for i in range(1000):
    train_loss = []
    train_psnr = []
    loss_set5 = []
    loss_set14 = []

    psnr_set5 = []
    psnr_set14 = []

    train(srcnn, optimizer, train_loader, train_loss, train_psnr)
    mean_train_loss = np.mean(train_loss)
    mean_train_psnr = np.mean(train_psnr)
    print(f"Epoch: {i} loss: {mean_train_loss}, psnr: {mean_train_psnr}")

    eval(srcnn, eval_loader_set5, loss_set5, psnr_set5)
    mean_loss_set5 = np.mean(loss_set5)
    mean_psnr_set5 = np.mean(psnr_set5)
    print(f"  Eval (Set5): loss: {mean_loss_set5}, psnr: {mean_psnr_set5}")

    eval(srcnn, eval_loader_set14, loss_set14, psnr_set14)
    mean_loss_set14 = np.mean(loss_set14)
    mean_psnr_set14 = np.mean(psnr_set14)
    print(f"  Eval (Set14): loss: {mean_loss_set14}, psnr: {mean_psnr_set14}")

    total_train_loss.append(mean_train_loss)
    total_train_psnr.append(mean_train_psnr)

    total_eval_loss_set5.append(mean_loss_set5)
    total_eval_psnr_set5.append(mean_psnr_set5)

    total_eval_loss_set14.append(mean_loss_set14)
    total_eval_psnr_set14.append(mean_psnr_set14)

    eval_psnr = [*psnr_set5, *psnr_set14]
    curr_psnr = np.mean(eval_psnr)
    if curr_psnr > best_psnr:
        print(f"  * New best psnr: {curr_psnr}")
        best_epoch = i
        best_psnr = curr_psnr
        model.export(
            name="best-srcnn.pt",
            model=srcnn,
            optimizer=optimizer,
            best_epoch=i,
            best_psnr=best_psnr,
            total_train_loss=total_train_loss,
            total_train_psnr=total_train_psnr,
            total_eval_loss_set5=total_eval_loss_set5,
            total_eval_psnr_set5=total_eval_psnr_set5,
            total_eval_loss_set14=total_eval_loss_set14,
            total_eval_psnr_set14=total_eval_psnr_set14,
        )


Epoch: 0 loss: 0.10027342812757786, psnr: 10.889632532635186
  Eval (Set5): loss: 0.019408216699957848, psnr: 17.227976608276368
  Eval (Set14): loss: 0.026819699177784578, psnr: 16.059313705989293
  * New best psnr: 16.366856575012207
Model best-srcnn.pt saved to /mnt/Files/code/super-resolution-pytorch/src/model/export/best-srcnn.pt
Epoch: 1 loss: 0.03834249316291376, psnr: 14.189434989748678
  Eval (Set5): loss: 0.017706459388136864, psnr: 17.777454566955566
  Eval (Set14): loss: 0.02281907413687025, psnr: 16.757298742021835
  * New best psnr: 17.025760801214922
Model best-srcnn.pt saved to /mnt/Files/code/super-resolution-pytorch/src/model/export/best-srcnn.pt
Epoch: 2 loss: 0.03626475436786934, psnr: 14.429116949791629
  Eval (Set5): loss: 0.01692101266235113, psnr: 18.013127517700195
  Eval (Set14): loss: 0.021383548820657388, psnr: 17.031706196921213
  * New best psnr: 17.28997496554726
Model best-srcnn.pt saved to /mnt/Files/code/super-resolution-pytorch/src/model/export/best-s

In [6]:
for i in range(1000, 2000):
    train_loss = []
    train_psnr = []
    loss_set5 = []
    loss_set14 = []

    psnr_set5 = []
    psnr_set14 = []

    train(srcnn, optimizer, train_loader, train_loss, train_psnr)
    mean_train_loss = np.mean(train_loss)
    mean_train_psnr = np.mean(train_psnr)
    print(f"Epoch: {i} loss: {mean_train_loss}, psnr: {mean_train_psnr}")

    eval(srcnn, eval_loader_set5, loss_set5, psnr_set5)
    mean_loss_set5 = np.mean(loss_set5)
    mean_psnr_set5 = np.mean(psnr_set5)
    print(f"  Eval (Set5): loss: {mean_loss_set5}, psnr: {mean_psnr_set5}")

    eval(srcnn, eval_loader_set14, loss_set14, psnr_set14)
    mean_loss_set14 = np.mean(loss_set14)
    mean_psnr_set14 = np.mean(psnr_set14)
    print(f"  Eval (Set14): loss: {mean_loss_set14}, psnr: {mean_psnr_set14}")

    total_train_loss.append(mean_train_loss)
    total_train_psnr.append(mean_train_psnr)

    total_eval_loss_set5.append(mean_loss_set5)
    total_eval_psnr_set5.append(mean_psnr_set5)

    total_eval_loss_set14.append(mean_loss_set14)
    total_eval_psnr_set14.append(mean_psnr_set14)

    eval_psnr = [*psnr_set5, *psnr_set14]
    curr_psnr = np.mean(eval_psnr)
    if curr_psnr > best_psnr:
        print(f"  * New best psnr: {curr_psnr}")
        best_epoch = i
        best_psnr = curr_psnr
        model.export(
            name="best-srcnn.pt",
            model=srcnn,
            optimizer=optimizer,
            best_epoch=i,
            best_psnr=best_psnr,
            total_train_loss=total_train_loss,
            total_train_psnr=total_train_psnr,
            total_eval_loss_set5=total_eval_loss_set5,
            total_eval_psnr_set5=total_eval_psnr_set5,
            total_eval_loss_set14=total_eval_loss_set14,
            total_eval_psnr_set14=total_eval_psnr_set14,
        )


Epoch: 1000 loss: 0.0037652496335468955, psnr: 24.305030870259476
  Eval (Set5): loss: 0.0027179419528692962, psnr: 26.948624038696288
  Eval (Set14): loss: 0.0040793493556390914, psnr: 24.628717831202916
  * New best psnr: 25.239219464753802
Model best-srcnn.pt saved to /mnt/Files/code/super-resolution-pytorch/src/model/export/best-srcnn.pt
Epoch: 1001 loss: 0.0037643245840619133, psnr: 24.31086508393436
  Eval (Set5): loss: 0.002717006066814065, psnr: 26.949125671386717
  Eval (Set14): loss: 0.004078175531633731, psnr: 24.629618099757604
  * New best psnr: 25.240014829133685
Model best-srcnn.pt saved to /mnt/Files/code/super-resolution-pytorch/src/model/export/best-srcnn.pt
Epoch: 1002 loss: 0.0037634114730189256, psnr: 24.310746367515097
  Eval (Set5): loss: 0.0027164902770891786, psnr: 26.95069694519043
  Eval (Set14): loss: 0.004077806320440557, psnr: 24.63020351954869
  * New best psnr: 25.240859684191253
Model best-srcnn.pt saved to /mnt/Files/code/super-resolution-pytorch/src/m

KeyboardInterrupt: 

# Continue

In [35]:
checkpoint = model.load_checkpoint("best-srcnn-continue.pt", device=model.device)
checkpoint.keys()

cont_model = model.SRCNN().to(model.device).load_state_dict(checkpoint["model_state_dict"])
cont_model.train()

cont_optimizer = torch.optim.SGD(
    [
        {"params": srcnn.block_1.parameters(), "lr": 1e-4},
        {"params": srcnn.block_final.parameters(), "lr": 1e-5},
    ],
    momentum=0.9,
).load_state_dict(checkpoint["optimizer"])

cont_epoch = checkpoint["best_epoch"]
cont_psnr = checkpoint["best_psnr"]

cont_total_train_loss = checkpoint["total_train_loss"]
cont_total_train_psnr = checkpoint["total_train_psnr"]
cont_total_eval_loss_set5 = checkpoint["total_eval_loss_set5"]
cont_total_eval_psnr_set5 = checkpoint["total_eval_psnr_set5"]
cont_total_eval_loss_set14 = checkpoint["total_eval_loss_set14"]
cont_total_eval_psnr_set14 = checkpoint["total_eval_psnr_set14"]

AttributeError: '_IncompatibleKeys' object has no attribute 'train'

In [None]:
for i in range(cont_epoch, 2000):
    train_loss = []
    train_psnr = []
    loss_set5 = []
    loss_set14 = []

    psnr_set5 = []
    psnr_set14 = []

    train(cont_model, cont_optimizer, train_loader, train_loss, train_psnr)
    mean_train_loss = np.mean(train_loss)
    mean_train_psnr = np.mean(train_psnr)
    print(f"Epoch: {i} loss: {mean_train_loss}, psnr: {mean_train_psnr}")

    eval(cont_model, eval_loader_set5, loss_set5, psnr_set5)
    mean_loss_set5 = np.mean(loss_set5)
    mean_psnr_set5 = np.mean(psnr_set5)
    print(f"  Eval (Set5): loss: {mean_loss_set5}, psnr: {mean_psnr_set5}")

    eval(cont_model, eval_loader_set14, loss_set14, psnr_set14)
    mean_loss_set14 = np.mean(loss_set14)
    mean_psnr_set14 = np.mean(psnr_set14)
    print(f"  Eval (Set14): loss: {mean_loss_set14}, psnr: {mean_psnr_set14}")

    cont_total_train_loss.append(mean_train_loss)
    cont_total_train_psnr.append(mean_train_psnr)

    cont_total_eval_loss_set5.append(mean_loss_set5)
    cont_total_eval_psnr_set5.append(mean_psnr_set5)

    cont_total_eval_loss_set14.append(mean_loss_set14)
    cont_total_eval_psnr_set14.append(mean_psnr_set14)

    eval_psnr = [*psnr_set5, *psnr_set14]
    curr_psnr = np.mean(eval_psnr)
    if curr_psnr > cont_psnr:
        print(f"  * New best psnr: {curr_psnr}")
        best_epoch = i
        cont_psnr = curr_psnr
        model.export(
            name="best-srcnn-continue.pt",
            model=cont_model,
            optimizer=cont_optimizer,
            best_epoch=i,
            best_psnr=cont_psnr,
            total_train_loss=cont_total_train_loss,
            total_train_psnr=cont_total_train_psnr,
            total_eval_loss_set5=cont_total_eval_loss_set5,
            total_eval_psnr_set5=cont_total_eval_psnr_set5,
            total_eval_loss_set14=cont_total_eval_loss_set14,
            total_eval_psnr_set14=cont_total_eval_psnr_set14,
        )
