# 1. Import thư viện

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchsummary import summary
from torch.utils.data import DataLoader, Dataset
import cv2
import numpy as np
import os
from PIL import Image
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
from models.esrpcb import *

In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Tạo Mô hình SR

In [None]:
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
import random
class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, scale=4, patch_size=96, valid=False):
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.valid = valid
        self.scale = scale
        self.patch_size = patch_size  # Kích thước patch crop
    def __len__(self):
        return len(self.lr_files)
    
    def __getitem__(self, idx):
        lr_img = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('RGB')
        hr_img = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('RGB')
        
        if self.valid:
            def transform_fn(hr_img, lr_img):
                # print(lr_img.size)
                
                lr_img = TF.resize(hr_img, (w // self.scale, h // self.scale), antialias=True)
                
                lr_img = TF.to_tensor(lr_img)
                hr_img = TF.to_tensor(hr_img)
        
                return lr_img, hr_img

        else:
            def transform_fn(hr_img, lr_img):
                w, h = hr_img.size
                top = random.randint(0, h - self.patch_size)
                left = random.randint(0, w - self.patch_size)
                hr_img = TF.crop(hr_img, top, left, self.patch_size, self.patch_size)

                # Crop ảnh LR tương ứng (phải chia tỷ lệ với scale)
                w,h = hr_img.size
                lr_img = TF.crop(hr_img, (w // self.scale, h // self.scale), antialias=True)
            
                # Chuyển sang tensor
                hr_img = TF.to_tensor(hr_img)
                lr_img = TF.to_tensor(lr_img)
            
                return lr_img, hr_img
        lr_img, hr_img = transform_fn(hr_img, lr_img)
        return lr_img, hr_img
        
    

# 3. Tạo Hyperparameter

In [None]:
# Đường dẫn tới bộ dữ liệu

# test_hr_dir  = '/kaggle/input/srdataset/sr_data/test/HR'
# test_lr_dir  = '/kaggle/input/srdataset/sr_data/test/LR'

# print(torch.cuda.memory_allocated())
# print(torch.cuda.memory_reserved())

In [None]:
import os
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"]="1"

In [None]:
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

In [None]:
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

def calculate_metrics(img1, img2, max_pixel_value=1.0):
    psnr = PeakSignalNoiseRatio(data_range=1.0).to(device)
    ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    psnr_value = psnr(img1, img2)
    ssim_value = ssim(img1, img2)
    return psnr_value, ssim_value

# 4. Training

In [None]:
from torch.amp import autocast, GradScaler
from torchsummary import summary
import math
scaler = GradScaler()

# Khởi tạo dataset và dataloader
# for scale in [2, 3, 4]:
scale = 4
batch_size = 16
learing_rate = 1e-4
num_iterations= 3e5

train_lr_dir = 'dataset/train/images'
train_hr_dir = 'dataset/train/images'
valid_lr_dir = 'dataset/val/images'
valid_hr_dir = 'dataset/val/images'
train_dataset = ImageDataset(train_lr_dir, train_hr_dir, scale=scale)
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)

valid_dataset = ImageDataset(valid_lr_dir, valid_hr_dir, scale=scale, valid=True)
valid_loader = DataLoader(valid_dataset)



sobelsr = ESRPCB(scale_factor = scale, use_sobel = True).to(device)
optim_sobel = optim.Adam(sobelsr.parameters(), lr=learing_rate,betas =(0.9, 0.999))
scheduler_sobel = optim.lr_scheduler.StepLR(optim_sobel, step_size=10**5, gamma=0.5)

cannysr = ESRPCB(scale_factor = scale, use_canny = True).to(device)
optim_canny = optim.Adam(cannysr.parameters(), lr=learing_rate,betas =(0.9, 0.999))
scheduler_canny = optim.lr_scheduler.StepLR(optim_canny, step_size=10**5, gamma=0.5)

criterion = nn.MSELoss()

num_epochs = math.ceil(num_iterations / len(train_loader))

best_psnr_sobel = float('-inf')
best_psnr_canny = float('-inf')
torch.cuda.empty_cache()

losses_sobel = []
losses_canny = []
avg_psnr_sobel = []
avg_psnr_canny = []

val_avg_psnr_sobel = []  # Validation PSNR
val_avg_psnr_canny = []

patience = int(0.25 * num_epochs)
epochs_no_improve = 0
log_file = open('outputs/train_log/esrpcb.txt', 'a')
scaler = GradScaler()

for epoch in range(num_epochs):
    sobelsr.train()
    cannysr.train()

    epoch_loss_sobel = 0
    psnr_values_sobel = 0
    epoch_loss_canny = 0
    psnr_values_canny = 0
    start_time = time.time()

    # Training loop
    for (lr_images, hr_images) in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch'):
        lr_images = lr_images.cuda()
        hr_images = hr_images.cuda()

        # Sobel SR training
        optim_sobel.zero_grad()  
        with autocast(device_type='cuda'):
            outputs_sobel = sobelsr(lr_images)
            loss_sobel = criterion(outputs_sobel, hr_images)
        psnr_sobel = calculate_psnr(outputs_sobel, hr_images)
            
        scaler.scale(loss_sobel).backward()
        scaler.step(optim_sobel)
        scaler.update()
        scheduler_sobel.step()

        # Canny SR training
        optim_canny.zero_grad()  
        with autocast(device_type='cuda'):
            outputs_canny = cannysr(lr_images)
            loss_canny = criterion(outputs_canny, hr_images)
        psnr_canny = calculate_psnr(outputs_canny, hr_images)

        scaler.scale(loss_canny).backward()
        scaler.step(optim_canny)
        scaler.update()
        scheduler_canny.step()
        
        # Update metrics
        epoch_loss_sobel += loss_sobel.item()
        psnr_values_sobel += psnr_sobel
        epoch_loss_canny += loss_canny.item()
        psnr_values_canny += psnr_canny

    # Calculate average training metrics
    avg_epoch_loss_sobel = epoch_loss_sobel / len(train_loader)
    average_psnr_sobel = psnr_values_sobel / len(train_loader)
    losses_sobel.append(avg_epoch_loss_sobel)
    avg_psnr_sobel.append(average_psnr_sobel)

    avg_epoch_loss_canny = epoch_loss_canny / len(train_loader)
    average_psnr_canny = psnr_values_canny / len(train_loader)
    losses_canny.append(avg_epoch_loss_canny)
    avg_psnr_canny.append(average_psnr_canny)

    # Validation step
    sobelsr.eval()
    cannysr.eval()

    val_psnr_values_sobel = 0
    val_psnr_values_canny = 0

    with torch.no_grad():  # No gradients during validation
        for (lr_images, hr_images) in tqdm(valid_loader, desc=f'Validation Epoch {epoch+1}/{num_epochs}', unit='batch'):
            lr_images = lr_images.cuda()
            hr_images = hr_images.cuda()

            # Sobel SR validation (no loss, only PSNR)
            outputs_sobel = sobelsr(lr_images)
            psnr_sobel = calculate_psnr(outputs_sobel, hr_images)

            # Canny SR validation (no loss, only PSNR)
            outputs_canny = cannysr(lr_images)
            psnr_canny = calculate_psnr(outputs_canny, hr_images)

            # Update validation PSNR
            val_psnr_values_sobel += psnr_sobel
            val_psnr_values_canny += psnr_canny

    # Calculate average validation PSNR
    val_average_psnr_sobel = val_psnr_values_sobel / len(valid_loader)
    val_avg_psnr_sobel.append(val_average_psnr_sobel)

    val_average_psnr_canny = val_psnr_values_canny / len(valid_loader)
    val_avg_psnr_canny.append(val_average_psnr_canny)

    end_time = time.time()

    # Logging results
    log_string = (f"Epoch {epoch+1}/{num_epochs}, Loss sobel: {avg_epoch_loss_sobel:.4f}, "
                f"Loss canny: {avg_epoch_loss_canny:.4f}, Time training: {end_time - start_time:.4f}s, "
                f"PSNR sobel: {average_psnr_sobel:.2f} dB, PSNR canny: {average_psnr_canny:.2f} dB, "
                f"Val PSNR sobel: {val_average_psnr_sobel:.2f} dB, Val PSNR canny: {val_average_psnr_canny:.2f} dB")
    print(log_string)
    log_file.write(log_string + '\n')
    log_file.flush()

    # Save best models based on validation PSNR
    if val_average_psnr_sobel > best_psnr_sobel:
        best_psnr_sobel = val_average_psnr_sobel
        torch.save(sobelsr.state_dict(), f'outputs/weight_sr/best_esrpcb_sobel.pth')
        print(f"Saved Sobel SR model with PSNR {best_psnr_sobel:.4f}")
        epochs_no_improve=0
    

    if val_average_psnr_canny > best_psnr_canny:
        best_psnr_canny = val_average_psnr_canny
        torch.save(cannysr.state_dict(), f'outputs/weight_sr/best_esrpcb_canny.pth')
        print(f"Saved Canny SR model with PSNR {best_psnr_canny:.4f}")
        epochs_no_improve=0
    
    if (val_average_psnr_sobel < best_psnr_sobel) and (val_average_psnr_canny < best_psnr_canny):
        epochs_no_improve+=1
    if epochs_no_improve >= patience:
        print(f"PSNR did not improve for {patience} epochs. Early stopping at epoch {epoch+1}")
        break
    # Clear cache and optionally save models at each epoch
    save_dir = f'outputs/path/x{scale}'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    torch.save(sobelsr.state_dict(), os.path.join(save_dir, f'esrpcb_sobel_{epoch}.pth'))
    torch.save(cannysr.state_dict(), os.path.join(save_dir, f'esrpcb_canny_{epoch}.pth'))
        # Close log file after training
log_file.close()

# Plotting results
plt.figure(figsize=(12, 10))

# Plot loss
plt.subplot(2, 1, 1)
plt.plot(losses_sobel, label='Sobel SR Loss (Train)')
plt.plot(losses_canny, label='Canny SR Loss (Train)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Loss')

# Plot PSNR
plt.subplot(2, 1, 2)
plt.plot(avg_psnr_sobel, label='Sobel SR PSNR (Train)')
plt.plot(val_avg_psnr_sobel, label='Sobel SR PSNR (Val)')
plt.plot(avg_psnr_canny, label='Canny SR PSNR (Train)')
plt.plot(val_avg_psnr_canny, label='Canny SR PSNR (Val)')
plt.xlabel('Epoch')
plt.ylabel('PSNR (dB)')
plt.legend()
plt.title('Average PSNR (Train and Val)')

plt.tight_layout()
plt.show()

# 5. Testing

In [None]:
cannysr = cannysr.to(device)
sobelsr = sobelsr.to(device)
sobelsr.eval()
cannysr.eval()
val_psnr_values_sobel = 0
val_psnr_values_canny = 0
val_ssim_values_sobel = 0
val_ssim_values_canny = 0
torch.cuda.empty_cache()
with torch.no_grad():  # No gradients during validation
        for (lr_images, hr_images) in tqdm(valid_loader, desc=f'Validation', unit='batch'):
                lr_images = lr_images.to(device)
                hr_images = hr_images.to(device)

                # Sobel SR validation (no loss, only PSNR)
                outputs_sobel = sobelsr(lr_images)
                psnr_sobel, ssim_sobel = calculate_metrics(outputs_sobel, hr_images)

                # Canny SR validation (no loss, only PSNR)
                outputs_canny = cannysr(lr_images)
                psnr_canny, ssim_canny = calculate_metrics(outputs_canny, hr_images)

                # Update validation PSNR
                val_psnr_values_sobel += psnr_sobel
                val_psnr_values_canny += psnr_canny
                val_ssim_values_sobel += ssim_sobel
                val_ssim_values_canny += ssim_canny
        
        # Calculate average validation PSNR & SSIM
        val_average_psnr_sobel = val_psnr_values_sobel / len(valid_loader)
        val_average_psnr_canny = val_psnr_values_canny / len(valid_loader)
        val_average_ssim_sobel = val_ssim_values_sobel / len(valid_loader)
        val_average_ssim_canny = val_ssim_values_canny / len(valid_loader)
        print(f'esrpcb canny: {val_average_ssim_canny:.4f} / {val_average_psnr_canny:.2f}')
        print(f'esrpcb sobel: {val_average_ssim_sobel:.4f} / {val_average_psnr_sobel:.2f}')