# 1. Import thư viện

In [1]:
import cv2
import numpy as np
import os
from PIL import Image
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DataParallel
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from torchvision.utils import save_image

from tqdm import tqdm
from models.vdsr import *
from models.srcnn import *
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Tạo Mô hình SR

In [5]:
class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, valid = False, scale=4, vdsr = False):
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.valid = valid
        self.scale = scale
        self.vdsr = vdsr
    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_image = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('RGB')
        hr_image = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('RGB')
        
        w, h = hr_image.size
  
        if self.valid:
            lr_image = lr_image.resize((w//self.scale, h//self.scale))
        if self.vdsr:
            lr_image = lr_image.resize((w, h))
        transform = transforms.Compose([
            # transforms.ToPILImage(),
            transforms.ToTensor()
        ])
        
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)
        return lr_image, hr_image

# 3. Tạo Hyperparameter

In [7]:
from torch.amp import autocast, GradScaler
scaler = GradScaler()

# Khởi tạo dataset và dataloader
for scale in [2, 3, 4]:
    train_lr_dir = f'dataset/Train/LR_{scale}'
    train_hr_dir = 'dataset/Train/HR'
    valid_lr_dir = 'dataset/Test/HR'
    valid_hr_dir = 'dataset/Test/HR'
    vdsr = VDSR().to(device)
    srcnn = SRCNN().to(device)
    train_dataset = ImageDataset(train_lr_dir, train_hr_dir,scale, vdsr=True)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    valid_dataset = ImageDataset(valid_lr_dir, valid_hr_dir, valid=True, scale=scale,vdsr=True)
    valid_loader = DataLoader(valid_dataset)

    # Khởi tạo loss function
    criterion = nn.MSELoss()
    
    optim_srcnn = optim.Adam(srcnn.parameters(), lr=1e-5, betas=(0.9, 0.999))
    scheduler_srcnn = optim.lr_scheduler.StepLR(optim_srcnn, step_size=10**5, gamma=0.5)

    optim_vdsr = optim.Adam(vdsr.parameters(), lr=1e-4, betas=(0.9, 0.999))
    scheduler_vdsr = optim.lr_scheduler.StepLR(optim_vdsr, step_size=10**5, gamma=0.5)
    # Hàm tính PSNR
    def calculate_psnr(img1, img2):
        mse = torch.mean((img1 - img2) ** 2)
        if mse == 0:
            return float('inf')
        max_pixel = 1.0
        psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
        return psnr.item()

    num_epochs = 24

    best_psnr_srcnn = float('-inf')
    best_psnr_vdsr = float('-inf')
    torch.cuda.empty_cache()

    losses_srcnn = []
    losses_vdsr = []

    avg_psnr_srcnn = []
    avg_psnr_vdsr = []

    val_avg_psnr_srcnn = []
    val_avg_psnr_vdsr = []

    patience = 5
    epochs_no_improve = 0
    log_file = open('outputs/train_log/srcnnn_vdsr.txt', 'a')

    for epoch in range(num_epochs):
        srcnn.train()
        vdsr.train()

        epoch_loss_srcnn, psnr_values_srcnn = 0, 0
        epoch_loss_vdsr, psnr_values_vdsr = 0, 0
        start_time = time.time()

        # Training loop for srcnn
        for (lr_images, hr_images) in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch'):
            lr_images = lr_images.cuda()
            hr_images = hr_images.cuda()

            # Train srcnn model
            optim_srcnn.zero_grad()
            with autocast(device_type='cuda'):
                outputs_srcnn =srcnn(lr_images)
                loss_srcnn = criterion(outputs_srcnn, hr_images)
            psnr_srcnn = calculate_psnr(outputs_srcnn, hr_images)
            # if psnr_srcnn < 27:
            scaler.scale(loss_srcnn).backward()
            scaler.step(optim_srcnn)
            scaler.update()
            scheduler_srcnn.step()

            epoch_loss_srcnn += loss_srcnn.item()
            psnr_values_srcnn += psnr_srcnn

            optim_vdsr.zero_grad()
            with autocast(device_type='cuda'):
                outputs_vdsr = vdsr(lr_images)
                loss_vdsr = criterion(outputs_vdsr, hr_images)
            psnr_vdsr = calculate_psnr(outputs_vdsr, hr_images)

            scaler.scale(loss_vdsr).backward()
            scaler.step(optim_vdsr)
            scaler.update()
            scheduler_vdsr.step()

            epoch_loss_vdsr += loss_vdsr.item()
            psnr_values_vdsr += psnr_vdsr
        
        # Training loop for vdsr
    
            
        # Average losses and PSNRs
        avg_epoch_loss_srcnn = epoch_loss_srcnn / len(train_loader)
        avg_psnr_srcnn_epoch = psnr_values_srcnn / len(train_loader)
        losses_srcnn.append(avg_epoch_loss_srcnn)
        avg_psnr_srcnn.append(avg_psnr_srcnn_epoch)

        avg_epoch_loss_vdsr = epoch_loss_vdsr / len(train_loader)
        avg_psnr_vdsr_epoch = psnr_values_vdsr / len(train_loader)
        losses_vdsr.append(avg_epoch_loss_vdsr)
        avg_psnr_vdsr.append(avg_psnr_vdsr_epoch)

        # Validation for srcnn and vdsr
        srcnn.eval()
        vdsr.eval()

        val_psnr_srcnn, val_psnr_vdsr = 0, 0

        with torch.no_grad():
            # Validate srcnn
            for (lr_images, hr_images) in valid_loader:
                lr_images = lr_images.cuda()
                hr_images = hr_images.cuda()

                outputs_srcnn = srcnn(lr_images)
                psnr_srcnn = calculate_psnr(outputs_srcnn, hr_images)
                val_psnr_srcnn += psnr_srcnn


                outputs_vdsr = vdsr(lr_images)
                psnr_vdsr = calculate_psnr(outputs_vdsr, hr_images)
                val_psnr_vdsr += psnr_vdsr
            # Validate vdsr
            


        val_avg_psnr_srcnn_epoch = val_psnr_srcnn / len(valid_loader)
        val_avg_psnr_srcnn.append(val_avg_psnr_srcnn_epoch)

        val_avg_psnr_vdsr_epoch = val_psnr_vdsr / len(valid_loader)
        val_avg_psnr_vdsr.append(val_avg_psnr_vdsr_epoch)

        # Save best model for srcnn
        if val_avg_psnr_srcnn_epoch > best_psnr_srcnn:
            best_psnr_srcnn = val_avg_psnr_srcnn_epoch
            torch.save(srcnn.state_dict(), f'outputs/weight_sr/x{scale}/best_srcnn.pth')
            print(f"Saved SRCNN model with PSNR {best_psnr_srcnn:.4f}")
        # Save best model for vdsr
        if val_avg_psnr_vdsr_epoch > best_psnr_vdsr:
            best_psnr_vdsr = val_avg_psnr_vdsr_epoch
            torch.save(vdsr.state_dict(), f'outputs/weight_sr/x{scale}/best_vdsr.pth')
            print(f"Saved VDSR model with PSNR {best_psnr_vdsr:.4f}")

        torch.save(srcnn.state_dict(), f'outputs/path/srcnn_{epoch+10}.pth')
        torch.save(vdsr.state_dict(), f'outputs/path/vdsr_{epoch+10}.pth')
        print(f"Epoch [{epoch+1}/{num_epochs}] completed: srcnn Loss: {avg_epoch_loss_srcnn:.4f}, PSNR: {avg_psnr_srcnn_epoch:.4f}, Validation PSNR: {val_avg_psnr_srcnn_epoch:.4f}")
        print(f"Epoch [{epoch+1}/{num_epochs}] completed: vdsr Loss: {avg_epoch_loss_vdsr:.4f}, PSNR: {avg_psnr_vdsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_vdsr_epoch:.4f}")

        log_file.write(f"Epoch {epoch+1}: WDSRA PSNR: {avg_psnr_srcnn_epoch:.4f}, Validation PSNR: {val_avg_psnr_srcnn_epoch:.4f}\n")
        log_file.write(f"Epoch {epoch+1}: vdsr PSNR: {avg_psnr_vdsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_vdsr_epoch:.4f}\n")

        # log_file.flush()

    log_file.close()


12500


# 4. Training