# 1. Import thư viện

In [1]:
import cv2
import numpy as np
import os
from PIL import Image
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DataParallel
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from torchvision.utils import save_image

from tqdm import tqdm
from models.srcnn import *
from models.idn import *

from models.sr_model import *

from models.srresnet_ import *
from models.sr_model import *
from models.fsrcnn import *
from models.wdsr import *
from models.edsr import *
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Tạo Mô hình SR

In [3]:
# edsr_rdn = EDSR_rdnfy()
edsr_orig = EDSR().to(device)
edrn_canny = EDRN(use_canny=True).to(device)
edrn_srresnet = EDRN(use_sobel=True).to(device)
srresnet = SRResNet().to(device)
rdn = SRCNN().to(device)
vdsr = VDSR().to(device)
fsrcnn = FSRCNN(scale_factor=4).to(device)


In [4]:

# edsr_orig.load_state_dict(torch.load('best_edsr.pth', map_location=device, weights_))
# edsr_orig.load_state_dict(torch.load('weight/best_edsrx4_orig_model.pth', map_location=device))
# edrn_sobel.load_state_dict(torch.load('weight/best_sobel_srx4_model.pth', map_location=device))
# edrn_canny.load_state_dict(torch.load('weight/best_canny_srx4_model.pth', map_location=device))
# srresnet.load_state_dict(torch.load('best_srresnet.pth', map_location=device))
# rdn.load_state_dict(torch.load('weight/rdn_x4.pth', map_location=device))
# fsrcnn.load_state_dict(torch.load('weight/fsrcnn_x4.pth', map_location=device))


In [12]:
class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir):
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_image = cv2.imread(os.path.join(self.lr_dir, self.lr_files[idx]))
        hr_image = cv2.imread(os.path.join(self.hr_dir, self.hr_files[idx]))
        lr_image = cv2.cvtColor(lr_image, cv2.COLOR_BGR2RGB)
        hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)
        
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor()
        ])
        
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)
        return lr_image, hr_image

# 3. Tạo Hyperparameter

In [6]:
# Đường dẫn tới bộ dữ liệu
train_lr_dir = 'Train/LR'
train_hr_dir = 'Train/HR'
valid_lr_dir = 'Test/LR'
valid_hr_dir = 'Test/HR'
# test_hr_dir  = '/kaggle/input/srdataset/sr_data/test/HR'
# test_lr_dir  = '/kaggle/input/srdataset/sr_data/test/LR'

# print(torch.cuda.memory_allocated())
# print(torch.cuda.memory_reserved())

In [7]:
from torch.amp import autocast, GradScaler
scaler = GradScaler()

# Khởi tạo dataset và dataloader
train_dataset = ImageDataset(train_lr_dir, train_hr_dir)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

valid_dataset = ImageDataset(valid_lr_dir, valid_hr_dir)
valid_loader = DataLoader(valid_dataset)

print(len(train_loader))

# Khởi tạo loss function
criterion = nn.MSELoss()

# Khởi tạo optimizers, schedulers cho từng mô hình
optim_edsr = optim.Adam(edsr_orig.parameters(), lr=1e-4, betas=(0.9, 0.999))
scheduler_edsr = optim.lr_scheduler.StepLR(optim_edsr, step_size=10**5, gamma=0.5)

optim_srresnet = optim.Adam(srresnet.parameters(), lr=1e-4, betas=(0.9, 0.999))
scheduler_srresnet = optim.lr_scheduler.StepLR(optim_srresnet, step_size=10**5, gamma=0.5)

optim_rdn = optim.Adam(rdn.parameters(), lr=1e-4, betas=(0.9, 0.999))
scheduler_rdn = optim.lr_scheduler.StepLR(optim_rdn, step_size=10**5, gamma=0.5)

optim_fsrcnn = optim.Adam(fsrcnn.parameters(), lr=1e-4, betas=(0.9, 0.999))
scheduler_fsrcnn = optim.lr_scheduler.StepLR(optim_fsrcnn, step_size=10**5, gamma=0.5)

12500


In [8]:
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

In [9]:
print(device)

cuda


# 4. Training

In [10]:


# Hàm tính PSNR
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

num_epochs = 24

best_psnr_edsr = float('-inf')
best_psnr_srresnet = float('-inf')
best_psnr_rdn = float('-inf')
best_psnr_fsrcnn = float('-inf')
torch.cuda.empty_cache()

losses_edsr = []
losses_srresnet = []
losses_rdn = []
losses_fsrcnn = []

avg_psnr_edsr = []
avg_psnr_srresnet = []
avg_psnr_rdn = []
avg_psnr_fsrcnn = []

val_avg_psnr_edsr = []
val_avg_psnr_srresnet = []
val_avg_psnr_rdn = []
val_avg_psnr_fsrcnn = []

patience = 50
epochs_no_improve = 0
log_file = open('training_log_models.txt', 'a')

for epoch in range(num_epochs):
    # edsr_orig.train()
    srresnet.train()
    rdn.train()
    # fsrcnn.train()

    epoch_loss_edsr, psnr_values_edsr = 0, 0
    epoch_loss_srresnet, psnr_values_srresnet = 0, 0
    epoch_loss_rdn, psnr_values_rdn = 0, 0
    epoch_loss_fsrcnn, psnr_values_fsrcnn = 0, 0
    start_time = time.time()

    # Training loop for each model
    for (lr_images, hr_images) in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch'):
        lr_images = lr_images.cuda()
        hr_images = hr_images.cuda()

        # Train EDSR model
        optim_edsr.zero_grad()
        with autocast(device_type='cuda'):
            outputs_edsr = edsr_orig(lr_images)
            loss_edsr = criterion(outputs_edsr, hr_images)
        psnr_edsr = calculate_psnr(outputs_edsr, hr_images)

        scaler.scale(loss_edsr).backward()
        scaler.step(optim_edsr)
        scaler.update()
        scheduler_edsr.step()

        epoch_loss_edsr += loss_edsr.item()
        psnr_values_edsr += psnr_edsr

        # Train SRResNet model
        optim_srresnet.zero_grad()
        with autocast(device_type='cuda'):
            outputs_srresnet = srresnet(lr_images)
            loss_srresnet = criterion(outputs_srresnet, hr_images)
        psnr_srresnet = calculate_psnr(outputs_srresnet, hr_images)

        scaler.scale(loss_srresnet).backward()
        scaler.step(optim_srresnet)
        scaler.update()
        scheduler_srresnet.step()

        epoch_loss_srresnet += loss_srresnet.item()
        psnr_values_srresnet += psnr_srresnet

        # Train RDN model
        # optim_rdn.zero_grad()
        # with autocast(device_type='cuda'):
        #     outputs_rdn = rdn(lr_images)
        #     loss_rdn = criterion(outputs_rdn, hr_images)
        # psnr_rdn = calculate_psnr(outputs_rdn, hr_images)

        # scaler.scale(loss_rdn).backward()
        # scaler.step(optim_rdn)
        # scaler.update()
        # scheduler_rdn.step()

        # epoch_loss_rdn += loss_rdn.item()
        # psnr_values_rdn += psnr_rdn

        # Train FSRCNN model
        # optim_fsrcnn.zero_grad()
        # with autocast(device_type='cuda'):
        #     outputs_fsrcnn = fsrcnn(lr_images)
        #     loss_fsrcnn = criterion(outputs_fsrcnn, hr_images)
        # psnr_fsrcnn = calculate_psnr(outputs_fsrcnn, hr_images)

        # scaler.scale(loss_fsrcnn).backward()
        # scaler.step(optim_fsrcnn)
        # scaler.update()
        # scheduler_fsrcnn.step()

        # epoch_loss_fsrcnn += loss_fsrcnn.item()
        # psnr_values_fsrcnn += psnr_fsrcnn

    # # Average losses and PSNRs
    avg_epoch_loss_edsr = epoch_loss_edsr / len(train_loader)
    avg_psnr_edsr_epoch = psnr_values_edsr / len(train_loader)
    losses_edsr.append(avg_epoch_loss_edsr)
    avg_psnr_edsr.append(avg_psnr_edsr_epoch)

    avg_epoch_loss_srresnet = epoch_loss_srresnet / len(train_loader)
    avg_psnr_srresnet_epoch = psnr_values_srresnet / len(train_loader)
    losses_srresnet.append(avg_epoch_loss_srresnet)
    avg_psnr_srresnet.append(avg_psnr_srresnet_epoch)

    avg_epoch_loss_rdn = epoch_loss_rdn / len(train_loader)
    avg_psnr_rdn_epoch = psnr_values_rdn / len(train_loader)
    losses_rdn.append(avg_epoch_loss_rdn)
    avg_psnr_rdn.append(avg_psnr_rdn_epoch)

    # avg_epoch_loss_fsrcnn = epoch_loss_fsrcnn / len(train_loader)
    # avg_psnr_fsrcnn_epoch = psnr_values_fsrcnn / len(train_loader)
    # losses_fsrcnn.append(avg_epoch_loss_fsrcnn)
    # avg_psnr_fsrcnn.append(avg_psnr_fsrcnn_epoch)

    # Validation for all models
    edsr_orig.eval()
    srresnet.eval()
    # rdn.eval()
    # fsrcnn.eval()

    val_psnr_edsr, val_psnr_srresnet = 0, 0
    val_psnr_rdn, val_psnr_fsrcnn = 0, 0

    with torch.no_grad():
        for (lr_images, hr_images) in valid_loader:
            lr_images = lr_images.cuda()
            hr_images = hr_images.cuda()

            # # Validate EDSR
            outputs_edsr = edsr_orig(lr_images)
            psnr_edsr = calculate_psnr(outputs_edsr, hr_images)
            val_psnr_edsr += psnr_edsr

            # Validate SRResNet
            outputs_srresnet = srresnet(lr_images)
            psnr_srresnet = calculate_psnr(outputs_srresnet, hr_images)
            val_psnr_srresnet += psnr_srresnet

            # # Validate RDN
            # outputs_rdn = rdn(lr_images)
            # psnr_rdn = calculate_psnr(outputs_rdn, hr_images)
            # val_psnr_rdn += psnr_rdn

            # # Validate FSRCNN
            # outputs_fsrcnn = fsrcnn(lr_images)
            # psnr_fsrcnn = calculate_psnr(outputs_fsrcnn, hr_images)
            # val_psnr_fsrcnn += psnr_fsrcnn

    val_avg_psnr_edsr_epoch = val_psnr_edsr / len(valid_loader)
    val_avg_psnr_edsr.append(val_avg_psnr_edsr_epoch)

    val_avg_psnr_srresnet_epoch = val_psnr_srresnet / len(valid_loader)
    val_avg_psnr_srresnet.append(val_avg_psnr_srresnet_epoch)

    # val_avg_psnr_rdn_epoch = val_psnr_rdn / len(valid_loader)
    # val_avg_psnr_rdn.append(val_avg_psnr_rdn_epoch)

    # val_avg_psnr_fsrcnn_epoch = val_psnr_fsrcnn / len(valid_loader)
    # val_avg_psnr_fsrcnn.append(val_avg_psnr_fsrcnn_epoch)

    # Save best model
    if val_avg_psnr_edsr_epoch > best_psnr_edsr:
        best_psnr_edsr = val_avg_psnr_edsr_epoch
        torch.save(edsr_orig.state_dict(), 'best_edsr.pth')
        print(f"Saved EDSRR model with PSNR {best_psnr_edsr:.4f}")
    if val_avg_psnr_srresnet_epoch > best_psnr_srresnet:
        best_psnr_srresnet = val_avg_psnr_srresnet_epoch
        torch.save(srresnet.state_dict(), 'best_srresnet.pth')
        print(f"Saved SRResNet model with PSNR {best_psnr_srresnet:.4f}")

    torch.save(edsr_orig.state_dict(), 'edsr.pth')
    torch.save(srresnet.state_dict(), 'srresnet.pth')
    # if val_avg_psnr_rdn_epoch > best_psnr_rdn:
    #     best_psnr_rdn = val_avg_psnr_rdn_epoch
    #     torch.save(rdn.state_dict(), 'best_rdn.pth')
    #     print(f"Saved RDN model with PSNR {best_psnr_rdn:.4f}")
    # if val_avg_psnr_fsrcnn_epoch > best_psnr_fsrcnn:
    #     best_psnr_fsrcnn = val_avg_psnr_fsrcnn_epoch
    #     torch.save(fsrcnn.state_dict(), 'best_fsrcnn.pth')

    print(f"Epoch [{epoch+1}/{num_epochs}] completed: EDSR Loss: {avg_epoch_loss_edsr:.4f}, PSNR: {avg_psnr_edsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_edsr_epoch:.4f}, SRResNEt Loss: {avg_epoch_loss_srresnet:.4f}, PSNR: {avg_psnr_srresnet_epoch:.4f}, Validation PSNR: {val_avg_psnr_srresnet_epoch:.4f}")
# 
    log_file.write(f"Epoch {epoch+1}:  EDSR PSNR: {avg_psnr_edsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_edsr_epoch:.4f}\n")
    log_file.write(f"              SRResNet PSNR: {avg_psnr_srresnet_epoch:.4f}, Validation PSNR: {val_avg_psnr_srresnet_epoch:.4f}\n")
    # log_file.write(f"Epoch {epoch+1}: RDN PSNR: {avg_psnr_rdn_epoch:.4f}, Validation PSNR: {val_avg_psnr_rdn_epoch:.4f}\n")
    # log_file.write(f"Epoch {epoch+1}: FSRCNN PSNR: {avg_psnr_fsrcnn_epoch:.4f}, Validation PSNR: {val_avg_psnr_fsrcnn_epoch:.4f}\n")

    log_file.flush()

log_file.close()


Epoch 1/24: 100%|██████████| 12500/12500 [20:45<00:00, 10.03batch/s]


Saved EDSRR model with PSNR 26.3885
Saved SRResNet model with PSNR 25.6223
Epoch [1/24] completed: EDSR Loss: 0.0038, PSNR: 25.0257, Validation PSNR: 26.3885, SRResNEt Loss: 0.0083, PSNR: 23.2386, Validation PSNR: 25.6223


Epoch 2/24: 100%|██████████| 12500/12500 [19:29<00:00, 10.69batch/s]


Saved EDSRR model with PSNR 26.5568
Saved SRResNet model with PSNR 26.3591
Epoch [2/24] completed: EDSR Loss: 0.0028, PSNR: 25.7365, Validation PSNR: 26.5568, SRResNEt Loss: 0.0031, PSNR: 25.3265, Validation PSNR: 26.3591


Epoch 3/24: 100%|██████████| 12500/12500 [19:39<00:00, 10.60batch/s]


Saved EDSRR model with PSNR 26.9843
Saved SRResNet model with PSNR 26.4460
Epoch [3/24] completed: EDSR Loss: 0.0027, PSNR: 25.9230, Validation PSNR: 26.9843, SRResNEt Loss: 0.0029, PSNR: 25.6416, Validation PSNR: 26.4460


Epoch 4/24: 100%|██████████| 12500/12500 [19:38<00:00, 10.61batch/s]


Saved EDSRR model with PSNR 27.1163
Epoch [4/24] completed: EDSR Loss: 0.0027, PSNR: 26.0372, Validation PSNR: 27.1163, SRResNEt Loss: 0.0028, PSNR: 25.7593, Validation PSNR: 22.1151


Epoch 5/24: 100%|██████████| 12500/12500 [19:34<00:00, 10.65batch/s]


Saved EDSRR model with PSNR 27.2259
Saved SRResNet model with PSNR 26.9930
Epoch [5/24] completed: EDSR Loss: 0.0026, PSNR: 26.1187, Validation PSNR: 27.2259, SRResNEt Loss: 0.0027, PSNR: 25.8954, Validation PSNR: 26.9930


Epoch 6/24: 100%|██████████| 12500/12500 [19:24<00:00, 10.73batch/s]


Saved EDSRR model with PSNR 27.2781
Saved SRResNet model with PSNR 27.1589
Epoch [6/24] completed: EDSR Loss: 0.0026, PSNR: 26.1875, Validation PSNR: 27.2781, SRResNEt Loss: 0.0027, PSNR: 26.0340, Validation PSNR: 27.1589


Epoch 7/24: 100%|██████████| 12500/12500 [19:22<00:00, 10.75batch/s]


Saved EDSRR model with PSNR 27.3678
Epoch [7/24] completed: EDSR Loss: 0.0026, PSNR: 26.2325, Validation PSNR: 27.3678, SRResNEt Loss: 0.0027, PSNR: 26.0187, Validation PSNR: 26.9482


Epoch 8/24: 100%|██████████| 12500/12500 [19:21<00:00, 10.76batch/s]


Saved EDSRR model with PSNR 27.4250
Saved SRResNet model with PSNR 27.2849
Epoch [8/24] completed: EDSR Loss: 0.0025, PSNR: 26.2731, Validation PSNR: 27.4250, SRResNEt Loss: 0.0026, PSNR: 26.0832, Validation PSNR: 27.2849


Epoch 9/24: 100%|██████████| 12500/12500 [19:22<00:00, 10.75batch/s]


Saved EDSRR model with PSNR 27.4611
Epoch [9/24] completed: EDSR Loss: 0.0025, PSNR: 26.3297, Validation PSNR: 27.4611, SRResNEt Loss: 0.0026, PSNR: 26.2292, Validation PSNR: 27.2516


Epoch 10/24: 100%|██████████| 12500/12500 [19:31<00:00, 10.67batch/s]


Saved EDSRR model with PSNR 27.5001
Saved SRResNet model with PSNR 27.4006
Epoch [10/24] completed: EDSR Loss: 0.0025, PSNR: 26.3516, Validation PSNR: 27.5001, SRResNEt Loss: 0.0025, PSNR: 26.2727, Validation PSNR: 27.4006


Epoch 11/24: 100%|██████████| 12500/12500 [19:27<00:00, 10.70batch/s]


Saved EDSRR model with PSNR 27.5004
Saved SRResNet model with PSNR 27.4525
Epoch [11/24] completed: EDSR Loss: 0.0025, PSNR: 26.3652, Validation PSNR: 27.5004, SRResNEt Loss: 0.0025, PSNR: 26.2842, Validation PSNR: 27.4525


Epoch 12/24: 100%|██████████| 12500/12500 [19:22<00:00, 10.75batch/s]


Saved EDSRR model with PSNR 27.5314
Saved SRResNet model with PSNR 27.4563
Epoch [12/24] completed: EDSR Loss: 0.0025, PSNR: 26.3768, Validation PSNR: 27.5314, SRResNEt Loss: 0.0025, PSNR: 26.3164, Validation PSNR: 27.4563


Epoch 13/24: 100%|██████████| 12500/12500 [19:27<00:00, 10.70batch/s]


Saved EDSRR model with PSNR 27.5326
Saved SRResNet model with PSNR 27.4645
Epoch [13/24] completed: EDSR Loss: 0.0025, PSNR: 26.3902, Validation PSNR: 27.5326, SRResNEt Loss: 0.0025, PSNR: 26.3432, Validation PSNR: 27.4645


Epoch 14/24: 100%|██████████| 12500/12500 [19:41<00:00, 10.58batch/s]


Saved SRResNet model with PSNR 27.5064
Epoch [14/24] completed: EDSR Loss: 0.0025, PSNR: 26.4064, Validation PSNR: 27.5321, SRResNEt Loss: 0.0025, PSNR: 26.3673, Validation PSNR: 27.5064


Epoch 15/24: 100%|██████████| 12500/12500 [19:40<00:00, 10.59batch/s]


Saved SRResNet model with PSNR 27.5654
Epoch [15/24] completed: EDSR Loss: 0.0024, PSNR: 26.4198, Validation PSNR: 27.5296, SRResNEt Loss: 0.0025, PSNR: 26.3891, Validation PSNR: 27.5654


Epoch 16/24: 100%|██████████| 12500/12500 [19:43<00:00, 10.56batch/s]


Saved EDSRR model with PSNR 27.5748
Epoch [16/24] completed: EDSR Loss: 0.0024, PSNR: 26.4233, Validation PSNR: 27.5748, SRResNEt Loss: 0.0025, PSNR: 26.3995, Validation PSNR: 25.9310


Epoch 17/24: 100%|██████████| 12500/12500 [19:42<00:00, 10.57batch/s]


Saved EDSRR model with PSNR 27.5857
Epoch [17/24] completed: EDSR Loss: 0.0024, PSNR: 26.4495, Validation PSNR: 27.5857, SRResNEt Loss: 0.0024, PSNR: 26.4367, Validation PSNR: 27.4756


Epoch 18/24: 100%|██████████| 12500/12500 [19:43<00:00, 10.56batch/s]


Saved EDSRR model with PSNR 27.5973
Saved SRResNet model with PSNR 27.5789
Epoch [18/24] completed: EDSR Loss: 0.0024, PSNR: 26.4581, Validation PSNR: 27.5973, SRResNEt Loss: 0.0024, PSNR: 26.4489, Validation PSNR: 27.5789


Epoch 19/24: 100%|██████████| 12500/12500 [19:46<00:00, 10.53batch/s]


Saved EDSRR model with PSNR 27.6128
Epoch [19/24] completed: EDSR Loss: 0.0024, PSNR: 26.4615, Validation PSNR: 27.6128, SRResNEt Loss: 0.0024, PSNR: 26.4551, Validation PSNR: 27.4525


Epoch 20/24: 100%|██████████| 12500/12500 [19:42<00:00, 10.57batch/s]


Saved EDSRR model with PSNR 27.6202
Saved SRResNet model with PSNR 27.6191
Epoch [20/24] completed: EDSR Loss: 0.0024, PSNR: 26.4659, Validation PSNR: 27.6202, SRResNEt Loss: 0.0024, PSNR: 26.4627, Validation PSNR: 27.6191


Epoch 21/24: 100%|██████████| 12500/12500 [19:46<00:00, 10.53batch/s]


Epoch [21/24] completed: EDSR Loss: 0.0024, PSNR: 26.4673, Validation PSNR: 27.6111, SRResNEt Loss: 0.0024, PSNR: 26.4664, Validation PSNR: 27.5787


Epoch 22/24: 100%|██████████| 12500/12500 [19:44<00:00, 10.55batch/s]


Epoch [22/24] completed: EDSR Loss: 0.0024, PSNR: 26.4786, Validation PSNR: 27.6161, SRResNEt Loss: 0.0024, PSNR: 26.4799, Validation PSNR: 27.5993


Epoch 23/24: 100%|██████████| 12500/12500 [19:42<00:00, 10.57batch/s]


Epoch [23/24] completed: EDSR Loss: 0.0024, PSNR: 26.4781, Validation PSNR: 27.6116, SRResNEt Loss: 0.0024, PSNR: 26.4822, Validation PSNR: 27.5061


Epoch 24/24: 100%|██████████| 12500/12500 [19:48<00:00, 10.52batch/s]


Saved EDSRR model with PSNR 27.6206
Epoch [24/24] completed: EDSR Loss: 0.0024, PSNR: 26.4856, Validation PSNR: 27.6206, SRResNEt Loss: 0.0024, PSNR: 26.4909, Validation PSNR: 27.5721


# 5. Testing

In [11]:
edsr_orig = edsr_orig.cpu()
srresnet = srresnet.cpu()
srresnet.eval()
edsr_orig.eval()

val_psnr_values_sobel = 0
val_psnr_values_canny = 0
torch.cuda.empty_cache()
with torch.no_grad():  # No gradients during validation
        for (lr_images, hr_images) in tqdm(valid_loader, unit='batch'):
                lr_images = lr_images.cpu()
                hr_images = hr_images.cpu()

                # Sobel SR validation (no loss, only PSNR)
                outputs_sobel = srresnet(lr_images)
                psnr_sobel = calculate_psnr(outputs_sobel, hr_images)

                # Canny SR validation (no loss, only PSNR)
                outputs_canny = edsr_orig(lr_images)
                psnr_canny = calculate_psnr(outputs_canny, hr_images)

                # Update validation PSNR
                val_psnr_values_sobel += psnr_sobel
                val_psnr_values_canny += psnr_canny

        # Calculate average validation PSNR
        val_average_psnr_sobel = val_psnr_values_sobel / len(valid_loader)

        val_average_psnr_canny = val_psnr_values_canny / len(valid_loader)
        print(val_average_psnr_canny, val_average_psnr_sobel)

100%|██████████| 10/10 [00:40<00:00,  4.08s/batch]

27.62064323425293 27.57206916809082



