Skip to content

hiram64/ESRGAN-tensorflow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

31 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ESRGAN (TensorFlow)

This repository provides a TensorFlow implementation of the paper "ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks" by X. Wang et al.

Dependencies

tensorflow, openCV, sklearn, numpy

The versions of my test environment :
Python==3.6.8, tensorflow-gpu==1.12.0, openCV==4.1.0, scikit-learn==0.20.3, numpy==1.16.2

How to Use

1. Prepare data for training

Prepare your data and put them into the directory specified by the flag "data_dir"(e.g. './data/LSUN') of train.py. Other necessary directories are created automatically as set in the script.

2. Prepare data for training

Run train.py script. The main processes are :

  • Data processing : create patches of HR and LR(by downsampling HR patches). These processed data can be saved in directories so that they can be recycled to use.

  • Pre-train with pixel-wise loss : As described in the paper, pre-training of Generator is done. You can set "pretrain_generator" flag to False to use an existing pre-trained checkpoint model. (training ESRGAN without pre-trained model is not supported.)

  • Training ESRGAN : based on pre-trained model, training ESRGAN is done

# python train.py

(data directory can be passed by the optional argument)
# python train.py --data_dir ./data/LSUN

3. Inference LR data

After training is finished, super-resolution of LR images is available. Input data can be specified "data_dir" of inference.py script.

# python inference.py

(data directory can be passed by the optional argument)
# python inference.py --data_dir ./data/inference

4. Inference via Network interpolation

The paper proposes the network interpolation method which linearly combines the weights of pixelwise-based pretrain model and ESRGAN generator. You can run this after training both pre-train model and ESRGAN finishes. Input data can be specified "data_dir" of network_interpolation.py script.

# python network_interpolation.py

(data directory can be passed by the optional argument)
# python network_interpolation.py --data_dir ./data/inference

Experiment Result

DIV2K dataset

DIV2K is a collection of 2K resolution high quality images.
https://data.vision.ee.ethz.ch/cvl/DIV2K/

from left to right: bicubic interpolation, ESRGAN, ESRGAN with network interpolation, High resolution(GT). 4x super resolution.

LSUN

LSUN is a collection of ordinaly resolution bedroom images.
https://www.kaggle.com/jhoward/lsun_bedroom/data

from left to right: bicubic interpolation, ESRGAN, ESRGAN with network interpolation, High resolution(GT). 4x super resolution.

Experiment condition

  • training with 800 images and cropped 2 patches per image for DIV2K
  • training with about 5000 images from 20% collection dataset and cropped 2 patches per image for LSUN
  • apply data augmentation(horizontal flip and rotate by 90 degree)
  • 15 RRDBs, 32 batchsize, 50,000 iteration per training phase. Other parameters are the same as the paper.
  • Network interpolation parameter is 0.2

Limitations

  • Only 4x super-resolution is supported
  • Grayscale images are not supported
  • Only Single GPU usage

To do list

The following features have not been implemented apart from the paper.

  • Perceptual loss using VGG19(currently pixel-wise loss is implemented instead)
  • Learning rate scheduling
  • Network interpolation
  • Data augmentation
  • Evaluation metrics

Notes

Some setting parameters like the number of RRDB blocks, mini-batch size, the number of iteration are changed corresponding to my test environment. So, please change them if you would prefer the same condition as the paper.

Reference

Releases

No releases published

Packages

No packages published

Languages