# **Test it Yourself!!!** 😉

**Getting Started**

- Click on profile picture on top right. Login to Google Colab with your gmail account.

- Click *File > Save copy in Drive*. This will open a new copy of this notebook in a new tab.

- In the notebook copy opened up in the new tab, click on *Connect* button on top right. This will connect you to a cloud runtime.

- Select *T4 GPU* option under *Runtime > Change runtime type*. Or choose any other GPU as you prefer if you have Google Colab Pro.


Download either *RealESRGAN_x2plus.pth* or *RealESRGAN_x4plus.pth* (or both of them) from [this model zoo]( https://github.com/NightmareAI/Real-ESRGAN/blob/master/docs/model_zoo.md).

If you would like to try out the face enhancement model too, download *GFPGANv1.4.pth* from [here](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth).

Upload the ***.pth PyTorch model files*** to this cloud runtime. You can click on the Folder icon on the left side panel, and then click on the Upload icon on the top of that left panel to upload your files.

The gray box with code below is called a 'code cell'. You can run the code in the code cell by clicking the play button on the left of the code cell. Run the pip install code cell below, it should download all the necessary packages to your cloud runtime.

In [None]:
!pip install torch torchvision torchaudio basicsr gfpgan --quiet

Now run this next cell to import all necessary packages:

In [None]:
import cv2
import math
import numpy as np
import os
import queue
import threading
import torch

from torch.nn import functional as F
from torch import nn as nn
from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet
from math import ceil, floor, sqrt
from PIL import Image, ImageFilter
from IPython.display import Image as display_image

## Codes [**Click the play button below to run all the hidden cells. No need to open this tab unless you want to inspect all the long long codes...**]

### Utilities

These are the Utils functions and the classes containing the actual model architecture:

In [None]:
# Utils
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

def convert_to_jpg(input_path, output_path):
    # Open the PNG file
    with Image.open(input_path) as img:
        # Save the image in JPEG format
        img.convert("RGB").save(output_path, 'JPEG')


class RealESRGANer():
    """A helper class for upsampling images with RealESRGAN.

    Args:
        scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
        model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
        model (nn.Module): The defined network. Default: None.
        tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
            input images into tiles, and then process each of them. Finally, they will be merged into one image.
            0 denotes for do not use tile. Default: 0.
        tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
        pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
        half (float): Whether to use half precision during inference. Default: False.
    """

    def __init__(self,
                 scale,
                 model_path,
                 model=None,
                 tile=0,
                 tile_pad=10,
                 pre_pad=10,
                 half=False,
                 device=None,
                 gpu_id=None):
        self.scale = scale
        self.tile_size = tile
        self.tile_pad = tile_pad
        self.pre_pad = pre_pad
        self.mod_scale = None
        self.half = half

        # initialize model
        if gpu_id:
            self.device = torch.device(
                f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
        else:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
        # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
        # if model_path.startswith('https://'):
        #     model_path = load_file_from_url(
        #         url=model_path, model_dir='realesrgan/weights')
        loadnet = torch.load(model_path, map_location=torch.device('cpu'))
        # prefer to use params_ema
        if 'params_ema' in loadnet:
            keyname = 'params_ema'
        else:
            keyname = 'params'
        model.load_state_dict(loadnet[keyname], strict=True)
        model.eval()
        self.model = model.to(self.device)
        if self.half:
            self.model = self.model.half()

    def pre_process(self, img):
        """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
        """
        img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
        self.img = img.unsqueeze(0).to(self.device)
        if self.half:
            self.img = self.img.half()

        # pre_pad
        if self.pre_pad != 0:
            self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
        # mod pad for divisible borders
        if self.scale == 2:
            self.mod_scale = 2
        elif self.scale == 1:
            self.mod_scale = 4
        if self.mod_scale is not None:
            self.mod_pad_h, self.mod_pad_w = 0, 0
            _, _, h, w = self.img.size()
            if (h % self.mod_scale != 0):
                self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
            if (w % self.mod_scale != 0):
                self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
            self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')

    def process(self):
        # model inference
        self.output = self.model(self.img)

    def tile_process(self):
        """It will first crop input images to tiles, and then process each tile.
        Finally, all the processed tiles are merged into one images.

        Modified from: https://github.com/ata4/esrgan-launcher
        """
        batch, channel, height, width = self.img.shape
        output_height = height * self.scale
        output_width = width * self.scale
        output_shape = (batch, channel, output_height, output_width)

        # start with black image
        self.output = self.img.new_zeros(output_shape)
        tiles_x = math.ceil(width / self.tile_size)
        tiles_y = math.ceil(height / self.tile_size)

        # loop over all tiles
        for y in range(tiles_y):
            for x in range(tiles_x):
                # extract tile from input image
                ofs_x = x * self.tile_size
                ofs_y = y * self.tile_size
                # input tile area on total image
                input_start_x = ofs_x
                input_end_x = min(ofs_x + self.tile_size, width)
                input_start_y = ofs_y
                input_end_y = min(ofs_y + self.tile_size, height)

                # input tile area on total image with padding
                input_start_x_pad = max(input_start_x - self.tile_pad, 0)
                input_end_x_pad = min(input_end_x + self.tile_pad, width)
                input_start_y_pad = max(input_start_y - self.tile_pad, 0)
                input_end_y_pad = min(input_end_y + self.tile_pad, height)

                # input tile dimensions
                input_tile_width = input_end_x - input_start_x
                input_tile_height = input_end_y - input_start_y
                tile_idx = y * tiles_x + x + 1
                input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]

                # upscale tile
                try:
                    with torch.no_grad():
                        output_tile = self.model(input_tile)
                except RuntimeError as error:
                    print('Error', error)
                print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')

                # output tile area on total image
                output_start_x = input_start_x * self.scale
                output_end_x = input_end_x * self.scale
                output_start_y = input_start_y * self.scale
                output_end_y = input_end_y * self.scale

                # output tile area without padding
                output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
                output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
                output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
                output_end_y_tile = output_start_y_tile + input_tile_height * self.scale

                # put tile into output image
                self.output[:, :, output_start_y:output_end_y,
                            output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
                                                                       output_start_x_tile:output_end_x_tile]

    def post_process(self):
        # remove extra pad
        if self.mod_scale is not None:
            _, _, h, w = self.output.size()
            self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
        # remove prepad
        if self.pre_pad != 0:
            _, _, h, w = self.output.size()
            self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
        return self.output

    @torch.no_grad()
    def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
        h_input, w_input = img.shape[0:2]
        # img: numpy
        img = img.astype(np.float32)
        if np.max(img) > 256:  # 16-bit image
            max_range = 65535
            print('\tInput is a 16-bit image')
        else:
            max_range = 255
        img = img / max_range
        if len(img.shape) == 2:  # gray image
            img_mode = 'L'
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        elif img.shape[2] == 4:  # RGBA image with alpha channel
            img_mode = 'RGBA'
            alpha = img[:, :, 3]
            img = img[:, :, 0:3]
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            if alpha_upsampler == 'realesrgan':
                alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
        else:
            img_mode = 'RGB'
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # ------------------- process image (without the alpha channel) ------------------- #
        self.pre_process(img)
        if self.tile_size > 0:
            self.tile_process()
        else:
            self.process()
        output_img = self.post_process()
        output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
        if img_mode == 'L':
            output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)

        # ------------------- process the alpha channel if necessary ------------------- #
        if img_mode == 'RGBA':
            if alpha_upsampler == 'realesrgan':
                self.pre_process(alpha)
                if self.tile_size > 0:
                    self.tile_process()
                else:
                    self.process()
                output_alpha = self.post_process()
                output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
                output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
                output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
            else:  # use the cv2 resize for alpha channel
                h, w = alpha.shape[0:2]
                output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)

            # merge the alpha channel
            output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
            output_img[:, :, 3] = output_alpha

        # ------------------------------ return ------------------------------ #
        if max_range == 65535:  # 16-bit image
            output = (output_img * 65535.0).round().astype(np.uint16)
        else:
            output = (output_img * 255.0).round().astype(np.uint8)

        if outscale is not None and outscale != float(self.scale):
            output = cv2.resize(
                output, (
                    int(w_input * outscale),
                    int(h_input * outscale),
                ), interpolation=cv2.INTER_LANCZOS4)

        return output, img_mode


class PrefetchReader(threading.Thread):
    """Prefetch images.

    Args:
        img_list (list[str]): A image list of image paths to be read.
        num_prefetch_queue (int): Number of prefetch queue.
    """

    def __init__(self, img_list, num_prefetch_queue):
        super().__init__()
        self.que = queue.Queue(num_prefetch_queue)
        self.img_list = img_list

    def run(self):
        for img_path in self.img_list:
            img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
            self.que.put(img)

        self.que.put(None)

    def __next__(self):
        next_item = self.que.get()
        if next_item is None:
            raise StopIteration
        return next_item

    def __iter__(self):
        return self


class IOConsumer(threading.Thread):

    def __init__(self, opt, que, qid):
        super().__init__()
        self._queue = que
        self.qid = qid
        self.opt = opt

    def run(self):
        while True:
            msg = self._queue.get()
            if isinstance(msg, str) and msg == 'quit':
                break

            output = msg['output']
            save_path = msg['save_path']
            cv2.imwrite(save_path, output)
        print(f'IO worker {self.qid} is done.')

### Main Code

And this is the main method to use to enhance our images:

In [None]:
def enhance_image(input_file, layers=2, upscale=2, final_filename="", enhance_faces=False):
  ## Set models to use
  if layers == 4:
    # 4 layers
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
    netscale = 4
    model_file = 'RealESRGAN_x4plus.pth'
  elif layers == 2:
    # 2 layers
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
    netscale = 2
    model_file = 'RealESRGAN_x2plus.pth'
  else:
    print("Layers parameter must be either 2 or 4.")
    return

  # Final enhanced image will be upscaled by this factor using LANCZOS4 resampling

  # Input image
  imgname, org_extension = input_file.split('.')
  image = cv2.imread(input_file)
  org_width, org_height = image.shape[:2]

  # Convert image to JPG if need be
  if org_extension not in ["jpeg", "jpg"]:
      """JPG file format reduces the file size and makes it feasable for
      faster enhancement using the model.
      """
      convert_to_jpg(input_file, f"{imgname}.jpg")
      input_file = f"{imgname}.jpg"

  # Compute tile size
  if min(org_width, org_height) <= 800:
    tile_size = 0
    print(f"Small image so batching is not necessary.")
  else:
    tile_size = ceil(sqrt(min(org_width, org_height))) * 10
  if tile_size > 500:
    tile_size = 350
  print(f"Tile size being used: {tile_size}")

  # restorer
  upsampler = RealESRGANer(
      scale=netscale,
      model_path=model_file,
      model=model,
      tile=tile_size,
      tile_pad=2,
      half=False)

  # Use GFPGAN for face enhancement
  if enhance_faces:
    from gfpgan import GFPGANer
    face_enhancer = GFPGANer(
        model_path='GFPGANv1.4.pth',
        upscale=upscale,
        arch='clean',
        channel_multiplier=2,
        bg_upsampler=upsampler)

  img = cv2.imread(input_file, cv2.IMREAD_UNCHANGED)

  try:
    if enhance_faces:
      _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
    else:
      output, _ = upsampler.enhance(img, outscale=upscale)
  except RuntimeError as error:
      print('Error', error)
      print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
      print('Else, the file you are using may be too large.')
  else:
    if final_filename != "":
      if not (final_filename.endswith(".jpg") or final_filename.endswith(".jpeg")):
        print(
          "Your preferred final filename for the output image does not or has a wrong have a file extenstion."\
          "Append .jpg or .jpg to your preferred filename."
        )
        return
      save_path = final_filename
    else:
      save_path = f'{imgname}_out.jpg'

    cv2.imwrite(save_path, output)
    print(f"Enhanced image has been saved to {save_path}.\nClick refresh button on the left panel to get latest version of {save_path}")
    return save_path

### Metrics Code

These are some helper fucntions to compute the quality of images based on certain metrics like resolution, sharpness, contrast, and noise:

In [None]:
# Metrics

def get_resolution(image):
    return image.shape[:2]

def get_noise_level(image):
    # Convert the image to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Compute the Discrete Fourier Transform (DFT)
    f_transform = np.fft.fft2(gray)
    f_transform_shifted = np.fft.fftshift(f_transform)

    # Compute the magnitude spectrum
    magnitude_spectrum = np.abs(f_transform_shifted)

    # Calculate the noise level using the standard deviation of the magnitude spectrum
    noise_level = np.std(np.log1p(magnitude_spectrum))

    return round(noise_level, 2)

def get_sharpness(image):
    pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    # Apply an edge-enhancing filter (Laplacian) and compute variance as a measure of sharpness
    laplacian = cv2.Laplacian(cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2GRAY), cv2.CV_64F)
    return round(laplacian.var(), 2)

def get_contrast(image):
    # Using Michelson contrast measure
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    I_max = np.max(gray)
    I_min = np.min(gray)
    print(I_min, I_max)

    contrast = (I_max - I_min) / (I_max + I_min)
    return round(contrast, 5)


def get_filesize(image_file):
  file_size = os.path.getsize(image_file)
  return round(file_size / 1_000_000, 2)


import time
class Timer:
    def __init__(self) -> None:
        self.start = 0
        self.end = 0

    def start(self):
        self.start = time.time()

    def end(self):
        self.end = time.time()
        elapsed_time = self.end - self.start
        print(f"Elapsed Time: {elapsed_time} seconds")


def print_quality(image_file):
    # Load the image
    image = cv2.imread(image_file)

    # Get image metrics
    resolution = get_resolution(image)
    noise_level = get_noise_level(image)
    sharpness = get_sharpness(image)
    try:
      contrast = get_contrast(image)
    except:
      contrast = "unknown"
    image_size = get_filesize(image_file)

    # Output the results
    print(f"Resolution: {resolution} pixels")
    print(f"Noise Level: {noise_level} dB")
    print(f"Sharpness: {sharpness}")
    print(f"Contrast: {contrast}")
    print(f"Size of image: {image_size} MB")

## Here we go woohoo :)

Ok so now you can run the code cell below and scroll a little down to below the code cell. It will prompt you to upload an image file. After it is uploaded, it will enhance your image and display it for you.

You can change the values of "layers" and "enhance_faces" in the actual method below which has been cordoned off for you! ;)

In [None]:
from google.colab import files

Image.MAX_IMAGE_PIXELS = 933120000

# Upload a file
uploaded = files.upload()

# Get the file name
filename = list(uploaded.keys())[0]

print(f"File uploaded: {filename}")

# Start timing
t = Timer()
t.start

########################################################################################################################################################

# You can edit the layers and enhance_faces parameters if you want.
result = enhance_image(
    input_file=filename, # DO NOT CHANGE THIS LINE
    layers=2, # Choose either 2 or 4 as the value here.
    upscale=1.5, # This value indicates the no of times the output image's resolution needs to enlarged from the original.
    enhance_faces=True # Choose between 'True' or 'False' [First letter is capital!]
  )

########################################################################################################################################################

t.end # End timing

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np

# Image paths (replace these with your image paths)
image_paths = [filename, f"{filename.split('.')[0]}_out.jpg"]

# Load images
images = [mpimg.imread(path) for path in image_paths]

# Display images in the same row
fig, axs = plt.subplots(1, len(images), figsize=(10, 5))  # Adjust the figsize as needed

for i, (img, path) in enumerate(zip(images, image_paths)):
    axs[i].imshow(img)
    axs[i].axis('off')

axs[0].set_title(f"Original Image")
axs[1].set_title(f"Enhanced Image")

plt.show()

You can download the enhanced image to your local computer by opening the Folder icon on the left side panel, and there you will find both your original file and an _out file.

The _out file is the enhanced image. You can right click on it and click Download.

You can also check out the metrics of your original and enhanced images by running the code cells below:

In [None]:
print_quality(filename)

In [None]:
print_quality(f"{filename.split('.')[0]}_out.jpg")

0 255
Resolution: (5250, 10500) pixels
Noise Level: 1.57 dB
Sharpness: 91.58
Contrast: 1.0
Size of image: 14.79 MB
