# 1. Import thư viện

In [11]:
import cv2
import numpy as np
import os
from PIL import Image
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DataParallel
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from torchvision.utils import save_image
import torchsummary
from tqdm import tqdm
from models.srresnet import *
from models.edsr import *

import time



In [12]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Tạo Mô hình SR

In [13]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# device = torch.device('cpu')

In [14]:

# edsr.load_state_dict(torch.load('weight/best_edsr.pth', map_location=device))
# edsr.load_state_dict(torch.load('weight/best_edsrx4_orig_model.pth', map_location=device))
# edrn_sobel.load_state_dict(torch.load('weight/best_sobel_srx4_model.pth', map_location=device))
# edrn_canny.load_state_dict(torch.load('weight/best_canny_srx4_model.pth', map_location=device))
# srresnet.load_state_dict(torch.load('best_srresnet.pth', map_location=device))
# vdsr.load_state_dict(torch.load('weight/best_vdsr.pth', map_location=device))

In [15]:
class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, valid = False, scale=4):
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.valid = valid
        self.scale = scale
    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_image = Image.open(os.path.join(self.lr_dir, self.lr_files[idx])).convert('RGB')
        hr_image = Image.open(os.path.join(self.hr_dir, self.hr_files[idx])).convert('RGB')
        
        w, h = hr_image.size
        if self.valid:
            lr_image = lr_image.resize((w//self.scale, h//self.scale))
            
        transform = transforms.Compose([
            # transforms.ToPILImage(),
            transforms.ToTensor()
        ])
        
        lr_image = transform(lr_image)
        hr_image = transform(hr_image)
        return lr_image, hr_image

# 3. Tạo Hyperparameter

In [16]:
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

In [17]:
from torch.amp import autocast, GradScaler
scaler = GradScaler()

# Khởi tạo dataset và dataloader
for scale in [2, 3, 4]:
    print(scale)
    train_lr_dir = f'dataset/Train/LR_{scale}'
    train_hr_dir = 'dataset/Train/HR'
    valid_lr_dir = 'dataset/Test/HR'
    valid_hr_dir = 'dataset/Test/HR'
    train_dataset = ImageDataset(train_lr_dir, train_hr_dir, scale=scale)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    valid_dataset = ImageDataset(valid_lr_dir, valid_hr_dir, scale=scale, valid = True)
    valid_loader = DataLoader(valid_dataset)

    # Khởi tạo loss function
    criterion = nn.MSELoss()
    edsr = EDSR(scale=scale).to(device)
    srresnet = SRResNet(scale=scale).to(device)
    
    # Khởi tạo optimizers, schedulers cho từng mô hình
    optim_edsr = optim.Adam(edsr.parameters(), lr=1e-4, betas=(0.9, 0.999))
    scheduler_edsr = optim.lr_scheduler.StepLR(optim_edsr, step_size=10**5, gamma=0.5)

    optim_srresnet = optim.Adam(srresnet.parameters(), lr=1e-4, betas=(0.9, 0.999))
    scheduler_srresnet = optim.lr_scheduler.StepLR(optim_srresnet, step_size=10**5, gamma=0.5)
    
    num_epochs = 24

    best_psnr_edsr = float('-inf')
    best_psnr_srresnet = float('-inf')

    torch.cuda.empty_cache()

    losses_edsr = []
    losses_srresnet = []

    avg_psnr_edsr = []
    avg_psnr_srresnet = []

    val_avg_psnr_edsr = []
    val_avg_psnr_srresnet = []

    patience = 5
    epochs_no_improve = 0
    log_file = open('outputs/train_log/edsr_srresnet.txt', 'a')

    for epoch in range(num_epochs):
        edsr.train()
        srresnet.train()

        epoch_loss_edsr, psnr_values_edsr = 0, 0
        epoch_loss_srresnet, psnr_values_srresnet = 0, 0

        start_time = time.time()
        torch.cuda.empty_cache()
        # Training loop for each model
        for i, (lr_images, hr_images) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')):
            lr_images = lr_images.to(device)
            hr_images = hr_images.to(device)

            # # Train EDSR model
            optim_edsr.zero_grad()
            with autocast(device_type='cuda'):
                outputs_edsr = edsr(lr_images)
                print(lr_images.shape)
                loss_edsr = criterion(outputs_edsr, hr_images)
            psnr_edsr = calculate_psnr(outputs_edsr, hr_images)

            scaler.scale(loss_edsr).backward()
            scaler.step(optim_edsr)
            scaler.update()
            scheduler_edsr.step()

            epoch_loss_edsr += loss_edsr.item()
            
            psnr_values_edsr += psnr_edsr
            # Train SRResNet model
            optim_srresnet.zero_grad()
            with autocast(device_type='cuda'):
                outputs_srresnet = srresnet(lr_images)
                loss_srresnet = criterion(outputs_srresnet, hr_images)
            psnr_srresnet = calculate_psnr(outputs_srresnet, hr_images)

            scaler.scale(loss_srresnet).backward()
            scaler.step(optim_srresnet)
            scaler.update()
            scheduler_srresnet.step()

            epoch_loss_srresnet += loss_srresnet.item()
            psnr_values_srresnet += psnr_srresnet


        # Average losses and PSNRs
        avg_epoch_loss_edsr = epoch_loss_edsr / len(train_loader)
        avg_psnr_edsr_epoch = psnr_values_edsr / len(train_loader)
        losses_edsr.append(avg_epoch_loss_edsr)
        avg_psnr_edsr.append(avg_psnr_edsr_epoch)

        avg_epoch_loss_srresnet = epoch_loss_srresnet / len(train_loader)
        avg_psnr_srresnet_epoch = psnr_values_srresnet / len(train_loader)
        losses_srresnet.append(avg_epoch_loss_srresnet)
        avg_psnr_srresnet.append(avg_psnr_srresnet_epoch)

        # Validation for all models
        edsr.eval()
        srresnet.eval()
    

        val_psnr_edsr, val_psnr_srresnet = 0, 0

        with torch.no_grad():
            for (lr_images, hr_images) in valid_loader:
                lr_images = lr_images.cuda()
                hr_images = hr_images.cuda()

                # # Validate EDSR
                outputs_edsr = edsr(lr_images)
                psnr_edsr = calculate_psnr(outputs_edsr, hr_images)
                val_psnr_edsr += psnr_edsr

                # Validate SRResNet
                outputs_srresnet = srresnet(lr_images)
                psnr_srresnet = calculate_psnr(outputs_srresnet, hr_images)
                val_psnr_srresnet += psnr_srresnet

        val_avg_psnr_edsr_epoch = val_psnr_edsr / len(valid_loader)
        val_avg_psnr_edsr.append(val_avg_psnr_edsr_epoch)

        val_avg_psnr_srresnet_epoch = val_psnr_srresnet / len(valid_loader)
        val_avg_psnr_srresnet.append(val_avg_psnr_srresnet_epoch)

    
        # Save best model
        if val_avg_psnr_edsr_epoch > best_psnr_edsr:
            best_psnr_edsr = val_avg_psnr_edsr_epoch
            torch.save(edsr.state_dict(), f'outputs/weight_sr/x{scale}/best_edsr.pth')
            print(f"Saved EDSRR model with PSNR {best_psnr_edsr:.4f}")
        if val_avg_psnr_srresnet_epoch > best_psnr_srresnet:
            best_psnr_srresnet = val_avg_psnr_srresnet_epoch
            torch.save(srresnet.state_dict(), f'outputs/weight_sr/x{scale}/best_srresnet.pth')
            print(f"Saved SRResNet model with PSNR {best_psnr_srresnet:.4f}")

        torch.save(edsr.state_dict(), f'outputs/path/edsr_{epoch}.pth')
        torch.save(srresnet.state_dict(), f'outputs/path/srresnet_{epoch}.pth')
        

        print(f"Epoch [{epoch+1}/{num_epochs}] completed: EDSR Loss: {avg_epoch_loss_edsr:.4f}, PSNR: {avg_psnr_edsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_edsr_epoch:.4f},"
            "SRResNEt Loss: {avg_epoch_loss_srresnet:.4f}, PSNR: {avg_psnr_srresnet_epoch:.4f}, Validation PSNR: {val_avg_psnr_srresnet_epoch:.4f}")

        log_file.write(f"Epoch {epoch+1}:  EDSR PSNR: {avg_psnr_edsr_epoch:.4f}, Validation PSNR: {val_avg_psnr_edsr_epoch:.4f}\n")
        log_file.write(f"              SRResNet PSNR: {avg_psnr_srresnet_epoch:.4f}, Validation PSNR: {val_avg_psnr_srresnet_epoch:.4f}\n")
        
        log_file.flush()

    log_file.close()



2


Epoch 1/24:   0%|          | 1/12500 [00:00<42:16,  4.93batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 4/12500 [00:00<21:47,  9.56batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 6/12500 [00:00<19:36, 10.62batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 10/12500 [00:00<18:09, 11.47batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 12/12500 [00:01<17:44, 11.73batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 16/12500 [00:01<17:10, 12.12batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 18/12500 [00:01<17:10, 12.11batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 22/12500 [00:01<17:00, 12.22batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 24/12500 [00:02<17:03, 12.18batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 28/12500 [00:02<16:38, 12.49batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 30/12500 [00:02<16:46, 12.39batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 34/12500 [00:02<16:52, 12.32batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 36/12500 [00:03<17:03, 12.18batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 40/12500 [00:03<16:45, 12.39batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 42/12500 [00:03<16:52, 12.31batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 46/12500 [00:03<16:59, 12.22batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 48/12500 [00:04<17:02, 12.18batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 52/12500 [00:04<17:30, 11.85batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 54/12500 [00:04<17:29, 11.86batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 58/12500 [00:04<17:14, 12.02batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   0%|          | 60/12500 [00:05<17:09, 12.09batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 64/12500 [00:05<16:54, 12.25batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 66/12500 [00:05<16:52, 12.28batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 70/12500 [00:05<16:57, 12.21batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 72/12500 [00:06<17:01, 12.17batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 76/12500 [00:06<17:00, 12.18batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 78/12500 [00:06<16:45, 12.36batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 82/12500 [00:06<16:47, 12.33batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 84/12500 [00:07<16:53, 12.25batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 88/12500 [00:07<16:53, 12.25batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 90/12500 [00:07<16:45, 12.35batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 94/12500 [00:07<16:58, 12.18batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 96/12500 [00:08<16:58, 12.18batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 100/12500 [00:08<16:58, 12.18batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 102/12500 [00:08<16:38, 12.41batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 106/12500 [00:08<16:47, 12.30batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 108/12500 [00:08<16:49, 12.27batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 112/12500 [00:09<17:17, 11.94batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 114/12500 [00:09<16:52, 12.24batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 118/12500 [00:09<16:55, 12.20batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 120/12500 [00:09<16:58, 12.15batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 124/12500 [00:10<17:04, 12.08batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 126/12500 [00:10<16:58, 12.15batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 130/12500 [00:10<16:50, 12.24batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 132/12500 [00:10<16:44, 12.31batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 136/12500 [00:11<16:54, 12.18batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])


Epoch 1/24:   1%|          | 139/12500 [00:11<17:08, 12.02batch/s]

torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])
torch.Size([16, 3, 48, 48])





KeyboardInterrupt: 