In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

In [2]:
from config import Config
from dataset import SRDataset
from model import FSRCNN
from utils import PSNR, AverageMeter, save_evaluation_image

In [3]:
cfg = Config()


model = FSRCNN(scale_factor=cfg.scale_factor, d=cfg.d, c=cfg.c, s=cfg.s, m=cfg.m)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), cfg.learning_rate)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=cfg.factor, patience=cfg.patience, min_lr=cfg.min_learning_rate, verbose=False
)



In [4]:
dataloader = DataLoader(
    SRDataset(f"./data/{cfg.dataset_name}", hr_shape=cfg.hr_shape, scale_factor=cfg.scale_factor),
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=cfg.n_cpu,
)


val_dataloader = DataLoader(
    SRDataset(f"./data/{cfg.val_dataset_name}", hr_shape=cfg.hr_shape, scale_factor=cfg.scale_factor),
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=cfg.n_cpu,
)


test_dataloader = DataLoader(
    SRDataset(f"./data/{cfg.test_dataset_name}", hr_shape=cfg.hr_shape, scale_factor=cfg.scale_factor),
    batch_size=cfg.test_batch_size,
    shuffle=False,
    num_workers=1,
)
print(cfg.model_name,"[ c=",cfg.c,", d=",cfg.d,", s=",cfg.s,", m=",cfg.m,"]")

FSRCNN [ c= 1 , d= 56 , s= 12 , m= 4 ]


In [5]:
epoch = -1
model.eval()
for lr_imgs, hr_imgs in test_dataloader:
            sr_imgs = model(lr_imgs)
            lr_tensor = lr_imgs[0]
            hr_tensor = hr_imgs[0]
            sr_tensor = sr_imgs[0]
            print(lr_tensor.shape,hr_tensor.shape,sr_tensor.shape)
            save_evaluation_image(epoch,lr_tensor, hr_tensor, sr_tensor, f"training/{cfg.model_name}")
    

torch.Size([1, 64, 64]) torch.Size([1, 256, 256]) torch.Size([1, 256, 256])


In [6]:
for epoch in range(cfg.num_epochs):
    model.train()
    epoch_loss = AverageMeter()

    for lr_imgs, hr_imgs in dataloader:

        optimizer.zero_grad()

        sr_imgs = model(lr_imgs)

        loss = criterion(sr_imgs, hr_imgs)
        loss.backward()
        optimizer.step()

        epoch_loss.update(loss.item(),len(lr_imgs))

    val_loss = AverageMeter()
    model.eval()

    for lr_imgs, hr_imgs in val_dataloader:

        with torch.no_grad():
            sr_imgs = model(lr_imgs).clamp(0.0, 1.0)

        loss = criterion(sr_imgs, hr_imgs)
          
        val_loss.update(loss.item(),len(lr_imgs))

    scheduler.step(val_loss.avg)

    print(f"Epoch [{epoch+1}/{cfg.num_epochs}], Train Loss: {epoch_loss.avg:.6f}, Val Loss: {val_loss.avg:.6f}")

    if (epoch + 1) % 10 == 0:

        model.eval()
        for lr_imgs, hr_imgs in test_dataloader:
            with torch.no_grad():
                sr_imgs = model(lr_imgs).clamp(0.0, 1.0)
            lr_tensor = lr_imgs[0]
            hr_tensor = hr_imgs[0]
            sr_tensor = sr_imgs[0]
            save_evaluation_image(epoch,lr_tensor, hr_tensor, sr_tensor, f"training/{cfg.model_name}")
        
        torch.save(model.state_dict(), f"./saved_models/{cfg.model_name}/fsrcnnx4_epoch_{epoch+1}.pth")


Epoch [1/60], Train Loss: 0.242547, Val Loss: 0.219942
Epoch [2/60], Train Loss: 0.229755, Val Loss: 0.215982
Epoch [3/60], Train Loss: 0.221338, Val Loss: 0.196508
Epoch [4/60], Train Loss: 0.212460, Val Loss: 0.199392
Epoch [5/60], Train Loss: 0.200129, Val Loss: 0.168621
Epoch [6/60], Train Loss: 0.185659, Val Loss: 0.176851
Epoch [7/60], Train Loss: 0.183668, Val Loss: 0.159752
Epoch [8/60], Train Loss: 0.171092, Val Loss: 0.159390
Epoch [9/60], Train Loss: 0.161808, Val Loss: 0.151544
Epoch [10/60], Train Loss: 0.149461, Val Loss: 0.142877
Epoch [11/60], Train Loss: 0.140869, Val Loss: 0.124358
Epoch [12/60], Train Loss: 0.132871, Val Loss: 0.113742
Epoch [13/60], Train Loss: 0.123105, Val Loss: 0.112220
Epoch [14/60], Train Loss: 0.114716, Val Loss: 0.101930
Epoch [15/60], Train Loss: 0.103627, Val Loss: 0.089432
Epoch [16/60], Train Loss: 0.094786, Val Loss: 0.076421
Epoch [17/60], Train Loss: 0.091156, Val Loss: 0.077763
Epoch [18/60], Train Loss: 0.078471, Val Loss: 0.073009
E