# ESRGAN: Image super-resolution and enhancement

## [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml4a/ml4a-guides/blob/ml4a.net/examples/models/ESRGAN.ipynb)

Upscales an image's pixel resolution by 4x. See the [original code](https://github.com/xinntao/ESRGAN) and [paper](https://arxiv.org/abs/1809.00219).

## Set up ml4a and enable GPU

If you don't already have ml4a installed, or you are opening this in Colab, first enable GPU (`Runtime` > `Change runtime type`), then run the following cell to install ml4a and its dependencies.

In [4]:
!pip3 install tensorflow

You should consider upgrading via the '/Library/Frameworks/Python.framework/Versions/3.9/bin/python3.9 -m pip install --upgrade pip' command.[0m[33m
[0m

In [6]:
!pip3 install --quiet ml4a

[31mERROR: Cannot install ml4a==0.1.0, ml4a==0.1.2 and ml4a==0.1.3 because these package versions have conflicting dependencies.[0m[31m
[0m[31mERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts[0m[31m
You should consider upgrading via the '/Library/Frameworks/Python.framework/Versions/3.9/bin/python3.9 -m pip install --upgrade pip' command.[0m[33m
[0m

## Upsample an image

We start with a small image and upsample it 4x using ESRGAN.

In [2]:
# from ml4a import image
# from ml4a.models import esrgan

# instead of using ml4a, I will use the ESRGAN model directly via a downloaded copy, and CLIP to do the image comparison and selection. I will also use the PIL library to do the image manipulation. I will also use the os library to do the file manipulation.

import os
import sys
import subprocess
import zipfile
import requests
import glob
import shutil
import tensorflow as tf
import numpy as np
import PIL
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import clip
import torch


# load the model into memory
print("Loading ESRGAN model")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.hub.load('xinntao/ESRGAN', 'esrgan', pretrained=True, device=device)
model.eval()

# load the CLIP model into memory
print("Loading CLIP model")
model, preprocess = clip.load("ViT-B/32", device)

# download the ESRGAN model
print("Downloading ESRGAN model")
r = requests.get("https://data.vision.ee.ethz.ch/cvl/DIV2K/models/RRDB_ESRGAN_x4.pth", allow_redirects=True)
open("RRDB_ESRGAN_x4.pth", "wb").write(r.content)

# download the ESRGAN code
print("Downloading ESRGAN code")
r = requests.get("https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip", allow_redirects=True)

# unzip the ESRGAN code
print("Unzipping ESRGAN code")
with zipfile.ZipFile("DIV2K_train_HR.zip", "r") as zip_ref:
    zip_ref.extractall()

# move the ESRGAN code to the correct folder
print("Moving ESRGAN code")
shutil.move("DIV2K_train_HR", "esrgan")


input_images_directory = './input_images'
output_images_directory = './output_images'

# create the input and output directories if they don't exist
if not os.path.exists(input_images_directory):
    os.makedirs(input_images_directory)

if not os.path.exists(output_images_directory):
    os.makedirs(output_images_directory)

# take the input images and convert them to the correct format for the ESRGAN model
input_images = glob.glob(input_images_directory + '/*')

from tqdm import tqdm
from alive_progress import alive_bar

number_of_images = len(input_images)

with alive_bar(number_of_images) as bar:
    for input_image in input_images:
        # load the image
        image = PIL.Image.open(input_image)

        # resize the image to 256x256
        image = image.resize((256, 256))

        # save the image
        image.save(input_image)

        bar()


# Now that the images are in the correct format, we can run them through the ESRGAN model
input_images = glob.glob(input_images_directory + '/*')

# create a list of the output images
output_images = []

# create a list of the output image names
output_image_names = []

# begin the loop
with alive_bar(number_of_images) as bar:
    # take each image and magnify it using the ESRGAN model 4x and save the output
    for input_image in input_images:
        # load the image
        image = PIL.Image.open(input_image)

        # convert the image to a numpy array
        image = np.array(image)

        # convert the image to a tensor
        image = torch.from_numpy(image).to(device).float()

        # convert the image to the correct format for the ESRGAN model
        image = image.permute(2, 0, 1).unsqueeze(0)

        # run the image through the ESRGAN model
        with torch.no_grad():
            output = model(image)

        # convert the output to a numpy array
        output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()

        # convert the output to an image
        output = PIL.Image.fromarray(np.uint8(output.clip(0, 1) * 255))

        # save the output image
        output_image_name = os.path.basename(input_image)
        output_image_name = os.path.splitext(output_image_name)[0]
        output_image_name = output_image_name + '_esrgan.png'
        output_image_path = os.path.join(output_images_directory, output_image_name)
        output.save(output_image_path)

        # add the output image to the list of output images
        output_images.append(output_image_path)

        # add the output image name to the list of output image names
        output_image_names.append(output_image_name)

        bar()






# # load image from the web
# img1 = image.load_image('https://raw.githubusercontent.com/xinntao/ESRGAN/master/LR/baboon.png')

# # or you can load an image directly from disk
# #img1 = image.load_image('my_file.jpg')

# # run ESRGAN
# img2 = esrgan.run(img1)

# image.display(img1, title="original image")
# image.display(img2, title="upsampled 4x")

ModuleNotFoundError: No module named 'ml4a'