# 1. Import thư viện

In [1]:
import cv2
import numpy as np
import os
from PIL import Image
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DataParallel
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from torchvision.utils import save_image

from tqdm import tqdm


from models.sr_model import *
from models.wdsr import *
from models.srresnet_ import *
from models.sr_model import *
from models.utils import *
from models.vdsr import *
from models.srcnn import *
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Tạo Mô hình SR

In [3]:
# edsr_srcnn = EDSR_srcnnnnfy()
torch.cuda.empty_cache()
# srresnet = SRResNet().to(device)
# srcnn = RCAN().to(device)
vdsr = VDSR().to(device)
srcnn = SRCNN().to(device)


In [4]:
# vdsr.load_state_dict(torch.load('best_vdsr.pth', map_location=device))
# srcnn.load_state_dict(torch.load('best_srcnn.pth', map_location=device))

In [5]:
class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir):
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_image = cv2.imread(os.path.join(self.lr_dir, self.lr_files[idx]))
        hr_image = cv2.imread(os.path.join(self.hr_dir, self.hr_files[idx]))
        lr_image = cv2.cvtColor(lr_image, cv2.COLOR_BGR2RGB)
        hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)
        # height, width = lr_image.shape[:2]
        # scale_factor = 4
        # # lr_height, lr_width = height // 2, width // 2
        # hr_height, hr_width = height * 4, width * 4

        # # # Rescale ảnh LR lên 4 lần
        # # lr_image = cv2.resize(lr_image, (lr_width, lr_height), interpolation=cv2.INTER_CUBIC)
        # lr_image = cv2.resize(lr_image, (hr_width, hr_height), interpolation=cv2.INTER_CUBIC)
        
        # hr_image = cv2.resize(hr_image, (hr_width, hr_height), interpolation=cv2.INTER_CUBIC)
        size = 32
        transform_hr = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((size*4, size*4)),
            transforms.ToTensor()
        ])

        transform_lr = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((size, size)),
            transforms.Resize((size*4, size*4)),
            transforms.ToTensor()
        ])
        
        lr_image = transform_lr(lr_image)
        hr_image = transform_hr(hr_image)
        return lr_image, hr_image
    
# class ImageDataset(Dataset):
    # def __init__(self, lr_dir, hr_dir):
    #     self.lr_files = sorted(os.listdir(lr_dir))
    #     self.hr_files = sorted(os.listdir(hr_dir))
    #     self.lr_dir = lr_dir
    #     self.hr_dir = hr_dir

    # def __len__(self):
    #     return len(self.lr_files)

    # def __getitem__(self, idx):
    #     lr_image = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('RGB')
    #     hr_image = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('RGB')

    #     width, height = lr_image.size
        
    #     # Tính kích thước mới bằng cách nhân với scale_factor
    #     new_width = width * 4
    #     new_height = height * 4

    #     # Resize ảnh LR lên theo kích thước mới
    #     lr_image = lr_image.resize((new_width, new_height), Image.BICUBIC)

    #     hr_image = np.array(hr_image).astype(np.float32)
    #     lr_image = np.array(lr_image).astype(np.float32)
    #     hr_image = convert_rgb_to_y(hr_image)
    #     lr_image = convert_rgb_to_y(lr_image)
    #     hr_image = np.expand_dims(hr_image / 255., 0)
    #     lr_image = np.expand_dims(lr_image / 255., 0)
    #     transform = transforms.Compose([
    #         # transforms.ToPILImage(),
    #         transforms.ToTensor()
    #     ])
        
    #     # lr_image = transform(lr_image)
    #     # hr_image = transform(hr_image)
    #     return lr_image, hr_image

# 3. Tạo Hyperparameter

In [6]:
# Đường dẫn tới bộ dữ liệu
train_lr_dir = 'Train/LR'
train_hr_dir = 'Train/HR'
valid_lr_dir = 'Test/LR'
valid_hr_dir = 'Test/HR'
# test_hr_dir  = '/kaggle/input/srdataset/sr_data/test/HR'
# test_lr_dir  = '/kaggle/input/srdataset/sr_data/test/LR'

# print(torch.cuda.memory_allocated())
# print(torch.cuda.memory_reserved())

In [7]:
from torch.amp import autocast, GradScaler
scaler = GradScaler()

# Khởi tạo dataset và dataloader
train_dataset = ImageDataset(train_lr_dir, train_hr_dir)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

valid_dataset = ImageDataset(valid_lr_dir, valid_hr_dir)
valid_loader = DataLoader(valid_dataset)

# vdsr_train_dataset = vdsrImageDataset(train_lr_dir, train_hr_dir)
# vdsr_train_loader = DataLoader(vdsr_train_dataset, batch_size = 16, shuffle=True)

# vdsr_valid_dataset = vdsrImageDataset(valid_lr_dir, valid_hr_dir)
# vdsr_valid_loader = DataLoader(vdsr_valid_dataset)
print(len(train_loader))

# Khởi tạo loss function
criterion = nn.MSELoss()

# Khởi tạo optimizers, schedulers cho từng mô hình
# optim_edsr = optim.Adam(edsr_orig.parameters(), lr=1e-4, betas=(0.9, 0.999))
# scheduler_edsr = optim.lr_scheduler.StepLR(optim_edsr, step_size=10**5, gamma=0.5)

# optim_srresnet = optim.Adam(srresnet.parameters(), lr=1e-4, betas=(0.9, 0.999))
# scheduler_srresnet = optim.lr_scheduler.StepLR(optim_srresnet, step_size=10**5, gamma=0.5)

optim_srcnn = optim.Adam(srcnn.parameters(), lr=1e-5, betas=(0.9, 0.999))
scheduler_srcnn = optim.lr_scheduler.StepLR(optim_srcnn, step_size=10**5, gamma=0.5)

optim_vdsr = optim.Adam(vdsr.parameters(), lr=1e-4, betas=(0.9, 0.999))
scheduler_vdsr = optim.lr_scheduler.StepLR(optim_vdsr, step_size=10**5, gamma=0.5)

12500


In [8]:
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

# 4. Training

In [9]:
# Hàm tính PSNR
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

num_epochs = 24

best_psnr_srcnn = float('-inf')
best_psnr_vdsr = float('-inf')
torch.cuda.empty_cache()

losses_srcnn = []
losses_vdsr = []

avg_psnr_srcnn = []
avg_psnr_vdsr = []

val_avg_psnr_srcnn = []
val_avg_psnr_vdsr = []

patience = 10
epochs_no_improve = 0
log_file = open('training_log_models_srcnnn_vdsr.txt', 'a')

for epoch in range(num_epochs):
    srcnn.train()
    vdsr.train()

    epoch_loss_srcnn, psnr_values_srcnn = 0, 0
    epoch_loss_vdsr, psnr_values_vdsr = 0, 0
    start_time = time.time()

    # Training loop for srcnn
    for (lr_images, hr_images) in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch'):
        lr_images = lr_images.cuda()
        hr_images = hr_images.cuda()

        # Train srcnn model
        optim_srcnn.zero_grad()
        with autocast(device_type='cuda'):
            outputs_srcnn =srcnn(lr_images)
            loss_srcnn = criterion(outputs_srcnn, hr_images)
        psnr_srcnn = calculate_psnr(outputs_srcnn, hr_images)
        # if psnr_srcnn < 27:
        scaler.scale(loss_srcnn).backward()
        scaler.step(optim_srcnn)
        scaler.update()
        scheduler_srcnn.step()

        epoch_loss_srcnn += loss_srcnn.item()
        psnr_values_srcnn += psnr_srcnn

        optim_vdsr.zero_grad()
        with autocast(device_type='cuda'):
            outputs_vdsr = vdsr(lr_images)
            loss_vdsr = criterion(outputs_vdsr, hr_images)
        psnr_vdsr = calculate_psnr(outputs_vdsr, hr_images)

        scaler.scale(loss_vdsr).backward()
        scaler.step(optim_vdsr)
        scaler.update()
        scheduler_vdsr.step()

        epoch_loss_vdsr += loss_vdsr.item()
        psnr_values_vdsr += psnr_vdsr
      
    # Training loop for vdsr
   
        
    # Average losses and PSNRs
    avg_epoch_loss_srcnn = epoch_loss_srcnn / len(train_loader)
    avg_psnr_srcnn_epoch = psnr_values_srcnn / len(train_loader)
    losses_srcnn.append(avg_epoch_loss_srcnn)
    avg_psnr_srcnn.append(avg_psnr_srcnn_epoch)

    avg_epoch_loss_vdsr = epoch_loss_vdsr / len(train_loader)
    avg_psnr_vdsr_epoch = psnr_values_vdsr / len(train_loader)
    losses_vdsr.append(avg_epoch_loss_vdsr)
    avg_psnr_vdsr.append(avg_psnr_vdsr_epoch)

    # Validation for srcnn and vdsr
    srcnn.eval()
    vdsr.eval()

    val_psnr_srcnn, val_psnr_vdsr = 0, 0

    with torch.no_grad():
        # Validate srcnn
        for (lr_images, hr_images) in valid_loader:
            lr_images = lr_images.cuda()
            hr_images = hr_images.cuda()

            outputs_srcnn = srcnn(lr_images)
            psnr_srcnn = calculate_psnr(outputs_srcnn, hr_images)
            val_psnr_srcnn += psnr_srcnn


            outputs_vdsr = vdsr(lr_images)
            psnr_vdsr = calculate_psnr(outputs_vdsr, hr_images)
            val_psnr_vdsr += psnr_vdsr
        # Validate vdsr
        


    val_avg_psnr_srcnn_epoch = val_psnr_srcnn / len(valid_loader)
    val_avg_psnr_srcnn.append(val_avg_psnr_srcnn_epoch)

    val_avg_psnr_vdsr_epoch = val_psnr_vdsr / len(valid_loader)
    val_avg_psnr_vdsr.append(val_avg_psnr_vdsr_epoch)

    # Save best model for srcnn
    if val_avg_psnr_srcnn_epoch > best_psnr_srcnn:
        best_psnr_srcnn = val_avg_psnr_srcnn_epoch
        torch.save(srcnn.state_dict(), 'best_srcnn.pth')
        print(f"Saved SRCNN model with PSNR {best_psnr_srcnn:.4f}")
    # Save best model for vdsr
    if val_avg_psnr_vdsr_epoch > best_psnr_vdsr:
        best_psnr_vdsr = val_avg_psnr_vdsr_epoch
        torch.save(vdsr.state_dict(), 'best_vdsr.pth')
        print(f"Saved VDSR model with PSNR {best_psnr_vdsr:.4f}")

    torch.save(srcnn.state_dict(), f'path/srcnn_{epoch+10}.pth')
    torch.save(vdsr.state_dict(), f'path/vdsr_{epoch+10}.pth')
    print(f"Epoch [{epoch+1}/{num_epochs}] completed: srcnn Loss: {avg_epoch_loss_srcnn:.4f}, PSNR: {avg_psnr_srcnn_epoch:.4f}, Validation PSNR: {val_avg_psnr_srcnn_epoch:.4f}")
    print(f"Epoch [{epoch+1}/{num_epochs}] completed: vdsr Loss: {avg_epoch_loss_vdsr:.4f}, PSNR: {avg_psnr_vdsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_vdsr_epoch:.4f}")

    log_file.write(f"Epoch {epoch+1}: WDSRA PSNR: {avg_psnr_srcnn_epoch:.4f}, Validation PSNR: {val_avg_psnr_srcnn_epoch:.4f}\n")
    log_file.write(f"Epoch {epoch+1}: vdsr PSNR: {avg_psnr_vdsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_vdsr_epoch:.4f}\n")

    # log_file.flush()

log_file.close()


Epoch 1/24: 100%|██████████| 12500/12500 [19:01<00:00, 10.95batch/s]


Saved SRCNN model with PSNR 20.2167
Saved VDSR model with PSNR 20.0045
Epoch [1/24] completed: srcnn Loss: 0.0049, PSNR: 24.3751, Validation PSNR: 20.2167
Epoch [1/24] completed: vdsr Loss: 0.0031, PSNR: 25.6132, Validation PSNR: 20.0045


Epoch 2/24: 100%|██████████| 12500/12500 [18:31<00:00, 11.24batch/s]


Saved SRCNN model with PSNR 20.2361
Epoch [2/24] completed: srcnn Loss: 0.0033, PSNR: 25.1324, Validation PSNR: 20.2361
Epoch [2/24] completed: vdsr Loss: 0.0028, PSNR: 25.8507, Validation PSNR: 19.9032


Epoch 3/24: 100%|██████████| 12500/12500 [18:33<00:00, 11.23batch/s]


Epoch [3/24] completed: srcnn Loss: 0.0032, PSNR: 25.2385, Validation PSNR: 20.2347
Epoch [3/24] completed: vdsr Loss: 0.0027, PSNR: 25.9459, Validation PSNR: 19.9659


Epoch 4/24: 100%|██████████| 12500/12500 [18:36<00:00, 11.19batch/s]


Epoch [4/24] completed: srcnn Loss: 0.0032, PSNR: 25.2883, Validation PSNR: 20.2339
Epoch [4/24] completed: vdsr Loss: 0.0027, PSNR: 25.9940, Validation PSNR: 19.8760


Epoch 5/24: 100%|██████████| 12500/12500 [18:32<00:00, 11.23batch/s]


Saved SRCNN model with PSNR 20.2434
Epoch [5/24] completed: srcnn Loss: 0.0031, PSNR: 25.3202, Validation PSNR: 20.2434
Epoch [5/24] completed: vdsr Loss: 0.0027, PSNR: 26.0275, Validation PSNR: 19.9240


Epoch 6/24: 100%|██████████| 12500/12500 [18:12<00:00, 11.44batch/s]


Epoch [6/24] completed: srcnn Loss: 0.0031, PSNR: 25.3508, Validation PSNR: 20.2264
Epoch [6/24] completed: vdsr Loss: 0.0027, PSNR: 26.0588, Validation PSNR: 19.9704


Epoch 7/24: 100%|██████████| 12500/12500 [18:11<00:00, 11.46batch/s]


Epoch [7/24] completed: srcnn Loss: 0.0031, PSNR: 25.3713, Validation PSNR: 20.2204
Epoch [7/24] completed: vdsr Loss: 0.0027, PSNR: 26.0767, Validation PSNR: 19.8287


Epoch 8/24: 100%|██████████| 12500/12500 [18:12<00:00, 11.44batch/s]


Epoch [8/24] completed: srcnn Loss: 0.0031, PSNR: 25.3834, Validation PSNR: 20.2035
Epoch [8/24] completed: vdsr Loss: 0.0027, PSNR: 26.0911, Validation PSNR: 19.9163


Epoch 9/24: 100%|██████████| 12500/12500 [18:13<00:00, 11.43batch/s]


Epoch [9/24] completed: srcnn Loss: 0.0031, PSNR: 25.4011, Validation PSNR: 20.2220
Epoch [9/24] completed: vdsr Loss: 0.0026, PSNR: 26.1471, Validation PSNR: 19.8394


Epoch 10/24: 100%|██████████| 12500/12500 [18:11<00:00, 11.45batch/s]


Epoch [10/24] completed: srcnn Loss: 0.0031, PSNR: 25.4066, Validation PSNR: 20.1926
Epoch [10/24] completed: vdsr Loss: 0.0026, PSNR: 26.1567, Validation PSNR: 19.8673


Epoch 11/24: 100%|██████████| 12500/12500 [18:11<00:00, 11.45batch/s]


Epoch [11/24] completed: srcnn Loss: 0.0031, PSNR: 25.4132, Validation PSNR: 20.1840
Epoch [11/24] completed: vdsr Loss: 0.0026, PSNR: 26.1634, Validation PSNR: 19.9007


Epoch 12/24: 100%|██████████| 12500/12500 [18:17<00:00, 11.39batch/s]


Epoch [12/24] completed: srcnn Loss: 0.0031, PSNR: 25.4178, Validation PSNR: 20.1773
Epoch [12/24] completed: vdsr Loss: 0.0026, PSNR: 26.1675, Validation PSNR: 19.9088


Epoch 13/24: 100%|██████████| 12500/12500 [18:26<00:00, 11.29batch/s]


Epoch [13/24] completed: srcnn Loss: 0.0031, PSNR: 25.4259, Validation PSNR: 20.1676
Epoch [13/24] completed: vdsr Loss: 0.0026, PSNR: 26.1760, Validation PSNR: 19.9335


Epoch 14/24: 100%|██████████| 12500/12500 [18:25<00:00, 11.30batch/s]


Epoch [14/24] completed: srcnn Loss: 0.0031, PSNR: 25.4357, Validation PSNR: 20.1437
Epoch [14/24] completed: vdsr Loss: 0.0026, PSNR: 26.1882, Validation PSNR: 19.8119


Epoch 15/24: 100%|██████████| 12500/12500 [18:24<00:00, 11.31batch/s]


Epoch [15/24] completed: srcnn Loss: 0.0031, PSNR: 25.4375, Validation PSNR: 20.1523
Epoch [15/24] completed: vdsr Loss: 0.0026, PSNR: 26.1894, Validation PSNR: 19.8214


Epoch 16/24: 100%|██████████| 12500/12500 [18:30<00:00, 11.26batch/s]


Epoch [16/24] completed: srcnn Loss: 0.0031, PSNR: 25.4423, Validation PSNR: 20.1565
Epoch [16/24] completed: vdsr Loss: 0.0026, PSNR: 26.1956, Validation PSNR: 19.8761


Epoch 17/24: 100%|██████████| 12500/12500 [18:36<00:00, 11.20batch/s]


Epoch [17/24] completed: srcnn Loss: 0.0031, PSNR: 25.4450, Validation PSNR: 20.1650
Epoch [17/24] completed: vdsr Loss: 0.0026, PSNR: 26.2198, Validation PSNR: 19.8797


Epoch 18/24: 100%|██████████| 12500/12500 [18:28<00:00, 11.27batch/s]


Epoch [18/24] completed: srcnn Loss: 0.0031, PSNR: 25.4518, Validation PSNR: 20.1421
Epoch [18/24] completed: vdsr Loss: 0.0026, PSNR: 26.2266, Validation PSNR: 19.8160


Epoch 19/24: 100%|██████████| 12500/12500 [18:33<00:00, 11.22batch/s]


Epoch [19/24] completed: srcnn Loss: 0.0031, PSNR: 25.4593, Validation PSNR: 20.1619
Epoch [19/24] completed: vdsr Loss: 0.0026, PSNR: 26.2365, Validation PSNR: 19.8439


Epoch 20/24: 100%|██████████| 12500/12500 [18:33<00:00, 11.22batch/s]


Epoch [20/24] completed: srcnn Loss: 0.0031, PSNR: 25.4561, Validation PSNR: 20.1435
Epoch [20/24] completed: vdsr Loss: 0.0026, PSNR: 26.2319, Validation PSNR: 19.8995


Epoch 21/24: 100%|██████████| 12500/12500 [18:31<00:00, 11.24batch/s]


Epoch [21/24] completed: srcnn Loss: 0.0031, PSNR: 25.4544, Validation PSNR: 20.1533
Epoch [21/24] completed: vdsr Loss: 0.0026, PSNR: 26.2318, Validation PSNR: 19.8703


Epoch 22/24: 100%|██████████| 12500/12500 [18:19<00:00, 11.37batch/s]


Epoch [22/24] completed: srcnn Loss: 0.0031, PSNR: 25.4552, Validation PSNR: 20.1520
Epoch [22/24] completed: vdsr Loss: 0.0026, PSNR: 26.2318, Validation PSNR: 19.9379


Epoch 23/24: 100%|██████████| 12500/12500 [18:12<00:00, 11.44batch/s]


Epoch [23/24] completed: srcnn Loss: 0.0030, PSNR: 25.4588, Validation PSNR: 20.1508
Epoch [23/24] completed: vdsr Loss: 0.0026, PSNR: 26.2375, Validation PSNR: 19.8199


Epoch 24/24: 100%|██████████| 12500/12500 [18:22<00:00, 11.34batch/s]


Epoch [24/24] completed: srcnn Loss: 0.0030, PSNR: 25.4647, Validation PSNR: 20.1518
Epoch [24/24] completed: vdsr Loss: 0.0026, PSNR: 26.2436, Validation PSNR: 19.8352


# 5. Testing

In [None]:
!nvidia-smi
with torch.no_grad():
    for lr_image_file, hr_image_file in tqdm(zip(lr_image_files, hr_image_files), unit = 'batch'):
        # Đường dẫn đến ảnh
        lr_image_path = os.path.join(test_lr_dir, lr_image_file)
        hr_image_path = os.path.join(test_hr_dir, hr_image_file)
        output_image_path = os.path.join(output_image_dir, lr_image_file[:-4] + '_edsrx4.jpg')

        # Tải và chuyển đổi ảnh
        lr_image = Image.open(lr_image_path)
        hr_image = Image.open(hr_image_path)

        lr_image = transform(lr_image).unsqueeze(0).cuda()  # Thêm batch dimension và chuyển sang CPU
        hr_image = transform(hr_image).unsqueeze(0).cuda()  # Thêm batch dimension và chuyển sang CPU

        # Dự đoán
        sobel = sobelsr(lr_image)
#         canny = cannysr(lr_image)

        # Tính toán PSNR
        psnr_sobel = calculate_psnr(sobel, hr_image)
#         psnr_canny = calculate_psnr(canny, hr_image)

        psnr_values_sobel.append(psnr_sobel)
#         psnr_values_canny.append(psnr_canny)

#         # Chuyển đổi tensor đầu ra thành ảnh và lưu
#         output_image = output.squeeze(0).cuda()  # Loại bỏ batch dimension và chuyển tensor sang CPU
#         output_image = transforms.ToPILImage()(output_image)  # Chuyển tensor thành ảnh PIL
#         output_image.save(output_image_path)  # Lưu ảnh

# Tính toán PSNR trung bình
average_psnr_sobel = sum(psnr_values_sobel) / len(psnr_values_sobel)
print(f"Average PSNR sobel: {average_psnr_sobel:.2f} dB")

# average_psnr_canny = sum(psnr_values_canny) / len(psnr_values_canny)
# print(f"Average PSNR canny: {average_psnr_canny:.2f} dB")