# GFPGAN Inference Demo
### (No colorization; No CUDA extensions required)

[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2101.04061)
[![GitHub Stars](https://img.shields.io/github/stars/TencentARC/GFPGAN?style=social)](https://github.com/TencentARC/GFPGAN)
[![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)

## GFPGAN - Towards Real-World Blind Face Restoration with Generative Facial Prior

GFPGAN is a blind face restoration algorithm towards real-world face images. <br>
It leverages the generative face prior in a pre-trained GAN (*e.g.*, StyleGAN2) to restore realistic faces while precerving fidelity. <br>

If you want to use the paper model, please go to this [Colab Demo](https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing) for GFPGAN <a href="https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.

**Limitations**: GFPGAN could not handle all the low-quality faces in the real world. Therefore, it may fail on your own cases.

###Enjoy! :-)

<img src="https://xinntao.github.io/projects/GFPGAN_src/gfpgan_teaser.jpg" width="800">


# 1. Preparations
Before start, make sure that you choose
* Runtime Type = Python 3
* Hardware Accelerator = GPU

in the **Runtime** menu -> **Change runtime type**

Then, we clone the repository, set up the envrironment, and download the pre-trained model.


In [None]:
# Clone GFPGAN and enter the GFPGAN folder
%cd /content
!rm -rf GFPGAN
!git clone https://github.com/TencentARC/GFPGAN.git
%cd GFPGAN

# Set up the environment
# Install basicsr - https://github.com/xinntao/BasicSR
# We use BasicSR for both training and inference
!pip install basicsr
# Install facexlib - https://github.com/xinntao/facexlib
# We use face detection and face restoration helper in the facexlib package
!pip install facexlib
# Install other depencencies
!pip install -r requirements.txt
!python setup.py develop
!pip install realesrgan  # used for enhancing the background (non-face) regions
# Download the pre-trained model
# !wget https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth -P experiments/pretrained_models
# Now we use the V1.3 model for the demo
!wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P experiments/pretrained_models


/content
Cloning into 'GFPGAN'...
remote: Enumerating objects: 523, done.[K
remote: Total 523 (delta 0), reused 0 (delta 0), pack-reused 523[K
Receiving objects: 100% (523/523), 5.39 MiB | 27.34 MiB/s, done.
Resolving deltas: 100% (264/264), done.
/content/GFPGAN
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting basicsr
  Downloading basicsr-1.4.2.tar.gz (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.5/172.5 KB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting addict
  Downloading addict-2.4.0-py3-none-any.whl (3.8 kB)
Collecting lmdb
  Downloading lmdb-1.4.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (305 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m305.9/305.9 KB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
Collecting tb-nightly
  Downloading tb_nightly-2.13.0a20230310-py3-none-any.whl (5.6 MB)
[2K

In [None]:
# upload your own images
import os
from google.colab import files
import shutil

upload_folder = 'inputs/upload'

if os.path.isdir(upload_folder):
    shutil.rmtree(upload_folder)
os.mkdir(upload_folder)

upload_folder = 'inputs/video'

if os.path.isdir(upload_folder):
    shutil.rmtree(upload_folder)
os.mkdir(upload_folder)

In [None]:
import cv2

# open the video file
video = cv2.VideoCapture('/content/GFPGAN/inputs/video/001_without_audio.mp4')

# initialize a counter for the frames
count = 0

# loop through the video frames
while True:
    # read the next frame
    ret, frame = video.read()

    # break the loop if there are no more frames
    if not ret:
        break

    # save the frame as an image file
    filename = f'frame{count}.jpg'
    cv2.imwrite('/content/GFPGAN/inputs/upload/' +filename, frame)

    # increment the frame counter
    count += 1

# release the video file
video.release()


## 2. Setting the parameters

In [None]:
import argparse
import cv2
import glob
import numpy as np
import os
import torch
from basicsr.utils import imwrite
from gfpgan import GFPGANer

import argparse

class MyArgs:
    def __init__(self, args_dict):
        self.input = args_dict.get('input', 'inputs/whole_imgs')
        self.output = args_dict.get('output', 'results')
        self.version = args_dict.get('version', '1.3')
        self.upscale = args_dict.get('upscale', 2)
        self.bg_upsampler = args_dict.get('bg_upsampler', 'realesrgan')
        self.bg_tile = args_dict.get('bg_tile', 400)
        self.suffix = args_dict.get('suffix', None)
        self.only_center_face = args_dict.get('only_center_face', False)
        self.aligned = args_dict.get('aligned', False)
        self.ext = args_dict.get('ext', 'auto')
        self.weight = args_dict.get('weight', 0.5)

# Dictionary of arguments
args_dict = {
    'input': ' inputs/upload/',
    'output': 'results',
    'version': '1.3',
    'upscale': 2,
    'bg_upsampler': 'realesrgan',
    'bg_tile': 400,
    'suffix': None,
    'only_center_face': False,
    'aligned': False,
    'ext': 'auto',
    'weight': 0.5
}

# Create instance of MyArgs class
args = MyArgs(args_dict)

In [None]:
args.input = 'inputs/upload/'

In [None]:
if args.input.endswith('/'):
    args.input = args.input[:-1]
if os.path.isfile(args.input):
    img_list = [args.input]
else:
    img_list = sorted(glob.glob(os.path.join(args.input, '*')))


In [None]:

def main():

    # ------------------------ input & output ------------------------
    if args.input.endswith('/'):
        args.input = args.input[:-1]
    if os.path.isfile(args.input):
        img_list = [args.input]
    else:
        img_list = sorted(glob.glob(os.path.join(args.input, '*')))

    os.makedirs(args.output, exist_ok=True)
    print(img_list)

    # ------------------------ set up background upsampler ------------------------
    if args.bg_upsampler == 'realesrgan':
        if not torch.cuda.is_available():  # CPU
            import warnings
            warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
                          'If you really want to use it, please modify the corresponding codes.')
            bg_upsampler = None
        else:
            from basicsr.archs.rrdbnet_arch import RRDBNet
            from realesrgan import RealESRGANer
            model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
            bg_upsampler = RealESRGANer(
                scale=2,
                model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
                model=model,
                tile=args.bg_tile,
                tile_pad=10,
                pre_pad=0,
                half=True)  # need to set False in CPU mode
    else:
        bg_upsampler = None

    # ------------------------ set up GFPGAN restorer ------------------------
    if args.version == '1':
        arch = 'original'
        channel_multiplier = 1
        model_name = 'GFPGANv1'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth'
    elif args.version == '1.2':
        arch = 'clean'
        channel_multiplier = 2
        model_name = 'GFPGANCleanv1-NoCE-C2'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth'
    elif args.version == '1.3':
        arch = 'clean'
        channel_multiplier = 2
        model_name = 'GFPGANv1.3'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
    elif args.version == '1.4':
        arch = 'clean'
        channel_multiplier = 2
        model_name = 'GFPGANv1.4'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
    elif args.version == 'RestoreFormer':
        arch = 'RestoreFormer'
        channel_multiplier = 2
        model_name = 'RestoreFormer'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
    else:
        raise ValueError(f'Wrong model version {args.version}.')

    # determine model paths
    model_path = os.path.join('experiments/pretrained_models', model_name + '.pth')
    if not os.path.isfile(model_path):
        model_path = os.path.join('gfpgan/weights', model_name + '.pth')
    if not os.path.isfile(model_path):
        # download pre-trained models from url
        model_path = url

    restorer = GFPGANer(
        model_path=model_path,
        upscale=args.upscale,
        arch=arch,
        channel_multiplier=channel_multiplier,
        bg_upsampler=bg_upsampler)


    # define a function to process an image
    def process_image(img_path):
        # read image
        img_name = os.path.basename(img_path)
        print(f'Processing {img_name} ...')
        basename, ext = os.path.splitext(img_name)
        input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # restore faces and background if necessary
        cropped_faces, restored_faces, restored_img = restorer.enhance(
            input_img,
            has_aligned=args.aligned,
            only_center_face=args.only_center_face,
            paste_back=True,
            weight=args.weight)

        # save faces
        for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
            # save cropped face
            save_crop_path = os.path.join(args.output, 'cropped_faces', f'{basename}_{idx:02d}.png')
            imwrite(cropped_face, save_crop_path)
            # save restored face
            if args.suffix is not None:
                save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
            else:
                save_face_name = f'{basename}_{idx:02d}.png'
            save_restore_path = os.path.join(args.output, 'restored_faces', save_face_name)
            imwrite(restored_face, save_restore_path)
            # save comparison image
            cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
            imwrite(cmp_img, os.path.join(args.output, 'cmp', f'{basename}_{idx:02d}.png'))

        # save restored img
        if restored_img is not None:
            if args.ext == 'auto':
                extension = ext[1:]
            else:
                extension = args.ext

            if args.suffix is not None:
                save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}_{args.suffix}.{extension}')
            else:
                save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}.{extension}')
            imwrite(restored_img, save_restore_path)

        print(f'Finished processing {img_name}.')


    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        # submit the tasks to the thread pool
        futures = [executor.submit(process_image, img_path) for img_path in img_list]

        # wait for the tasks to complete
        for future in concurrent.futures.as_completed(futures):
            # handle any exceptions that might have occurred
            try:
                future.result()
            except Exception as e:
                print(f'Exception occurred: {e}')



In [None]:
import time
import concurrent.futures


In [None]:

n = len(img_list)
num_parts = 5
part_size = n // num_parts

split_list = [img_list[i:i+part_size] for i in range(0, n, part_size)]

In [None]:

def enhance_list(split_list):


    # # ------------------------ input & output ------------------------
    # if args.input.endswith('/'):
    #     args.input = args.input[:-1]
    # if os.path.isfile(args.input):
    #     img_list = [args.input]
    # else:
    #     img_list = sorted(glob.glob(os.path.join(args.input, '*')))

    img_list = split_list

    os.makedirs(args.output, exist_ok=True)
    print(img_list)

    # ------------------------ set up background upsampler ------------------------
    if args.bg_upsampler == 'realesrgan':
        if not torch.cuda.is_available():  # CPU
            import warnings
            warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
                          'If you really want to use it, please modify the corresponding codes.')
            bg_upsampler = None
        else:
            from basicsr.archs.rrdbnet_arch import RRDBNet
            from realesrgan import RealESRGANer
            model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
            bg_upsampler = RealESRGANer(
                scale=2,
                model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
                model=model,
                tile=args.bg_tile,
                tile_pad=10,
                pre_pad=0,
                half=True)  # need to set False in CPU mode
    else:
        bg_upsampler = None

    # ------------------------ set up GFPGAN restorer ------------------------
    if args.version == '1':
        arch = 'original'
        channel_multiplier = 1
        model_name = 'GFPGANv1'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth'
    elif args.version == '1.2':
        arch = 'clean'
        channel_multiplier = 2
        model_name = 'GFPGANCleanv1-NoCE-C2'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth'
    elif args.version == '1.3':
        arch = 'clean'
        channel_multiplier = 2
        model_name = 'GFPGANv1.3'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
    elif args.version == '1.4':
        arch = 'clean'
        channel_multiplier = 2
        model_name = 'GFPGANv1.4'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
    elif args.version == 'RestoreFormer':
        arch = 'RestoreFormer'
        channel_multiplier = 2
        model_name = 'RestoreFormer'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
    else:
        raise ValueError(f'Wrong model version {args.version}.')

    # determine model paths
    model_path = os.path.join('experiments/pretrained_models', model_name + '.pth')
    if not os.path.isfile(model_path):
        model_path = os.path.join('gfpgan/weights', model_name + '.pth')
    if not os.path.isfile(model_path):
        # download pre-trained models from url
        model_path = url

    restorer = GFPGANer(
        model_path=model_path,
        upscale=args.upscale,
        arch=arch,
        channel_multiplier=channel_multiplier,
        bg_upsampler=bg_upsampler)

    # ------------------------ restore ------------------------
    for img_path in img_list:
        print(img_path)
        # read image
        img_name = os.path.basename(img_path)
        print(f'Processing {img_name} ...')
        basename, ext = os.path.splitext(img_name)
        input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)

        # restore faces and background if necessary
        cropped_faces, restored_faces, restored_img = restorer.enhance(
            input_img,
            has_aligned=args.aligned,
            only_center_face=args.only_center_face,
            paste_back=True,
            weight=args.weight)

        # save faces
        for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
            # save cropped face
            save_crop_path = os.path.join(args.output, 'cropped_faces', f'{basename}_{idx:02d}.png')
            imwrite(cropped_face, save_crop_path)
            # save restored face
            if args.suffix is not None:
                save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
            else:
                save_face_name = f'{basename}_{idx:02d}.png'
            save_restore_path = os.path.join(args.output, 'restored_faces', save_face_name)
            imwrite(restored_face, save_restore_path)
            # save comparison image
            cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
            imwrite(cmp_img, os.path.join(args.output, 'cmp', f'{basename}_{idx:02d}.png'))

        # save restored img
        if restored_img is not None:
            if args.ext == 'auto':
                extension = ext[1:]
            else:
                extension = args.ext

            if args.suffix is not None:
                save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}_{args.suffix}.{extension}')
            else:
                save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}.{extension}')
            imwrite(restored_img, save_restore_path)

    print(f'Results are in the [{args.output}] folder.')



## 3. Inference with multi-processing

In [None]:
import concurrent.futures

n = len(img_list)
num_parts = 5
part_size = n // num_parts

split_list = [img_list[i:i+part_size] for i in range(0, n, part_size)]

def process_list(img_list):
    enhance_list(img_list)

with concurrent.futures.ThreadPoolExecutor(max_workers=num_parts) as executor:
    futures = [executor.submit(process_list, img_list) for img_list in split_list]
    for future in concurrent.futures.as_completed(futures):
        try:
            future.result()
        except Exception as e:
            print(f'Exception occurred: {e}')


['inputs/upload/frame0.jpg', 'inputs/upload/frame1.jpg', 'inputs/upload/frame10.jpg', 'inputs/upload/frame11.jpg', 'inputs/upload/frame12.jpg', 'inputs/upload/frame13.jpg', 'inputs/upload/frame14.jpg', 'inputs/upload/frame15.jpg', 'inputs/upload/frame16.jpg', 'inputs/upload/frame17.jpg', 'inputs/upload/frame18.jpg', 'inputs/upload/frame19.jpg', 'inputs/upload/frame2.jpg']
['inputs/upload/frame20.jpg', 'inputs/upload/frame21.jpg', 'inputs/upload/frame22.jpg', 'inputs/upload/frame23.jpg', 'inputs/upload/frame24.jpg', 'inputs/upload/frame25.jpg', 'inputs/upload/frame26.jpg', 'inputs/upload/frame27.jpg', 'inputs/upload/frame28.jpg', 'inputs/upload/frame29.jpg', 'inputs/upload/frame3.jpg', 'inputs/upload/frame30.jpg', 'inputs/upload/frame31.jpg']
['inputs/upload/frame32.jpg', 'inputs/upload/frame33.jpg', 'inputs/upload/frame34.jpg', 'inputs/upload/frame35.jpg', 'inputs/upload/frame36.jpg', 'inputs/upload/frame37.jpg', 'inputs/upload/frame38.jpg', 'inputs/upload/frame39.jpg', 'inputs/upload/



inputs/upload/frame44.jpg
Processing frame44.jpg ...
inputs/upload/frame0.jpg
Processing frame0.jpg ...
inputs/upload/frame56.jpg
Processing frame56.jpg ...
inputs/upload/frame20.jpg
Processing frame20.jpg ...
inputs/upload/frame32.jpg
Processing frame32.jpg ...
	Tile 1/35
	Tile 1/35
	Tile 1/35
	Tile 1/35
	Tile 2/35
	Tile 2/35
	Tile 2/35
	Tile 1/35
	Tile 2/35
	Tile 3/35
	Tile 3/35
	Tile 3/35
	Tile 3/35
	Tile 2/35
	Tile 4/35
	Tile 4/35
	Tile 4/35
	Tile 4/35	Tile 3/35

	Tile 5/35
	Tile 5/35
	Tile 5/35
	Tile 5/35
	Tile 4/35
	Tile 6/35
	Tile 6/35
	Tile 6/35
	Tile 5/35
	Tile 6/35
	Tile 7/35
	Tile 7/35
	Tile 7/35
	Tile 7/35
	Tile 6/35
	Tile 8/35
	Tile 8/35
	Tile 8/35
	Tile 8/35
	Tile 7/35
	Tile 9/35
	Tile 9/35
	Tile 9/35
	Tile 9/35
	Tile 8/35
	Tile 10/35
	Tile 10/35
	Tile 10/35
	Tile 10/35
	Tile 9/35
	Tile 11/35
	Tile 11/35
	Tile 11/35
	Tile 11/35
	Tile 10/35
	Tile 12/35
	Tile 12/35
	Tile 12/35
	Tile 12/35
	Tile 11/35
	Tile 13/35
	Tile 13/35
	Tile 13/35
	Tile 12/35
	Tile 13/35
	Tile 14/35
	T

In [None]:
import datetime
## 5 43
current_time = datetime.datetime.now()
print("Current date and time: ", current_time)


Current date and time:  2023-03-11 12:52:56.649598


## 4. Inference without multi-processing

In [None]:
# Now we use the GFPGAN to restore the above low-quality images
# We use [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for enhancing the background (non-face) regions
# You can find the different models in https://github.com/TencentARC/GFPGAN#european_castle-model-zoo
!rm -rf results
!python inference_gfpgan.py -i inputs/upload/ -o results -v 1.3 -s 2 --bg_upsampler realesrgan

# Usage: python inference_gfpgan.py -i inputs/whole_imgs -o results -v 1.3 -s 2 [options]...
#
#  -h                   show this help
#  -i input             Input image or folder. Default: inputs/whole_imgs
#  -o output            Output folder. Default: results
#  -v version           GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3
#  -s upscale           The final upsampling scale of the image. Default: 2
#  -bg_upsampler        background upsampler. Default: realesrgan
#  -bg_tile             Tile size for background sampler, 0 for no tile during testing. Default: 400
#  -suffix              Suffix of the restored faces
#  -only_center_face    Only restore the center face
#  -aligned             Input are aligned faces
#  -ext                 Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto

!ls results/cmp

Downloading: "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" to /usr/local/lib/python3.9/dist-packages/weights/RealESRGAN_x2plus.pth

100% 64.0M/64.0M [00:01<00:00, 39.3MB/s]
Downloading: "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth" to /content/GFPGAN/gfpgan/weights/detection_Resnet50_Final.pth

100% 104M/104M [00:00<00:00, 299MB/s] 
Downloading: "https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth" to /content/GFPGAN/gfpgan/weights/parsing_parsenet.pth

100% 81.4M/81.4M [00:03<00:00, 23.4MB/s]
Processing frame0.jpg ...
	Tile 1/35
	Tile 2/35
	Tile 3/35
	Tile 4/35
	Tile 5/35
	Tile 6/35
	Tile 7/35
	Tile 8/35
	Tile 9/35
	Tile 10/35
	Tile 11/35
	Tile 12/35
	Tile 13/35
	Tile 14/35
	Tile 15/35
	Tile 16/35
	Tile 17/35
	Tile 18/35
	Tile 19/35
	Tile 20/35
	Tile 21/35
	Tile 22/35
	Tile 23/35
	Tile 24/35
	Tile 25/35
	Tile 26/35
	Tile 27/35
	Tile 28/35
	Tile 29/35
	Tile 30/35
	Tile 

## 6. Download results

In [None]:
# download the result
import os
from google.colab import files

print('Download results')
os.system('zip -r downloadimgs.zip /content/GFPGAN/results/restored_imgs')
files.download("downloadimgs.zip")

Download results


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
!cp /content/GFPGAN/downloadimgs.zip  /content/drive/MyDrive