# 1. Import thư viện

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchsummary import summary
from torch.utils.data import DataLoader, Dataset
import cv2
import numpy as np
import os
from PIL import Image
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
from models.hqsr import *

In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Tạo Mô hình SR

In [None]:
class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, scale, valid = False):
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.scale = scale
        self.valid = valid

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_image = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('RGB')
        hr_image = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('RGB')
    
        w, h= hr_image.size
        if self.valid:
            lr_image = lr_image.resize((w//self.scale, h//self.scale))
        transform = transforms.Compose([
            transforms.ToTensor()
        ])
        
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)
        return lr_image, hr_image

# 3. Tạo Hyperparameter

In [None]:
# Đường dẫn tới bộ dữ liệu

# test_hr_dir  = '/kaggle/input/srdataset/sr_data/test/HR'
# test_lr_dir  = '/kaggle/input/srdataset/sr_data/test/LR'

# print(torch.cuda.memory_allocated())
# print(torch.cuda.memory_reserved())

In [None]:
import os
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"]="1"

In [None]:
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

# 4. Training

In [None]:
# from torch.amp import autocast, GradScaler
# from torchsummary import summary
# scaler = GradScaler()

# # Khởi tạo dataset và dataloader
# for scale in [2, 3, 4]:
#     train_lr_dir = f'dataset/Train/LR_{scale}'
#     train_hr_dir = 'dataset/Train/HR'
#     valid_lr_dir = 'dataset/Test/HR'
#     valid_hr_dir = 'dataset/Test/HR'
#     train_dataset = ImageDataset(train_lr_dir, train_hr_dir, scale=scale)
#     train_loader = DataLoader(train_dataset, batch_size = 16, shuffle=True)

#     valid_dataset = ImageDataset(valid_lr_dir, valid_hr_dir, scale=scale, valid=True)
#     valid_loader = DataLoader(valid_dataset)

#     # print(len(train_loader))
#     # Khởi tạo mô hình, loss function và optimizer
#     torch.cuda.empty_cache()

#     sobelsr = HQSR(scale_factor = scale, use_sobel = True).to(device)
#     # sobelsr.load_state_dict(torch.load('weight/best_sobel_srx4_model.pth', map_location=device))
#     criterion = nn.MSELoss()
#     optim_sobel = optim.Adam(sobelsr.parameters(), lr=1e-4,betas =(0.9, 0.999))
#     scheduler_sobel = optim.lr_scheduler.StepLR(optim_sobel, step_size=10**5, gamma=0.5)
#     # summary(sobelsr.cuda(), input_size=(3, 510, 339), device='cuda')
#     cannysr = HQSR(scale_factor = scale, use_canny = True).to(device)
#     # cannysr = nn.DataParallel(cannysr).to(device)
#     # cannysr.load_state_dict(torch.load('weight/best_canny_srx4_model.pth', map_location=device))
    
#     optim_canny = optim.Adam(cannysr.parameters(), lr=1e-4,betas =(0.9, 0.999))
#     scheduler_canny = optim.lr_scheduler.StepLR(optim_canny, step_size=10**5, gamma=0.5)
#     num_epochs = 24

#     best_psnr_sobel = float('-inf')
#     best_psnr_canny = float('-inf')
#     torch.cuda.empty_cache()

#     losses_sobel = []
#     losses_canny = []
#     avg_psnr_sobel = []
#     avg_psnr_canny = []

#     val_avg_psnr_sobel = []  # Validation PSNR
#     val_avg_psnr_canny = []

#     patience = 5
#     epochs_no_improve = 0
#     log_file = open('outputs/train_log/hqsr.txt', 'a')
#     scaler = GradScaler()

#     for epoch in range(num_epochs):
#         sobelsr.train()
#         cannysr.train()

#         epoch_loss_sobel = 0
#         psnr_values_sobel = 0
#         epoch_loss_canny = 0
#         psnr_values_canny = 0
#         start_time = time.time()

#         # Training loop
#         # for (lr_images, hr_images) in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch'):
#         #     lr_images = lr_images.cuda()
#         #     hr_images = hr_images.cuda()

#         #     # Sobel SR training
#         #     optim_sobel.zero_grad()  
#         #     with autocast(device_type='cuda'):
#         #         outputs_sobel = sobelsr(lr_images)
#         #         loss_sobel = criterion(outputs_sobel, hr_images)
#         #     psnr_sobel = calculate_psnr(outputs_sobel, hr_images)
                
#         #     scaler.scale(loss_sobel).backward()
#         #     scaler.step(optim_sobel)
#         #     scaler.update()
#         #     scheduler_sobel.step()

#         #     # Canny SR training
#         #     optim_canny.zero_grad()  
#         #     with autocast(device_type='cuda'):
#         #         outputs_canny = cannysr(lr_images)
#         #         loss_canny = criterion(outputs_canny, hr_images)
#         #     psnr_canny = calculate_psnr(outputs_canny, hr_images)

#         #     scaler.scale(loss_canny).backward()
#         #     scaler.step(optim_canny)
#         #     scaler.update()
#         #     scheduler_canny.step()
            
#         #     # Update metrics
#         #     epoch_loss_sobel += loss_sobel.item()
#         #     psnr_values_sobel += psnr_sobel
#         #     epoch_loss_canny += loss_canny.item()
#         #     psnr_values_canny += psnr_canny

#         # Calculate average training metrics
#         avg_epoch_loss_sobel = epoch_loss_sobel / len(train_loader)
#         average_psnr_sobel = psnr_values_sobel / len(train_loader)
#         losses_sobel.append(avg_epoch_loss_sobel)
#         avg_psnr_sobel.append(average_psnr_sobel)

#         avg_epoch_loss_canny = epoch_loss_canny / len(train_loader)
#         average_psnr_canny = psnr_values_canny / len(train_loader)
#         losses_canny.append(avg_epoch_loss_canny)
#         avg_psnr_canny.append(average_psnr_canny)

#         # Validation step
#         sobelsr.eval()
#         cannysr.eval()

#         val_psnr_values_sobel = 0
#         val_psnr_values_canny = 0

#         with torch.no_grad():  # No gradients during validation
#             for (lr_images, hr_images) in tqdm(valid_loader, desc=f'Validation Epoch {epoch+1}/{num_epochs}', unit='batch'):
#                 lr_images = lr_images.cuda()
#                 hr_images = hr_images.cuda()

#                 # Sobel SR validation (no loss, only PSNR)
#                 outputs_sobel = sobelsr(lr_images)
#                 psnr_sobel = calculate_psnr(outputs_sobel, hr_images)

#                 # Canny SR validation (no loss, only PSNR)
#                 outputs_canny = cannysr(lr_images)
#                 psnr_canny = calculate_psnr(outputs_canny, hr_images)

#                 # Update validation PSNR
#                 val_psnr_values_sobel += psnr_sobel
#                 val_psnr_values_canny += psnr_canny

#         # Calculate average validation PSNR
#         val_average_psnr_sobel = val_psnr_values_sobel / len(valid_loader)
#         val_avg_psnr_sobel.append(val_average_psnr_sobel)

#         val_average_psnr_canny = val_psnr_values_canny / len(valid_loader)
#         val_avg_psnr_canny.append(val_average_psnr_canny)

#         end_time = time.time()

#         # Logging results
#         log_string = (f"Epoch {epoch+1}/{num_epochs}, Loss sobel: {avg_epoch_loss_sobel:.4f}, "
#                     f"Loss canny: {avg_epoch_loss_canny:.4f}, Time training: {end_time - start_time:.4f}s, "
#                     f"PSNR sobel: {average_psnr_sobel:.2f} dB, PSNR canny: {average_psnr_canny:.2f} dB, "
#                     f"Val PSNR sobel: {val_average_psnr_sobel:.2f} dB, Val PSNR canny: {val_average_psnr_canny:.2f} dB")
#         print(log_string)
#         log_file.write(log_string + '\n')
#         log_file.flush()

#         # Save best models based on validation PSNR
#         if val_average_psnr_sobel > best_psnr_sobel:
#             best_psnr_sobel = val_average_psnr_sobel
#             torch.save(sobelsr.state_dict(), f'outputs/weight_sr/x{scale}/best_hqsr_sobel.pth')
#             print(f"Saved Sobel SR model with PSNR {best_psnr_sobel:.4f}")
#             epochs_no_improve=0
        

#         if val_average_psnr_canny > best_psnr_canny:
#             best_psnr_canny = val_average_psnr_canny
#             torch.save(cannysr.state_dict(), f'outputs/weight_sr/x{scale}/best_hqsr_canny.pth')
#             print(f"Saved Canny SR model with PSNR {best_psnr_canny:.4f}")
#             epochs_no_improve=0
        
#         if (val_average_psnr_sobel < best_psnr_sobel) and (val_average_psnr_canny < best_psnr_canny):
#             epochs_no_improve+=1
#         if epochs_no_improve >= patience:
#             print(f"PSNR did not improve for 50 epochs. Early stopping at epoch {epoch+1}")
#             break
#         # Clear cache and optionally save models at each epoch
#         save_dir = f'outputs/path/x{scale}'
#         if not os.path.exists(save_dir):
#             os.makedirs(save_dir)

#         torch.save(sobelsr.state_dict(), os.path.join(save_dir, f'hqsr_sobel_{epoch}.pth'))
#         torch.save(cannysr.state_dict(), os.path.join(save_dir, f'hqsr_canny_{epoch}.pth'))
#             # Close log file after training
#     log_file.close()

#     # Plotting results
#     plt.figure(figsize=(12, 10))

#     # Plot loss
#     plt.subplot(2, 1, 1)
#     plt.plot(losses_sobel, label='Sobel SR Loss (Train)')
#     plt.plot(losses_canny, label='Canny SR Loss (Train)')
#     plt.xlabel('Epoch')
#     plt.ylabel('Loss')
#     plt.legend()
#     plt.title('Training Loss')

#     # Plot PSNR
#     plt.subplot(2, 1, 2)
#     plt.plot(avg_psnr_sobel, label='Sobel SR PSNR (Train)')
#     plt.plot(val_avg_psnr_sobel, label='Sobel SR PSNR (Val)')
#     plt.plot(avg_psnr_canny, label='Canny SR PSNR (Train)')
#     plt.plot(val_avg_psnr_canny, label='Canny SR PSNR (Val)')
#     plt.xlabel('Epoch')
#     plt.ylabel('PSNR (dB)')
#     plt.legend()
#     plt.title('Average PSNR (Train and Val)')

#     plt.tight_layout()
#     plt.show()

# E2DSR

In [None]:
import cv2
import numpy as np
import os
from PIL import Image
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DataParallel
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from torchvision.utils import save_image
import torchsummary
from tqdm import tqdm
from models.e2dsr import *
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import time
class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, valid = False, scale=4):
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.valid = valid
        self.scale = scale
    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_image = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('RGB')
        hr_image = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('RGB')
        
        w, h = hr_image.size
        if self.valid:
            lr_image = lr_image.resize((w//self.scale, h//self.scale))
            
        transform = transforms.Compose([
            # transforms.ToPILImage(),
            transforms.ToTensor()
        ])
        
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)
        return lr_image, hr_image
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

from torch.amp import autocast, GradScaler
scaler = GradScaler()

# Khởi tạo dataset và dataloader
for scale in [2, 3, 4]:
    print(f'Training Scale {scale}')
    train_lr_dir = f'dataset/Train/LR_{scale}'
    train_hr_dir = 'dataset/Train/HR'
    valid_lr_dir = 'dataset/Test/HR'
    valid_hr_dir = 'dataset/Test/HR'
    
    train_dataset = ImageDataset(train_lr_dir, train_hr_dir)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    valid_dataset = ImageDataset(valid_lr_dir, valid_hr_dir, valid = True, scale=scale)
    valid_loader = DataLoader(valid_dataset)

    # Khởi tạo loss function
    criterion = nn.MSELoss()
    e2dsr_sobel = E2DSR(edge_option='sobel', scale_factor=scale).to(device)
    e2dsr_canny = E2DSR(edge_option='canny', scale_factor=scale).to(device)
    # Khởi tạo optimizers, schedulers cho từng mô hình
    optim_e2dsr_canny = optim.Adam(e2dsr_canny.parameters(), lr=1e-4, betas=(0.9, 0.999))
    scheduler_e2dsr_canny = optim.lr_scheduler.StepLR(optim_e2dsr_canny, step_size=10**5, gamma=0.5)

    optim_e2dsr_sobel = optim.Adam(e2dsr_sobel.parameters(), lr=1e-4, betas=(0.9, 0.999))
    scheduler_e2dsr_sobel = optim.lr_scheduler.StepLR(optim_e2dsr_sobel, step_size=10**5, gamma=0.5)
    
    num_epochs = 24

    best_psnr_e2dsr_canny = float('-inf')
    best_psnr_e2dsr_sobel = float('-inf')
    torch.cuda.empty_cache()

    losses_e2dsr_canny = []
    losses_e2dsr_sobel = []

    avg_psnr_e2dsr_canny = []
    avg_psnr_e2dsr_sobel = []

    val_avg_psnr_e2dsr_canny = []
    val_avg_psnr_e2dsr_sobel = []

    patience = 5
    epochs_no_improve = 0
    log_file = open('outputs/train_log/e2dsr.txt', 'a')

    for epoch in range(num_epochs):
        e2dsr_canny.train()
        e2dsr_sobel.train()

        epoch_loss_e2dsr_canny, psnr_values_e2dsr_canny = 0, 0
        epoch_loss_e2dsr_sobel, psnr_values_e2dsr_sobel = 0, 0

        start_time = time.time()
        torch.cuda.empty_cache()
        # Training loop for each model
        for i, (lr_images, hr_images) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')):
            lr_images = lr_images.to(device)
            hr_images = hr_images.to(device)

            # # Train e2dsr_canny model
            optim_e2dsr_canny.zero_grad()
            with autocast(device_type='cuda'):
                outputs_e2dsr_canny = e2dsr_canny(lr_images)
                loss_e2dsr_canny = criterion(outputs_e2dsr_canny, hr_images)
            psnr_e2dsr_canny = calculate_psnr(outputs_e2dsr_canny, hr_images)

            scaler.scale(loss_e2dsr_canny).backward()
            scaler.step(optim_e2dsr_canny)
            scaler.update()
            scheduler_e2dsr_canny.step()

            epoch_loss_e2dsr_canny += loss_e2dsr_canny.item()
            
            psnr_values_e2dsr_canny += psnr_e2dsr_canny

            # Train e2dsr_sobel model
            optim_e2dsr_sobel.zero_grad()
            with autocast(device_type='cuda'):
                outputs_e2dsr_sobel = e2dsr_sobel(lr_images)
                loss_e2dsr_sobel = criterion(outputs_e2dsr_sobel, hr_images)
            psnr_e2dsr_sobel = calculate_psnr(outputs_e2dsr_sobel, hr_images)

            scaler.scale(loss_e2dsr_sobel).backward()
            scaler.step(optim_e2dsr_sobel)
            scaler.update()
            scheduler_e2dsr_sobel.step()

            epoch_loss_e2dsr_sobel += loss_e2dsr_sobel.item()
            psnr_values_e2dsr_sobel += psnr_e2dsr_sobel

        # Average losses and PSNRs
        avg_epoch_loss_e2dsr_canny = epoch_loss_e2dsr_canny / len(train_loader)
        avg_psnr_e2dsr_canny_epoch = psnr_values_e2dsr_canny / len(train_loader)
        losses_e2dsr_canny.append(avg_epoch_loss_e2dsr_canny)
        avg_psnr_e2dsr_canny.append(avg_psnr_e2dsr_canny_epoch)

        avg_epoch_loss_e2dsr_sobel = epoch_loss_e2dsr_sobel / len(train_loader)
        avg_psnr_e2dsr_sobel_epoch = psnr_values_e2dsr_sobel / len(train_loader)
        losses_e2dsr_sobel.append(avg_epoch_loss_e2dsr_sobel)
        avg_psnr_e2dsr_sobel.append(avg_psnr_e2dsr_sobel_epoch)

        # Validation for all models
        e2dsr_canny.eval()
        e2dsr_sobel.eval()

        val_psnr_e2dsr_canny, val_psnr_e2dsr_sobel = 0, 0
        val_psnr_vdsr, val_psnr_fsrcnn = 0, 0

        with torch.no_grad():
            for (lr_images, hr_images) in valid_loader:
                lr_images = lr_images.cuda()
                hr_images = hr_images.cuda()

                # # Validate e2dsr_canny
                outputs_e2dsr_canny = e2dsr_canny(lr_images)
                psnr_e2dsr_canny = calculate_psnr(outputs_e2dsr_canny, hr_images)
                val_psnr_e2dsr_canny += psnr_e2dsr_canny

                # Validate e2dsr_sobelbel
                outputs_e2dsr_sobel = e2dsr_sobel(lr_images)
                psnr_e2dsr_sobel = calculate_psnr(outputs_e2dsr_sobel, hr_images)
                val_psnr_e2dsr_sobel += psnr_e2dsr_sobel

        val_avg_psnr_e2dsr_canny_epoch = val_psnr_e2dsr_canny / len(valid_loader)
        val_avg_psnr_e2dsr_canny.append(val_avg_psnr_e2dsr_canny_epoch)

        val_avg_psnr_e2dsr_sobel_epoch = val_psnr_e2dsr_sobel / len(valid_loader)
        val_avg_psnr_e2dsr_sobel.append(val_avg_psnr_e2dsr_sobel_epoch)

        # Save best model
        if val_avg_psnr_e2dsr_canny_epoch > best_psnr_e2dsr_canny:
            best_psnr_e2dsr_canny = val_avg_psnr_e2dsr_canny_epoch
            torch.save(e2dsr_canny.state_dict(), f'outputs/weight_sr/x{scale}/best_e2dsr_canny.pth')
            print(f"Saved e2dsr_cannyR model with PSNR {best_psnr_e2dsr_canny:.4f}")
        if val_avg_psnr_e2dsr_sobel_epoch > best_psnr_e2dsr_sobel:
            best_psnr_e2dsr_sobel = val_avg_psnr_e2dsr_sobel_epoch
            torch.save(e2dsr_sobel.state_dict(), f'outputs/weight_sr/x{scale}/best_e2dsr_sobel.pth')
            print(f"Saved e2dsr_sobel model with PSNR {best_psnr_e2dsr_sobel:.4f}")

        torch.save(e2dsr_canny.state_dict(), f'outputs/path/x{scale}/e2dsr_canny_{epoch}.pth')
        torch.save(e2dsr_sobel.state_dict(), f'outputs/path/x{scale}/e2dsr_sobel_{epoch}.pth')


        print(f"Epoch [{epoch+1}/{num_epochs}] completed: e2dsr_canny Loss: {avg_epoch_loss_e2dsr_canny:.4f}, PSNR: {avg_psnr_e2dsr_canny_epoch:.4f}, Validation PSNR: {val_avg_psnr_e2dsr_canny_epoch:.4f},"
              f"e2dsr_sobel Loss: {avg_epoch_loss_e2dsr_sobel:.4f}, PSNR: {avg_psnr_e2dsr_sobel_epoch:.4f}, Validation PSNR: {val_avg_psnr_e2dsr_sobel_epoch:.4f}")
    # 
        log_file.write(f"Epoch {epoch+1}:  e2dsr_canny PSNR: {avg_psnr_e2dsr_canny_epoch:.4f}, Validation PSNR: {val_avg_psnr_e2dsr_canny_epoch:.4f}\n")
        log_file.write(f"              e2dsr_sobel PSNR: {avg_psnr_e2dsr_sobel_epoch:.4f}, Validation PSNR: {val_avg_psnr_e2dsr_sobel_epoch:.4f}\n")

        log_file.flush()

    log_file.close()



# SRResNEt EDSR

In [None]:
import cv2
import numpy as np
import os
from PIL import Image
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DataParallel
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from torchvision.utils import save_image
import torchsummary
from tqdm import tqdm
from models.srresnet import *
from models.edsr import *

import time
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, valid = False, scale=4):
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.valid = valid
        self.scale = scale
    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_image = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('RGB')
        hr_image = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('RGB')
        
        w, h = hr_image.size
        if self.valid:
            lr_image = lr_image.resize((w//self.scale, h//self.scale))
            
        transform = transforms.Compose([
            # transforms.ToPILImage(),
            transforms.ToTensor()
        ])
        
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)
        return lr_image, hr_image
    
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

from torch.amp import autocast, GradScaler
scaler = GradScaler()

# Khởi tạo dataset và dataloader
for scale in [2, 3, 4]:
    print(f'Traing scale {scale}')
    train_lr_dir = f'dataset/Train/LR_{scale}'
    train_hr_dir = 'dataset/Train/HR'
    valid_lr_dir = 'dataset/Test/HR'
    valid_hr_dir = 'dataset/Test/HR'
    train_dataset = ImageDataset(train_lr_dir, train_hr_dir, scale=scale)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    valid_dataset = ImageDataset(valid_lr_dir, valid_hr_dir, scale=scale, valid = True)
    valid_loader = DataLoader(valid_dataset)

    # Khởi tạo loss function
    criterion = nn.MSELoss()
    edsr = EDSR(scale=scale).to(device)
    srresnet = SRResNet(scale=scale).to(device)
    
    # Khởi tạo optimizers, schedulers cho từng mô hình
    optim_edsr = optim.Adam(edsr.parameters(), lr=1e-4, betas=(0.9, 0.999))
    scheduler_edsr = optim.lr_scheduler.StepLR(optim_edsr, step_size=10**5, gamma=0.5)

    optim_srresnet = optim.Adam(srresnet.parameters(), lr=1e-4, betas=(0.9, 0.999))
    scheduler_srresnet = optim.lr_scheduler.StepLR(optim_srresnet, step_size=10**5, gamma=0.5)
    
    num_epochs = 24

    best_psnr_edsr = float('-inf')
    best_psnr_srresnet = float('-inf')

    torch.cuda.empty_cache()

    losses_edsr = []
    losses_srresnet = []

    avg_psnr_edsr = []
    avg_psnr_srresnet = []

    val_avg_psnr_edsr = []
    val_avg_psnr_srresnet = []

    patience = 5
    epochs_no_improve = 0
    log_file = open('outputs/train_log/edsr_srresnet.txt', 'a')

    for epoch in range(num_epochs):
        edsr.train()
        srresnet.train()

        epoch_loss_edsr, psnr_values_edsr = 0, 0
        epoch_loss_srresnet, psnr_values_srresnet = 0, 0

        start_time = time.time()
        torch.cuda.empty_cache()
        # Training loop for each model
        for i, (lr_images, hr_images) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')):
            lr_images = lr_images.to(device)
            hr_images = hr_images.to(device)

            # # Train EDSR model
            optim_edsr.zero_grad()
            with autocast(device_type='cuda'):
                outputs_edsr = edsr(lr_images)
                loss_edsr = criterion(outputs_edsr, hr_images)
            psnr_edsr = calculate_psnr(outputs_edsr, hr_images)

            scaler.scale(loss_edsr).backward()
            scaler.step(optim_edsr)
            scaler.update()
            scheduler_edsr.step()

            epoch_loss_edsr += loss_edsr.item()
            
            psnr_values_edsr += psnr_edsr
            # Train SRResNet model
            optim_srresnet.zero_grad()
            with autocast(device_type='cuda'):
                outputs_srresnet = srresnet(lr_images)
                loss_srresnet = criterion(outputs_srresnet, hr_images)
            psnr_srresnet = calculate_psnr(outputs_srresnet, hr_images)

            scaler.scale(loss_srresnet).backward()
            scaler.step(optim_srresnet)
            scaler.update()
            scheduler_srresnet.step()

            epoch_loss_srresnet += loss_srresnet.item()
            psnr_values_srresnet += psnr_srresnet


        # Average losses and PSNRs
        avg_epoch_loss_edsr = epoch_loss_edsr / len(train_loader)
        avg_psnr_edsr_epoch = psnr_values_edsr / len(train_loader)
        losses_edsr.append(avg_epoch_loss_edsr)
        avg_psnr_edsr.append(avg_psnr_edsr_epoch)

        avg_epoch_loss_srresnet = epoch_loss_srresnet / len(train_loader)
        avg_psnr_srresnet_epoch = psnr_values_srresnet / len(train_loader)
        losses_srresnet.append(avg_epoch_loss_srresnet)
        avg_psnr_srresnet.append(avg_psnr_srresnet_epoch)

        # Validation for all models
        edsr.eval()
        srresnet.eval()
    

        val_psnr_edsr, val_psnr_srresnet = 0, 0

        with torch.no_grad():
            for (lr_images, hr_images) in valid_loader:
                lr_images = lr_images.cuda()
                hr_images = hr_images.cuda()

                # # Validate EDSR
                outputs_edsr = edsr(lr_images)
                psnr_edsr = calculate_psnr(outputs_edsr, hr_images)
                val_psnr_edsr += psnr_edsr

                # Validate SRResNet
                outputs_srresnet = srresnet(lr_images)
                psnr_srresnet = calculate_psnr(outputs_srresnet, hr_images)
                val_psnr_srresnet += psnr_srresnet

        val_avg_psnr_edsr_epoch = val_psnr_edsr / len(valid_loader)
        val_avg_psnr_edsr.append(val_avg_psnr_edsr_epoch)

        val_avg_psnr_srresnet_epoch = val_psnr_srresnet / len(valid_loader)
        val_avg_psnr_srresnet.append(val_avg_psnr_srresnet_epoch)

    
        # Save best model
        if val_avg_psnr_edsr_epoch > best_psnr_edsr:
            best_psnr_edsr = val_avg_psnr_edsr_epoch
            torch.save(edsr.state_dict(), f'outputs/weight_sr/x{scale}/best_edsr.pth')
            print(f"Saved EDSRR model with PSNR {best_psnr_edsr:.4f}")
        if val_avg_psnr_srresnet_epoch > best_psnr_srresnet:
            best_psnr_srresnet = val_avg_psnr_srresnet_epoch
            torch.save(srresnet.state_dict(), f'outputs/weight_sr/x{scale}/best_srresnet.pth')
            print(f"Saved SRResNet model with PSNR {best_psnr_srresnet:.4f}")

        torch.save(edsr.state_dict(), f'outputs/path/x{scale}/edsr_{epoch}.pth')
        torch.save(srresnet.state_dict(), f'outputs/path/x{scale}/srresnet_{epoch}.pth')
        

        print(f"Epoch [{epoch+1}/{num_epochs}] completed: EDSR Loss: {avg_epoch_loss_edsr:.4f}, PSNR: {avg_psnr_edsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_edsr_epoch:.4f},"
            f"SRResNEt Loss: {avg_epoch_loss_srresnet:.4f}, PSNR: {avg_psnr_srresnet_epoch:.4f}, Validation PSNR: {val_avg_psnr_srresnet_epoch:.4f}")

        log_file.write(f"Epoch {epoch+1}:  EDSR PSNR: {avg_psnr_edsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_edsr_epoch:.4f}\n")
        log_file.write(f"              SRResNet PSNR: {avg_psnr_srresnet_epoch:.4f}, Validation PSNR: {val_avg_psnr_srresnet_epoch:.4f}\n")
        
        log_file.flush()

    log_file.close()



# SRCNN 

In [None]:
import cv2
import numpy as np
import os
from PIL import Image
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DataParallel
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from torchvision.utils import save_image

from tqdm import tqdm
from models.vdsr import *
from models.srcnn import *
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, valid = False, scale=4, vdsr = False):
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.valid = valid
        self.scale = scale
        self.vdsr = vdsr
    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_image = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('RGB')
        hr_image = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('RGB')
        
        w, h = hr_image.size
  
        if self.valid:
            lr_image = lr_image.resize((w//self.scale, h//self.scale))
        if self.vdsr:
            lr_image = lr_image.resize((w, h))
        transform = transforms.Compose([
            # transforms.ToPILImage(),
            transforms.ToTensor()
        ])
        
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)
        return lr_image, hr_image
from torch.amp import autocast, GradScaler
scaler = GradScaler()

# Khởi tạo dataset và dataloader
for scale in [2, 3, 4]:
    print(f'train scale {scale}')
    train_lr_dir = f'dataset/Train/LR_{scale}'
    train_hr_dir = 'dataset/Train/HR'
    valid_lr_dir = 'dataset/Test/HR'
    valid_hr_dir = 'dataset/Test/HR'
    vdsr = VDSR().to(device)
    srcnn = SRCNN().to(device)
    train_dataset = ImageDataset(train_lr_dir, train_hr_dir,scale=scale, vdsr=True)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    valid_dataset = ImageDataset(valid_lr_dir, valid_hr_dir, valid=True, scale=scale,vdsr=True)
    valid_loader = DataLoader(valid_dataset)

    # Khởi tạo loss function
    criterion = nn.MSELoss()
    
    optim_srcnn = optim.Adam(srcnn.parameters(), lr=1e-5, betas=(0.9, 0.999))
    scheduler_srcnn = optim.lr_scheduler.StepLR(optim_srcnn, step_size=10**5, gamma=0.5)

    optim_vdsr = optim.Adam(vdsr.parameters(), lr=1e-4, betas=(0.9, 0.999))
    scheduler_vdsr = optim.lr_scheduler.StepLR(optim_vdsr, step_size=10**5, gamma=0.5)
    # Hàm tính PSNR
    def calculate_psnr(img1, img2):
        mse = torch.mean((img1 - img2) ** 2)
        if mse == 0:
            return float('inf')
        max_pixel = 1.0
        psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
        return psnr.item()

    num_epochs = 24

    best_psnr_srcnn = float('-inf')
    best_psnr_vdsr = float('-inf')
    torch.cuda.empty_cache()

    losses_srcnn = []
    losses_vdsr = []

    avg_psnr_srcnn = []
    avg_psnr_vdsr = []

    val_avg_psnr_srcnn = []
    val_avg_psnr_vdsr = []

    patience = 5
    epochs_no_improve = 0
    log_file = open('outputs/train_log/srcnnn_vdsr.txt', 'a')

    for epoch in range(num_epochs):
        srcnn.train()
        vdsr.train()

        epoch_loss_srcnn, psnr_values_srcnn = 0, 0
        epoch_loss_vdsr, psnr_values_vdsr = 0, 0
        start_time = time.time()

        # Training loop for srcnn
        for (lr_images, hr_images) in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch'):
            lr_images = lr_images.cuda()
            hr_images = hr_images.cuda()

            # Train srcnn model
            optim_srcnn.zero_grad()
            with autocast(device_type='cuda'):
                outputs_srcnn =srcnn(lr_images)
                loss_srcnn = criterion(outputs_srcnn, hr_images)
            psnr_srcnn = calculate_psnr(outputs_srcnn, hr_images)
            # if psnr_srcnn < 27:
            scaler.scale(loss_srcnn).backward()
            scaler.step(optim_srcnn)
            scaler.update()
            scheduler_srcnn.step()

            epoch_loss_srcnn += loss_srcnn.item()
            psnr_values_srcnn += psnr_srcnn

            optim_vdsr.zero_grad()
            with autocast(device_type='cuda'):
                outputs_vdsr = vdsr(lr_images)
                loss_vdsr = criterion(outputs_vdsr, hr_images)
            psnr_vdsr = calculate_psnr(outputs_vdsr, hr_images)

            scaler.scale(loss_vdsr).backward()
            scaler.step(optim_vdsr)
            scaler.update()
            scheduler_vdsr.step()

            epoch_loss_vdsr += loss_vdsr.item()
            psnr_values_vdsr += psnr_vdsr
        
        # Training loop for vdsr
    
            
        # Average losses and PSNRs
        avg_epoch_loss_srcnn = epoch_loss_srcnn / len(train_loader)
        avg_psnr_srcnn_epoch = psnr_values_srcnn / len(train_loader)
        losses_srcnn.append(avg_epoch_loss_srcnn)
        avg_psnr_srcnn.append(avg_psnr_srcnn_epoch)

        avg_epoch_loss_vdsr = epoch_loss_vdsr / len(train_loader)
        avg_psnr_vdsr_epoch = psnr_values_vdsr / len(train_loader)
        losses_vdsr.append(avg_epoch_loss_vdsr)
        avg_psnr_vdsr.append(avg_psnr_vdsr_epoch)

        # Validation for srcnn and vdsr
        srcnn.eval()
        vdsr.eval()

        val_psnr_srcnn, val_psnr_vdsr = 0, 0

        with torch.no_grad():
            # Validate srcnn
            for (lr_images, hr_images) in valid_loader:
                lr_images = lr_images.cuda()
                hr_images = hr_images.cuda()

                outputs_srcnn = srcnn(lr_images)
                psnr_srcnn = calculate_psnr(outputs_srcnn, hr_images)
                val_psnr_srcnn += psnr_srcnn


                outputs_vdsr = vdsr(lr_images)
                psnr_vdsr = calculate_psnr(outputs_vdsr, hr_images)
                val_psnr_vdsr += psnr_vdsr
            # Validate vdsr
            


        val_avg_psnr_srcnn_epoch = val_psnr_srcnn / len(valid_loader)
        val_avg_psnr_srcnn.append(val_avg_psnr_srcnn_epoch)

        val_avg_psnr_vdsr_epoch = val_psnr_vdsr / len(valid_loader)
        val_avg_psnr_vdsr.append(val_avg_psnr_vdsr_epoch)

        # Save best model for srcnn
        if val_avg_psnr_srcnn_epoch > best_psnr_srcnn:
            best_psnr_srcnn = val_avg_psnr_srcnn_epoch
            torch.save(srcnn.state_dict(), f'outputs/weight_sr/x{scale}/best_srcnn.pth')
            print(f"Saved SRCNN model with PSNR {best_psnr_srcnn:.4f}")
        # Save best model for vdsr
        if val_avg_psnr_vdsr_epoch > best_psnr_vdsr:
            best_psnr_vdsr = val_avg_psnr_vdsr_epoch
            torch.save(vdsr.state_dict(), f'outputs/weight_sr/x{scale}/best_vdsr.pth')
            print(f"Saved VDSR model with PSNR {best_psnr_vdsr:.4f}")

        torch.save(srcnn.state_dict(), f'outputs/path/srcnn_{epoch+10}.pth')
        torch.save(vdsr.state_dict(), f'outputs/path/vdsr_{epoch+10}.pth')
        print(f"Epoch [{epoch+1}/{num_epochs}] completed: srcnn Loss: {avg_epoch_loss_srcnn:.4f}, PSNR: {avg_psnr_srcnn_epoch:.4f}, Validation PSNR: {val_avg_psnr_srcnn_epoch:.4f}")
        print(f"Epoch [{epoch+1}/{num_epochs}] completed: vdsr Loss: {avg_epoch_loss_vdsr:.4f}, PSNR: {avg_psnr_vdsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_vdsr_epoch:.4f}")

        log_file.write(f"Epoch {epoch+1}: WDSRA PSNR: {avg_psnr_srcnn_epoch:.4f}, Validation PSNR: {val_avg_psnr_srcnn_epoch:.4f}\n")
        log_file.write(f"Epoch {epoch+1}: vdsr PSNR: {avg_psnr_vdsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_vdsr_epoch:.4f}\n")

        # log_file.flush()

    log_file.close()
