# Image Restoration

## Library Imports

In [1]:
import os
import cv2
import gradio as gr
import torch
from gfpgan.utils import GFPGANer
os.system("hub install deoldify==1.0.1")
import paddlehub as hub
from pathlib import Path
from datetime import datetime



## Image Colorization

#### Load model checkpoint and define function logic for colorization

In [2]:
model = hub.Module(name='deoldify')
render_factor=35

load pretrained checkpoint success


In [32]:
def colorize_image(image):
    if not os.path.exists("./output"):
        os.makedirs("./output")   
    result = model.predict(image)
    return result[0],result[1]

## Image Upscaling and Enhancement

#### Load model checkpoint and define function logic for upscaling and enhancement

In [None]:
def inference(img, scale):
    try:
        extension = os.path.splitext(os.path.basename(str(img)))[1]
        img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
        if len(img.shape) == 3 and img.shape[2] == 4:
            img_mode = 'RGBA'
        elif len(img.shape) == 2:  # for gray inputs
            img_mode = None
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        else:
            img_mode = None

        h, w = img.shape[0:2]
        if h < 300:
            img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)

        face_enhancer = GFPGANer(
        model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2)

        try:
            _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
        except RuntimeError as error:
            print('Error', error)

        try:
            if scale != 2:
                interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
                h, w = img.shape[0:2]
                output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
        except Exception as error:
            print('wrong scale input.', error)
        if img_mode == 'RGBA':  # RGBA images should be saved in png format
            extension = 'png'
        else:
            extension = 'jpg'
        save_path = f'output/out.{extension}'
        cv2.imwrite(save_path, output)

        output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
        return output, save_path
    except Exception as error:
        print('global exception', error)
        return None, None

## Gradio Interface

#### Interface for colorization

In [53]:
interface1 = gr.Interface(
    colorize_image, [
        gr.Image(type="filepath",label="Input Image")
    ], [
        gr.outputs.Image(type="numpy", label="Output Image"),
        gr.outputs.File(label="Download the output image")
    ],
    description="Colorize B/W, Grayscale images")

#### Interface for upscaling/enhancement

In [54]:
interface2 = gr.Interface(
    inference, [
        gr.inputs.Image(type="filepath", label="Input Image"),
        gr.inputs.Number(label="Rescaling factor", default=2),
    ], [
        gr.outputs.Image(type="numpy", label="Output Image"),
        gr.outputs.File(label="Download the output image")
    ],
    description = "Upscale and enhance Images")

#### Combined Interface

In [87]:
title = "Image Restoration"
final_interface = gr.TabbedInterface([interface1, interface2], ["Colorization", "Enhancement"], title=title,theme=gr.themes.Default())
final_interface.launch(inbrowser=True)

Running on local URL:  http://127.0.0.1:7904

To create a public link, set `share=True` in `launch()`.


