In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

In [2]:
from config import Config
from dataset import SRDataset
from model import FSRCNN
from utils import PSNR, AverageMeter, save_evaluation_image

In [3]:
cfg = Config()


model = FSRCNN(scale_factor=cfg.scale_factor, d=cfg.d, c=cfg.c, s=cfg.s, m=cfg.m)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), cfg.learning_rate)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=cfg.factor, patience=cfg.patience, min_lr=cfg.min_learning_rate, verbose=False
)



In [4]:
dataloader = DataLoader(
    SRDataset(f"./data/{cfg.dataset_name}", hr_shape=cfg.hr_shape, scale_factor=cfg.scale_factor),
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=cfg.n_cpu,
)


val_dataloader = DataLoader(
    SRDataset(f"./data/{cfg.val_dataset_name}", hr_shape=cfg.hr_shape, scale_factor=cfg.scale_factor),
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=cfg.n_cpu,
)


test_dataloader = DataLoader(
    SRDataset(f"./data/{cfg.test_dataset_name}", hr_shape=cfg.hr_shape, scale_factor=cfg.scale_factor),
    batch_size=cfg.test_batch_size,
    shuffle=False,
    num_workers=1,
)
print(cfg.model_name,"[ c=",cfg.c,", d=",cfg.d,", s=",cfg.s,", m=",cfg.m,"]")

FSRCNN-s [ c= 1 , d= 32 , s= 5 , m= 1 ]


In [5]:
epoch = -1
model.eval()
for lr_imgs, hr_imgs in test_dataloader:
            sr_imgs = model(lr_imgs)
            lr_tensor = lr_imgs[0]
            hr_tensor = hr_imgs[0]
            sr_tensor = sr_imgs[0]
            print(lr_tensor.shape,hr_tensor.shape,sr_tensor.shape)
            save_evaluation_image(epoch,lr_tensor, hr_tensor, sr_tensor, f"training/{cfg.model_name}")
    

torch.Size([1, 64, 64]) torch.Size([1, 256, 256]) torch.Size([1, 256, 256])


In [6]:
for epoch in range(cfg.num_epochs):
    model.train()
    epoch_loss = AverageMeter()

    for lr_imgs, hr_imgs in dataloader:

        optimizer.zero_grad()

        sr_imgs = model(lr_imgs)

        loss = criterion(sr_imgs, hr_imgs)
        loss.backward()
        optimizer.step()

        epoch_loss.update(loss.item(),len(lr_imgs))

    val_loss = AverageMeter()
    model.eval()

    for lr_imgs, hr_imgs in val_dataloader:

        with torch.no_grad():
            sr_imgs = model(lr_imgs).clamp(0.0, 1.0)

        loss = criterion(sr_imgs, hr_imgs)
          
        val_loss.update(loss.item(),len(lr_imgs))

    scheduler.step(val_loss.avg)

    print(f"Epoch [{epoch+1}/{cfg.num_epochs}], Train Loss: {epoch_loss.avg:.6f}, Val Loss: {val_loss.avg:.6f}")

    if (epoch + 1) % 10 == 0:

        model.eval()
        for lr_imgs, hr_imgs in test_dataloader:
            with torch.no_grad():
                sr_imgs = model(lr_imgs).clamp(0.0, 1.0)
            lr_tensor = lr_imgs[0]
            hr_tensor = hr_imgs[0]
            sr_tensor = sr_imgs[0]
            save_evaluation_image(epoch,lr_tensor, hr_tensor, sr_tensor, f"training/{cfg.model_name}")
        
        torch.save(model.state_dict(), f"./saved_models/{cfg.model_name}/fsrcnnx4_epoch_{epoch+1}.pth")


Epoch [1/60], Train Loss: 0.240073, Val Loss: 0.226800
Epoch [2/60], Train Loss: 0.234436, Val Loss: 0.215999
Epoch [3/60], Train Loss: 0.229798, Val Loss: 0.211242
Epoch [4/60], Train Loss: 0.221785, Val Loss: 0.199199
Epoch [5/60], Train Loss: 0.216478, Val Loss: 0.197327
Epoch [6/60], Train Loss: 0.203024, Val Loss: 0.182494
Epoch [7/60], Train Loss: 0.191799, Val Loss: 0.174768
Epoch [8/60], Train Loss: 0.187238, Val Loss: 0.158479
Epoch [9/60], Train Loss: 0.175446, Val Loss: 0.166419
Epoch [10/60], Train Loss: 0.173702, Val Loss: 0.152394
Epoch [11/60], Train Loss: 0.164284, Val Loss: 0.142083
Epoch [12/60], Train Loss: 0.161899, Val Loss: 0.139661
Epoch [13/60], Train Loss: 0.151341, Val Loss: 0.137325
Epoch [14/60], Train Loss: 0.147250, Val Loss: 0.123334
Epoch [15/60], Train Loss: 0.137411, Val Loss: 0.129654
Epoch [16/60], Train Loss: 0.131348, Val Loss: 0.126570
Epoch [17/60], Train Loss: 0.131489, Val Loss: 0.120336
Epoch [18/60], Train Loss: 0.126031, Val Loss: 0.114721
E

In [7]:
for epoch in range(cfg.num_epochs,2*(cfg.num_epochs)):
    model.train()
    epoch_loss = AverageMeter()

    for lr_imgs, hr_imgs in dataloader:

        optimizer.zero_grad()

        sr_imgs = model(lr_imgs)

        loss = criterion(sr_imgs, hr_imgs)
        loss.backward()
        optimizer.step()

        epoch_loss.update(loss.item(),len(lr_imgs))

    val_loss = AverageMeter()
    model.eval()

    for lr_imgs, hr_imgs in val_dataloader:

        with torch.no_grad():
            sr_imgs = model(lr_imgs).clamp(0.0, 1.0)

        loss = criterion(sr_imgs, hr_imgs)
          
        val_loss.update(loss.item(),len(lr_imgs))

    scheduler.step(val_loss.avg)

    print(f"Epoch [{epoch+1}/{(2*cfg.num_epochs)}], Train Loss: {epoch_loss.avg:.6f}, Val Loss: {val_loss.avg:.6f}")

    if (epoch + 1) % 10 == 0:

        model.eval()
        for lr_imgs, hr_imgs in test_dataloader:
            with torch.no_grad():
                sr_imgs = model(lr_imgs).clamp(0.0, 1.0)
            lr_tensor = lr_imgs[0]
            hr_tensor = hr_imgs[0]
            sr_tensor = sr_imgs[0]
            save_evaluation_image(epoch,lr_tensor, hr_tensor, sr_tensor, f"training/{cfg.model_name}")
        
        torch.save(model.state_dict(), f"./saved_models/{cfg.model_name}/fsrcnnx4_epoch_{epoch+1}.pth")


Epoch [61/120], Train Loss: 0.028608, Val Loss: 0.030253
Epoch [62/120], Train Loss: 0.028380, Val Loss: 0.030297
Epoch [63/120], Train Loss: 0.027990, Val Loss: 0.032593
Epoch [64/120], Train Loss: 0.027488, Val Loss: 0.032095
Epoch [65/120], Train Loss: 0.027651, Val Loss: 0.029454
Epoch [66/120], Train Loss: 0.027769, Val Loss: 0.031216
Epoch [67/120], Train Loss: 0.027577, Val Loss: 0.029154
Epoch [68/120], Train Loss: 0.027723, Val Loss: 0.031928
Epoch [69/120], Train Loss: 0.027250, Val Loss: 0.027945
Epoch [70/120], Train Loss: 0.027341, Val Loss: 0.030625
Epoch [71/120], Train Loss: 0.026612, Val Loss: 0.028580
Epoch [72/120], Train Loss: 0.026934, Val Loss: 0.030954
Epoch [73/120], Train Loss: 0.026964, Val Loss: 0.027903
Epoch [74/120], Train Loss: 0.026687, Val Loss: 0.028779
Epoch [75/120], Train Loss: 0.026000, Val Loss: 0.030368
Epoch [76/120], Train Loss: 0.026694, Val Loss: 0.028630
Epoch [77/120], Train Loss: 0.026547, Val Loss: 0.029455
Epoch [78/120], Train Loss: 0.0