# IMPORT

In [None]:
import os.path as osp
import glob
import sys
import torch
import os
import time
from matplotlib import pyplot as plt
import lpips
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".")))
import cv2
from FinalModels.ESRGAN.model import ESRGAN
import pyiqa


# Folder path for test images
folder_path = "d:\\workspace\\ThesisProject\\TESTS\\images\\LR" # change this 

  
# Initialize IQA models from `pyiqa`
niqe_fn = pyiqa.create_metric('niqe', device=device)
# brisque_fn = pyiqa.create_metric('brisque', device=device)
nrqm_fn = pyiqa.create_metric('nrqm', device=device)
lpips_fn = pyiqa.create_metric('lpips', device=device)
ssim_fn = pyiqa.create_metric('ssim', device=device)
pi_fn = pyiqa.create_metric('pi', device=device)  

**LOAD THE MODEL**

In [None]:
model = ESRGAN(model_path='./FinalModels/ESRGAN/model/RRDB_ESRGAN_x4.pth', device=device)

# FUNCTIONS

In [None]:


def enhance_image(image_path, model, device):
    """Performs super-resolution on an input image."""
    model.to(device)
    start = time.time()
    hr_image = preprocess_image(image_path).to(device)

    # original = hr_image
    original = hr_image
    bicubic_image = F.interpolate(hr_image, size=(224, 224), mode='bicubic', align_corners=False)

    modes = {
        "Bilinear": F.interpolate(hr_image, size=(224, 224), mode='bilinear', align_corners=False),
        "Nearest": F.interpolate(hr_image, size=(224, 224), mode='nearest'),
        "Bicubic": F.interpolate(hr_image, size=(224, 224), mode='bicubic', align_corners=False),
        "None" : F.interpolate(hr_image, size=(224, 224)),
        "Exact": F.interpolate(hr_image, size=(224, 224), mode='nearest-exact')
    }
    
    with torch.no_grad():
        enhanced_image = model(modes['Bicubic'])
    enhanced_image = F.interpolate(enhanced_image, size=(224, 224), mode='bicubic', align_corners=False)
    print(f"Time taken: {time.time() - start:.2f} seconds: enhanced shape: {enhanced_image.shape}, hr shape: {hr_image.shape}, bicubic shape: {bicubic_image.shape}, original shape: {original.shape}")
    # return enhanced_image, modes
    return hr_image, bicubic_image, enhanced_image, modes

def evaluate_performance(hr_image, sr_image):
    """
    Evaluate SR quality using multiple perceptual metrics from `pyiqa`.
    """
    # Ensure images have the same spatial resolution

    # if hr_image.shape != sr_image.shape:
    #     print(f"Shape mismatch: HR {hr_image.shape}, SR {sr_image.shape}")
    #     return None  

    sr_resized_image = F.interpolate(sr_image, size=(hr_image.shape[2], hr_image.shape[3]), mode='bicubic', align_corners=False)

    if sr_image.shape[-2] < 64 or sr_image.shape[-1] < 64:
        print("Image too small for NIQE evaluation:", sr_image.shape)
        return None  # Skip NIQE evaluation for small images


    hr_image = torch.clamp(hr_image, 0, 1)
    sr_image = torch.clamp(sr_image, 0, 1)
    sr_resized_image = torch.clamp(sr_resized_image, 0, 1)
    # Compute Metrics
    psnr = 10 * torch.log10(1.0 / F.mse_loss(sr_resized_image.squeeze(0), hr_image.squeeze(0))) if F.mse_loss(sr_resized_image.squeeze(0), hr_image.squeeze(0)) > 0 else float('inf')
    ssim_score = ssim_fn(sr_resized_image, hr_image).item()
    lpips_score = lpips_fn(sr_resized_image, hr_image).item()
    niqe_score = niqe_fn(sr_image).item()
    # brisque_score = brisque_fn(sr_image).item()
    nrqm_score = nrqm_fn(sr_image).item()
    pi_score = pi_fn(sr_image).item()

    # Return all metrics
    metrics = {
        "PSNR": psnr.item(),
        "SSIM": ssim_score,
        ">LPIPS": lpips_score,
        "<NIQE": niqe_score,
        # "BRISQUE": brisque_score,
        ">NRQM": nrqm_score,
        "<PI": pi_score
    }

    return metrics

def perform_super_resolution2(image_path, model, device):
    """Performs super-resolution on an input image."""
    model.to(device)
    start = time.time()
    hr_image = preprocess_image(image_path).to(device)
    original = hr_image
    hr_image = F.interpolate(hr_image, size=(896, 896), mode='bicubic', align_corners=False) # with interpolation
    lr_image = F.interpolate(hr_image, size=(224, 224), mode='bicubic', align_corners=False)
    bicubic_image = F.interpolate(lr_image, size=(896, 896), mode='bicubic', align_corners=False)

    with torch.no_grad():
        enhanced_image = model(lr_image)
    print(f"Time taken: {time.time() - start:.2f} seconds: enhanced shape: {enhanced_image.shape}, hr shape: {hr_image.shape}, bicubic shape: {bicubic_image.shape}, Original shape: {original.shape}")
    # return enhanced_image, modes
    return original, hr_image, bicubic_image, enhanced_image


lpips_loss_fn = lpips.LPIPS(net='vgg').to(device)


def evaluate_performance2(hr_image, sr_image):
    """
    Evaluate SR quality using multiple perceptual metrics from `pyiqa`.
    """
    # Ensure images have the same spatial resolution
    try: 

        if hr_image.shape != sr_image.shape:
            hr_image = F.interpolate(hr_image, size=(sr_image.shape[2], sr_image.shape[3]), mode='bicubic', align_corners=False)


        hr_image = torch.clamp(hr_image, 0, 1)
        sr_image = torch.clamp(sr_image, 0, 1)
        # Compute Metrics
        if F.mse_loss(sr_image.squeeze(0), hr_image.squeeze(0)) == 0:
            psnr = 0
        else:
            psnr = (10 * torch.log10(1.0 / F.mse_loss(sr_image.squeeze(0), hr_image.squeeze(0)))).item() if F.mse_loss(sr_image.squeeze(0), hr_image.squeeze(0)) > 0 else float('inf')

        ssim_score = ssim_fn(sr_image, hr_image).item()
        lpips_score = lpips_fn(sr_image, hr_image).item()

        if sr_image.shape[-2] > 64 and sr_image.shape[-1] > 64:
            niqe_score = niqe_fn(sr_image).item()
        else:
            niqe_score = 0

        if sr_image.shape[-2] > 64 and sr_image.shape[-1] > 64:
            nrqm_score = nrqm_fn(sr_image).item()
            pi_score = pi_fn(sr_image).item()
        else:
            nrqm_score = 0
            pi_score =0

        metrics = {
            "PSNR": psnr,
            "SSIM": ssim_score,
            ">LPIPS": lpips_score,
            "<NIQE": niqe_score,
            ">NRQM": nrqm_score,
            "<PI": pi_score
        }

        return metrics
    except Exception as e:
        print(f"Error in evaluate_performance2: {e}")

def preprocess_image(image_path):
    img = Image.open(image_path).convert("RGB")
    transform = transforms.ToTensor()
    return transform(img).unsqueeze(0)  # Add batch dimension


def downscale_image(hr_image, scale=4):
    lr_size = (hr_image.size(2) // scale, hr_image.size(3) // scale)
    lr_image = F.interpolate(hr_image, size=lr_size, mode='bicubic', align_corners=False)
    return lr_image

def save_image_esrgan(image, save_path, device, filename):
    """Saves the image tensor to a file after interpolation and conversion, running on GPU if needed."""
    
    image = image.to(device)
    image = F.interpolate(image, size=(224,224), mode='bilinear', align_corners=False)
    # image = F.interpolate(image, (size, size), mode='bicubic', align_corners=False).squeeze(0)
    
    # Convert from (C, H, W) to (H, W, C)
    image = image.squeeze().permute(1, 2, 0)
    
    # Move tensor to CPU before converting to NumPy
    image = image.cpu().numpy()
    
    image = np.clip(image * 255, 0, 255).astype(np.uint8)
    
    print("Shape before save:", )
    
    pil_image = Image.fromarray(image)
    try:
        pil_image.save(os.path.join(save_path, filename + '.png'), format='PNG')  
        print(f"Saved {filename} Shape: {image.shape} at {save_path}")
    except Exception as e:
        print(f"Error saving {filename}: {e}")

def compare_images(hr_image, lr_image, sr_image, psnr):
    """Compare the images by displaying them and printing PSNR."""
    sr_image = F.interpolate(sr_image, size=hr_image.shape[2:], mode='bicubic', align_corners=False)
    def process_image(img):
        if isinstance(img, torch.Tensor):
            img = img.squeeze().permute(1, 2, 0).cpu().numpy()  # Convert (3, H, W) to (H, W, 3)
        else:
            img = np.squeeze(img)  
            if img.ndim == 3 and img.shape[0] in [1, 3]:  # If channel-first, convert to channel-last
                img = np.transpose(img, (1, 2, 0))
        return img
    lr_image = F.interpolate(lr_image, size=hr_image.shape[2:], mode='bicubic', align_corners=False)
    hr_image = hr_image.cpu().numpy() if isinstance(hr_image, torch.Tensor) else hr_image
    lr_image = lr_image.cpu().numpy() if isinstance(lr_image, torch.Tensor) else lr_image
    sr_image = sr_image.cpu().numpy() if isinstance(sr_image, torch.Tensor) else sr_image
    
    hr_image = np.clip(hr_image, 0, 1)
    lr_image = np.clip(lr_image, 0, 1)
    sr_image = np.clip(sr_image, 0, 1)

    hr_image = process_image(hr_image)
    lr_image = process_image(lr_image)
    sr_image = process_image(sr_image)

    plt.rcParams['figure.figsize'] = [15, 10]
    fig, axes = plt.subplots(1, 3)
    fig.tight_layout()

    # Plot Original Image
    plt.subplot(131)
    plot_image(hr_image, title="HR Image")
    
    # Plot Low-Resolution Image
    plt.subplot(132)
    plot_image(lr_image, title="x4 Bicubic")
    
    # Plot Super-Resolved Image
    plt.subplot(133)
    plot_image(sr_image, title="Super Resolution (ESRGRAN)")
    
    plt.show(block=False)
    

def plot_image(image, title=""):
  """
    Plots images from image tensors.
    Args:
      image: 3D image tensor. [height, width, channels].
      title: Title to display in the plot.
  """
  image = np.asarray(image)
  plt.imshow(image)
  plt.axis("off")
  plt.title(title)


# EVALUATION AND COMPARISON

**BICUBIC VS ESRGAN**

In [None]:
# Initialize accumulators
total_psnr_esrgan = 0
total_psnr_bicubic = 0
total_ssim_esrgan = 0
total_ssim_bicubic = 0
total_lpips_esrgan = 0
total_lpips_bicubic = 0
total_niqe_esrgan = 0
total_niqe_bicubic = 0
total_nrqm_esrgan = 0
total_nrqm_bicubic = 0
total_pi_esrgan = 0
total_pi_bicubic = 0
total_psnr_hr = 0
total_ssim_hr = 0
total_lpips_hr = 0
total_niqe_hr = 0
total_nrqm_hr = 0
total_pi_hr = 0

num_images = 0

idx = 0
for path in glob.glob(folder_path + "/*"):
    idx += 1
    base = osp.splitext(osp.basename(path))[0]
    print(f'{idx}: {base}')

    original, hr_image, bicubic, output = perform_super_resolution2(path, model, device)

    # Evaluate performance for ESRGAN output
    esrgan_metrics = evaluate_performance2(hr_image, output)
    bicubic_metrics = evaluate_performance2(hr_image, bicubic)
    hr_metrics = evaluate_performance2(original, hr_image)  

    # Extract metrics for ESRGAN
    psnr_esrgan = esrgan_metrics["PSNR"]
    ssim_esrgan = esrgan_metrics["SSIM"]
    lpips_esrgan = esrgan_metrics[">LPIPS"]
    niqe_esrgan = esrgan_metrics["<NIQE"]
    nrqm_esrgan = esrgan_metrics[">NRQM"]
    pi_esrgan = esrgan_metrics["<PI"]

    # Extract metrics for Bicubic
    psnr_bicubic = bicubic_metrics["PSNR"]
    ssim_bicubic = bicubic_metrics["SSIM"]
    lpips_bicubic = bicubic_metrics[">LPIPS"]
    niqe_bicubic = bicubic_metrics["<NIQE"]
    nrqm_bicubic = bicubic_metrics[">NRQM"]
    pi_bicubic = bicubic_metrics["<PI"]

    # Extract metrics for HR
    psnr_hr = hr_metrics["PSNR"]
    ssim_hr = hr_metrics["SSIM"]
    lpips_hr = hr_metrics[">LPIPS"]
    niqe_hr = hr_metrics["<NIQE"]
    nrqm_hr = hr_metrics[">NRQM"]
    pi_hr = hr_metrics["<PI"]

    print('ESRGAN: ', esrgan_metrics)
    print('BICUBIC: ', bicubic_metrics)
    print('HR: ', hr_metrics)

    # Accumulate totals for averaging
    total_psnr_esrgan += psnr_esrgan
    total_psnr_bicubic += psnr_bicubic
    total_psnr_hr += psnr_hr
    total_ssim_esrgan   += ssim_esrgan
    total_ssim_bicubic  += ssim_bicubic
    total_ssim_hr       += ssim_hr
    total_lpips_esrgan += lpips_esrgan
    total_lpips_bicubic += lpips_bicubic
    total_lpips_hr     += lpips_hr
    total_niqe_esrgan += niqe_esrgan
    total_niqe_bicubic += niqe_bicubic
    total_niqe_hr     += niqe_hr
    total_nrqm_esrgan += nrqm_esrgan
    total_nrqm_bicubic += nrqm_bicubic
    total_nrqm_hr     += nrqm_hr
    total_pi_esrgan += pi_esrgan
    total_pi_bicubic += pi_bicubic
    total_pi_hr     += pi_hr
    num_images += 1

    compare_images(hr_image, bicubic, output, psnr_esrgan)
    # Save images (optional)
    # save_image_esrgan(bicubic, "images/HR", device, f'{idx}-cubic-{base}')
    # save_image_esrgan(output, "images/HR", device, f'{idx}-esr-{base}')
    print('====================================================================================')

# Calculate and print averages
if num_images > 0:
    avg_psnr_esrgan = total_psnr_esrgan / num_images
    avg_psnr_bicubic = total_psnr_bicubic / num_images
    avg_psnr_hr = total_psnr_hr / num_images
    avg_ssim_esrgan = total_ssim_esrgan / num_images
    avg_ssim_bicubic = total_ssim_bicubic / num_images
    avg_ssim_hr = total_ssim_hr / num_images
    avg_lpips_esrgan = total_lpips_esrgan / num_images
    avg_lpips_bicubic = total_lpips_bicubic / num_images
    avg_lpips_hr = total_lpips_hr / num_images
    avg_niqe_esrgan = total_niqe_esrgan / num_images
    avg_niqe_bicubic = total_niqe_bicubic / num_images
    avg_niqe_hr = total_niqe_hr / num_images
    avg_nrqm_esrgan = total_nrqm_esrgan / num_images
    avg_nrqm_bicubic = total_nrqm_bicubic / num_images
    avg_nrqm_hr = total_nrqm_hr / num_images
    avg_pi_esrgan = total_pi_esrgan / num_images
    avg_pi_bicubic = total_pi_bicubic / num_images
    avg_pi_hr = total_pi_hr / num_images

    print(f'\nAverage PSNR (ESRGAN): {avg_psnr_esrgan:.2f}')
    print(f'Average PSNR (BICUBIC): {avg_psnr_bicubic:.2f}')
    print(f'Average PSNR (HR): {avg_psnr_hr:.2f}')
    print(f'Average SSIM (ESRGAN): {avg_ssim_esrgan:.4f}')
    print(f'Average SSIM (BICUBIC): {avg_ssim_bicubic:.2f}')
    print(f'Average SSIM (HR): {avg_ssim_hr:.4f}')
    print(f'Average LPIPS (ESRGAN): {avg_lpips_esrgan:.4f}')
    print(f'Average LPIPS (BICUBIC): {avg_lpips_bicubic:.4f}')
    print(f'Average LPIPS (HR): {avg_lpips_hr:.4f}')
    print(f'Average NIQE (ESRGAN): {avg_niqe_esrgan:.2f}')
    print(f'Average NIQE (BICUBIC): {avg_niqe_bicubic:.2f}')
    print(f'Average NIQE (HR): {avg_niqe_hr:.2f}')
    print(f'Average NRQM (ESRGAN): {avg_nrqm_esrgan:.4f}')
    print(f'Average NRQM (BICUBIC): {avg_nrqm_bicubic:.4f}')
    print(f'Average NRQM (HR): {avg_nrqm_hr:.4f}')
    print(f'Average PI (ESRGAN): {avg_pi_esrgan:.4f}')
    print(f'Average PI (BICUBIC): {avg_pi_bicubic:.4f}')
    print(f'Average PI (HR): {avg_pi_hr:.4f}')
else:
    print("No images processed.")


# ENHANCING IMAGE AND SAVING TO DATASET

**RUN IF ENHANCEMENT SUDDENLY STOPPED**

In [None]:
input_folder = './DATASET/TRAIN'
output_folder = './DATASET/TRAIN_ESRGAN2 - ESRGAN'

os.makedirs(output_folder, exist_ok=True)

found = False


target_file = 'fish_000000019596_03648.png' 
target_folder = 'Myripristis Kuntee' 

for class_name in os.listdir(input_folder):


    if class_name == target_folder:

        class_path = os.path.join(input_folder, class_name)

        if os.path.isdir(class_path):
            enhanced_class_path = os.path.join(output_folder, class_name)
            os.makedirs(enhanced_class_path, exist_ok=True)

            for image_file in os.listdir(class_path):
                image_path = os.path.join(class_path, image_file)
                

                if not found:
                    if image_file == target_file:
                        found = True 
                        continue
                    else:
                        print(f"Skipping file: {image_file}")
                        continue  

                if image_file.lower().endswith(('.png', '.jpg', '.jpeg')):  
                    print(f"Processing {image_file}...")
                    original, _, enhance_image_result, _ = enhance_image(image_path, esrgan_model, device)
                    try:
                        metrics = evaluate_performance(original, enhance_image_result)
                        print(metrics)
                    except Exception as e:
                        print(f"Error evaluating {image_file}: {e}")
                    print(f"FINISHED {image_file}: ")
                    save_image_esrgan(enhance_image_result, os.path.join(enhanced_class_path), device, image_file.split('.')[0])



print("Image enhancement complete. Enhanced images saved in:", output_folder)

**ONLY RUN UNFINISHED FOLDERS**

In [None]:

input_folder = './DATASET/TRAIN'
output_folder = './DATASET/TRAIN_ESRGAN2 - ESRGAN'
torch.cuda.empty_cache()

os.makedirs(output_folder, exist_ok=True)

# finished_folders = ['Acanthurus Nigrofuscus', 'Canthigaster Valentini', 'Chaetodon Trifascialis',
#                     'Hemigymnus Fasciatus', 'Neoniphon Sammara', 'Pomacentrus Moluccensis',
#                     'Myripristis Kuntee', 'Lutjanus fulvus'
#                     ,'Abudefduf Vaigiensis', 'Balistapus Undulatus', 'Hemigymnus Melapterus', 'Neoglyphidodon Nigroris', 'Pempheris Vanicolensis'] #new added 10 species

finished_folders = ['Acanthurus Nigrofuscus','Abudefduf Vaigiensis','Balistapus Undulatus', 'Canthigaster Valentini', 
                    'Chaetodon Trifascialis', 'Hemigymnus Fasciatus', 'Hemigymnus Melapterus', 'Lutjanus fulvus', 'Myripristis Kuntee'
                    ] #new added 10 species

for class_name in os.listdir(input_folder):

    if class_name in finished_folders:
        continue 

    class_path = os.path.join(input_folder, class_name)

    if os.path.isdir(class_path):
        enhanced_class_path = os.path.join(output_folder, class_name)
        os.makedirs(enhanced_class_path, exist_ok=True)

        for image_file in os.listdir(class_path):
            image_path = os.path.join(class_path, image_file)
            
            if image_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                
                print(f"Processing {image_file}...")
                original, _, enhance_image_result, _ = enhance_image(image_path, esrgan_model, device)
                try:
                    metrics = evaluate_performance(original, enhance_image_result)
                    print(metrics)
                except Exception as e:
                    print(f"Error evaluating {image_file}: {e}")
                print(f"FINISHED {image_file}: ")
                save_image_esrgan(enhance_image_result, os.path.join(enhanced_class_path), device, image_file.split('.')[0])

print("Image enhancement complete. Enhanced images saved in:", output_folder)