Improving Deep Learning Accuracy by Data Augmentation Using Generative Adversarial Networks (GAN) is a project for the NTU module EE3080 Design & Innovation Project.
We won second place in the NTU EEE DIP competition within the Thematic Programme.
By Tan Kim Wai, Ng Woon Ting Elizabeth, Melvin Kok Xinwei, Lee Huei Min, Teo Jing Yu, Matthew Stanley, Carina Low Su Yun and Tian Siming.
The aim of the project was to find out if data augmentation using GANs could improve deep learning accuracy, specifically classification accuracy. The improvement in performance is measured by comparing the accuracy of models trained on augmented datasets to that of models trained on a baseline dataset. As a comparison to the current predominant data augmentation methods, we also used the accuracies of models trained on noise-augmented datasets as an additional reference.
Our project goals were achieved as the results collated showed an improvement of up to 4.7% in classification accuracy.. However, it is also of note that data augmentation using GANs did not significantly outperform data augmentation using noise. When a state-of-the-art classification model like VGG19 is used, the accuracy was only improved marginally over the baseline accuracy.
Using the Facial Expression Recognition (FER2013) dataset from Kaggle as our baseline dataset, we performed classification using three different CNN models, namely VGG19, ResNet101, and EfficientNet-B2. This initial set of test classification accuracies served as our baseline accuracy.
We then generated images using two methods:
- GAN models (ESRGAN & WGAN-GP)
- Adding noise to images (Gaussian, Laplace and Poisson noise)
We combined the generated images with the baseline dataset to create multiple augmented datasets. These datasets were then used to train the abovementioned CNN models and the test classification accuracies were then compared to the baseline accuracy.
We further improved the outputs of the GAN generated images by using two methods:
- Feeding enhanced resolution images from ESRGAN into WGAN-GP (instead of the normal resolution images)
- Using face detection to improve the output of WGAN-GP
Dataset | Validation Accuracy | Test Accuracy |
---|---|---|
Original | 63.42% | 64.36% |
Noise | 65.98% | 69.10% (+4.74) |
WGAN-GP | 65.61% | 66.87% (+2.51) |
ESRGAN | 63.89% | 66.01% (+1.65) |
ESRGAN + WGAN-GP | 66.12% | 68.32% (+3.96) |
ESRGAN + WGAN-GP + Face Detection | 67.43% | 68.79% (+4.43) |
Dataset | Validation Accuracy | Test Accuracy |
---|---|---|
Original | 63.75% | 62.25% |
Noise | 65.09% | 67.54% (+5.29) |
WGAN-GP | 65.67% | 66.59% (+4.34) |
ESRGAN | 63.89% | 64.25% (+2.00) |
ESRGAN + WGAN-GP | 64.17% | 64.47% (+2.22) |
ESRGAN + WGAN-GP + Face Detection | 65.95% | 66.90% (+4.65) |
Dataset | Validation Accuracy | Test Accuracy |
---|---|---|
Original | 67.71% | 68.29% |
Noise | 66.95% | 67.67% (-0.62) |
WGAN-GP | 66.20% | 68.85% (+0.56) |
ESRGAN | 66.04% | 67.34% (-0.95) |
ESRGAN + WGAN-GP | 65.62% | 67.04% (+1.25) |
ESRGAN + WGAN-GP + Face Detection | 66.59% | 68.68% (+0.39) |
- Clone the repo
git clone https://github.com/melvinkokxw/improving-dl-accuracy-gan
cd improving-dl-accuracy-gan
- Install required packages
The following packages are required:
torch
torchvision
Pillow
numpy
pandas
pytorchcv
opencv-python
tqdm
matplotlib
imgaug
Install via pip:
python3 -m pip install <package-name>
The following files are required to run the program:
- Image
- Haar Cascade file. The Haar cascade file is required to run face detection can be obtained from here. Once downloaded, place into the
weights
folder - ESRGAN weights file. Get download link from here.
As a guideline, here is where to put the weights & data
.
├── README.md
├── classifier
├── models
├── data
│ ├── baseline
│ │ ├── test
│ │ ├── train
│ │ └── val
│ ├── esrgan
│ │ ├── test
│ │ ├── train
│ │ └── val
│ └── {wgan_gp, esrgan_wgan_gp, esrgan_wgan_gp_fd}
│ ├── 0
│ ├── ...
│ └── 6
└── weights
├── esrgan
│ ├── README.md
│ ├── RRDB_ESRGAN_x4.pth
├── {vgg19, resnet101, efficientnet-b2b} <-- Classifier weights go here
│ ├── {baseline, wgan_gp, esrgan, esrgan_wgan_gp, esrgan_wgan_gp_fd}
│ │ ├── weights_1.pth
│ │ ├── weights_2.pth
│ │ └── weights_3.pth
├── {baseline, wgan_gp, esrgan, esrgan_wgan_gp, esrgan_wgan_gp_fd} <-- GAN weights go here
├── 0
│ └── weights.pth
├── ...
└── 6
└── weights.pth
-
Place image folders into
data/
, as specified in the folder structure -
Run the training file
python3 models/wgan_gp/train.py [-h] [--n_epochs N_EPOCHS] [--batch_size BATCH_SIZE] [--lr LR]
[--b1 B1] [--b2 B2] [--n_cpu N_CPU] [--latent_dim LATENT_DIM]
[--img_size IMG_SIZE] [--channels CHANNELS]
[--n_critic N_CRITIC] [--clip_value CLIP_VALUE]
[--sample_interval SAMPLE_INTERVAL]
[--dataset {baseline,esrgan}]
-
Place weights into respective folders in
weights/
, as specified in the folder structure -
Run the relevant generator file
For ESRGAN:
python3 models/esrgan/generate.py
For WGAN-GP & its variants:
python3 models/wgan_gp/generate.py [-h] [--quality {baseline, esrgan}] [--face_detection]
-
Place image folders into
data/
, as specified in the folder structure -
Run the classifier training file
python3 classifier/train.py [-h] [--dataset {baseline,esrgan,wgan_gp,esrgan_wgan_gp,esrgan_wgan_gp_fd}]
[--classifier {vgg19,resnet101,efficientnet-b2b}]
- Weights will be saved into
weights/<classifier>/<dataset>
, while training graphs will be saved intoclassifier/graphs
-
Place weights into respective folders in
weights/
, as specified in the folder structure -
Run the classifier testing file
python3 classifier/test.py [-h]
[--dataset {baseline,esrgan,wgan_gp,esrgan_wgan_gp,esrgan_wgan_gp_fd}]
[--classifier {vgg19,resnet101,efficientnet-b2b}]
- The highest accuracy and corresponding weight file will be printed to
stdout
- WGAN-GP model adapted from PyTorch-GAN
- ESRGAN model adapted from ESRGAN