# 1. Import thư viện

In [1]:
import cv2
import numpy as np
import os
from PIL import Image
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DataParallel
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from torchvision.utils import save_image
import torchsummary
from tqdm import tqdm
from models.e2dsr import *

import time



In [2]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Tạo Mô hình SR

In [3]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# device = torch.device('cpu')

In [4]:
class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, valid = False, scale=4):
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.valid = valid
        self.scale = scale
    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_image = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('RGB')
        hr_image = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('RGB')
        
        w, h = hr_image.size
        if self.valid:
            lr_image = lr_image.resize((w//self.scale, h//self.scale))
            
        transform = transforms.Compose([
            # transforms.ToPILImage(),
            transforms.ToTensor()
        ])
        
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)
        return lr_image, hr_image

# 3. Tạo Hyperparameter

In [5]:
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

In [None]:
from torch.amp import autocast, GradScaler
scaler = GradScaler()

# Khởi tạo dataset và dataloader
for scale in [2, 3, 4]:
    train_lr_dir = f'dataset/Train/LR_{scale}'
    train_hr_dir = 'dataset/Train/HR'
    valid_lr_dir = 'dataset/Test/HR'
    valid_hr_dir = 'dataset/Test/HR'
    
    train_dataset = ImageDataset(train_lr_dir, train_hr_dir)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    valid_dataset = ImageDataset(valid_lr_dir, valid_hr_dir, valid = True, scale=scale)
    valid_loader = DataLoader(valid_dataset)

    # Khởi tạo loss function
    criterion = nn.MSELoss()
    e2dsr_sobel = E2DSR(edge_option='sobel', scale_factor=scale).to(device)
    e2dsr_canny = E2DSR(edge_option='canny', scale_factor=scale).to(device)
    # Khởi tạo optimizers, schedulers cho từng mô hình
    optim_e2dsr_canny = optim.Adam(e2dsr_canny.parameters(), lr=1e-4, betas=(0.9, 0.999))
    scheduler_e2dsr_canny = optim.lr_scheduler.StepLR(optim_e2dsr_canny, step_size=10**5, gamma=0.5)

    optim_e2dsr_sobel = optim.Adam(e2dsr_sobel.parameters(), lr=1e-4, betas=(0.9, 0.999))
    scheduler_e2dsr_sobel = optim.lr_scheduler.StepLR(optim_e2dsr_sobel, step_size=10**5, gamma=0.5)
    
    num_epochs = 24

    best_psnr_e2dsr_canny = float('-inf')
    best_psnr_e2dsr_sobel = float('-inf')
    torch.cuda.empty_cache()

    losses_e2dsr_canny = []
    losses_e2dsr_sobel = []

    avg_psnr_e2dsr_canny = []
    avg_psnr_e2dsr_sobel = []

    val_avg_psnr_e2dsr_canny = []
    val_avg_psnr_e2dsr_sobel = []

    patience = 5
    epochs_no_improve = 0
    log_file = open('e2dsr.txt', 'a')

    for epoch in range(num_epochs):
        e2dsr_canny.train()
        e2dsr_sobel.train()

        epoch_loss_e2dsr_canny, psnr_values_e2dsr_canny = 0, 0
        epoch_loss_e2dsr_sobel, psnr_values_e2dsr_sobel = 0, 0

        start_time = time.time()
        torch.cuda.empty_cache()
        # Training loop for each model
        for i, (lr_images, hr_images) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')):
            lr_images = lr_images.to(device)
            hr_images = hr_images.to(device)

            # # Train e2dsr_canny model
            optim_e2dsr_canny.zero_grad()
            with autocast(device_type='cuda'):
                outputs_e2dsr_canny = e2dsr_canny(lr_images)
                loss_e2dsr_canny = criterion(outputs_e2dsr_canny, hr_images)
            psnr_e2dsr_canny = calculate_psnr(outputs_e2dsr_canny, hr_images)

            scaler.scale(loss_e2dsr_canny).backward()
            scaler.step(optim_e2dsr_canny)
            scaler.update()
            scheduler_e2dsr_canny.step()

            epoch_loss_e2dsr_canny += loss_e2dsr_canny.item()
            
            psnr_values_e2dsr_canny += psnr_e2dsr_canny

            # Train e2dsr_sobel model
            optim_e2dsr_sobel.zero_grad()
            with autocast(device_type='cuda'):
                outputs_e2dsr_sobel = e2dsr_sobel(lr_images)
                loss_e2dsr_sobel = criterion(outputs_e2dsr_sobel, hr_images)
            psnr_e2dsr_sobel = calculate_psnr(outputs_e2dsr_sobel, hr_images)

            scaler.scale(loss_e2dsr_sobel).backward()
            scaler.step(optim_e2dsr_sobel)
            scaler.update()
            scheduler_e2dsr_sobel.step()

            epoch_loss_e2dsr_sobel += loss_e2dsr_sobel.item()
            psnr_values_e2dsr_sobel += psnr_e2dsr_sobel

        # Average losses and PSNRs
        avg_epoch_loss_e2dsr_canny = epoch_loss_e2dsr_canny / len(train_loader)
        avg_psnr_e2dsr_canny_epoch = psnr_values_e2dsr_canny / len(train_loader)
        losses_e2dsr_canny.append(avg_epoch_loss_e2dsr_canny)
        avg_psnr_e2dsr_canny.append(avg_psnr_e2dsr_canny_epoch)

        avg_epoch_loss_e2dsr_sobel = epoch_loss_e2dsr_sobel / len(train_loader)
        avg_psnr_e2dsr_sobel_epoch = psnr_values_e2dsr_sobel / len(train_loader)
        losses_e2dsr_sobel.append(avg_epoch_loss_e2dsr_sobel)
        avg_psnr_e2dsr_sobel.append(avg_psnr_e2dsr_sobel_epoch)

        # Validation for all models
        e2dsr_canny.eval()
        e2dsr_sobel.eval()

        val_psnr_e2dsr_canny, val_psnr_e2dsr_sobel = 0, 0
        val_psnr_vdsr, val_psnr_fsrcnn = 0, 0

        with torch.no_grad():
            for (lr_images, hr_images) in valid_loader:
                lr_images = lr_images.cuda()
                hr_images = hr_images.cuda()

                # # Validate e2dsr_canny
                outputs_e2dsr_canny = e2dsr_canny(lr_images)
                psnr_e2dsr_canny = calculate_psnr(outputs_e2dsr_canny, hr_images)
                val_psnr_e2dsr_canny += psnr_e2dsr_canny

                # Validate e2dsr_sobelbel
                outputs_e2dsr_sobel = e2dsr_sobel(lr_images)
                psnr_e2dsr_sobel = calculate_psnr(outputs_e2dsr_sobel, hr_images)
                val_psnr_e2dsr_sobel += psnr_e2dsr_sobel

        val_avg_psnr_e2dsr_canny_epoch = val_psnr_e2dsr_canny / len(valid_loader)
        val_avg_psnr_e2dsr_canny.append(val_avg_psnr_e2dsr_canny_epoch)

        val_avg_psnr_e2dsr_sobel_epoch = val_psnr_e2dsr_sobel / len(valid_loader)
        val_avg_psnr_e2dsr_sobel.append(val_avg_psnr_e2dsr_sobel_epoch)

        # Save best model
        if val_avg_psnr_e2dsr_canny_epoch > best_psnr_e2dsr_canny:
            best_psnr_e2dsr_canny = val_avg_psnr_e2dsr_canny_epoch
            torch.save(e2dsr_canny.state_dict(), f'outputs/weight_sr/x{scale}/best_e2dsr_canny.pth')
            print(f"Saved e2dsr_cannyR model with PSNR {best_psnr_e2dsr_canny:.4f}")
        if val_avg_psnr_e2dsr_sobel_epoch > best_psnr_e2dsr_sobel:
            best_psnr_e2dsr_sobel = val_avg_psnr_e2dsr_sobel_epoch
            torch.save(e2dsr_sobel.state_dict(), f'outputs/weight_sr/x{scale}/best_e2dsr_sobel.pth')
            print(f"Saved e2dsr_sobel model with PSNR {best_psnr_e2dsr_sobel:.4f}")

        torch.save(e2dsr_canny.state_dict(), f'outputs/path/e2dsr_canny_{epoch}.pth')
        torch.save(e2dsr_sobel.state_dict(), f'outputs/path/e2dsr_sobel_{epoch}.pth')


        print(f"Epoch [{epoch+1}/{num_epochs}] completed: e2dsr_canny Loss: {avg_epoch_loss_e2dsr_canny:.4f}, PSNR: {avg_psnr_e2dsr_canny_epoch:.4f}, Validation PSNR: {val_avg_psnr_e2dsr_canny_epoch:.4f},"
              f"e2dsr_sobel Loss: {avg_epoch_loss_e2dsr_sobel:.4f}, PSNR: {avg_psnr_e2dsr_sobel_epoch:.4f}, Validation PSNR: {val_avg_psnr_e2dsr_sobel_epoch:.4f}")
    # 
        log_file.write(f"Epoch {epoch+1}:  e2dsr_canny PSNR: {avg_psnr_e2dsr_canny_epoch:.4f}, Validation PSNR: {val_avg_psnr_e2dsr_canny_epoch:.4f}\n")
        log_file.write(f"              e2dsr_sobel PSNR: {avg_psnr_e2dsr_sobel_epoch:.4f}, Validation PSNR: {val_avg_psnr_e2dsr_sobel_epoch:.4f}\n")

        log_file.flush()

    log_file.close()



Epoch 1/24:   0%|          | 43/12500 [00:04<22:12,  9.35batch/s] 


KeyboardInterrupt: 