This repository demonstrates how to train a cat vs dog recognition model and export the model to an optimized frozen graph easy for deployment using TensorFlow. If you want to know how to deploy a flask app which recognizes cats/dogs using TensorFlow, please visit cat-recognition-app.
- Python3 (Tested on 3.6.8)
- TensorFlow (Tested on 1.12.0)
- NumPy (Tested on 1.15.1)
- tqdm (Tested on 4.29.1)
- Dogs vs. Cats dataset from https://www.kaggle.com/c/dogs-vs-cats
- (Optional if you want to run tests) PyTorch (Tested on 1.0.0 and 1.0.1)
We recommend using Anaconda3 / Miniconda3 to manage your python environment.
If the machine you're using does not have a GPU instance, you can just:
$ pip install -r requirements.txt
$ conda install --file requirements.txt
However, if you want to use GPU to accelerate the training process, please visit TensorFlow - GPU support for more information.
Train a Convolutional Neural Network
- Create training/valid set (dataset.py)
- Load, augment, resize and normalize the images using
- Define a CNN model (net.py)
- Here we use the ShufflenetV2 structure, which achieves great balance between speed and accuracy.
- We do transfer learning on ShuffleNetV2 using the pretrained weights from https://github.com/ericsun99/Shufflenet-v2-Pytorch.
- If you want to know how to load PyTorch weights onto TensorFlow model graph, please check
convert_pytorch_weight_teststarting from line 44 in
- Train the CNN model (train.py)
- Serialize the model for deployment (train.py)
If you want to execute the code, make sure you have all package requirements installed, and Dogs vs. Cats training dataset placed in
datasets. The folder structure should be like:
cat-recognition-train +-- train.py +-- net.py +-- dataset.py +-- datasets +-- train | +-- cat.0.jpg | +-- cat.1.jpg | ... | +-- cat.12499.jpg | +-- dog.0.jpg | +-- dog.1.jpg | ... | +-- dog.12499.jpg +-- ...
After all requirements set, run the following command using default arguments:
$ python train.py
Or you can pass your desired arguments:
$ python train.py --epochs 30 --batch_size 32 --valset_ratio .1 --optim sgd --lr_decay_step 10
train.py for available arguments.
Visualizing Learning using Tensorboard
During training, you can supervise how is the training going by running:
$ tensorboard --logdir runs
And you can check the tensorboard summaries on
Training and Validation Flow
Optimized Network Graph
Predict Using Optimized Frozen Graph
predict.py for details and demo.
You can run
$ python predict.py
The result should be:
Predicting catness on images/test.png using model from baseline_model/optimized_net_best_acc.pb Catness: 16.460064 Cat Probability: 1.000000 It's a cat.
for demonstration. Also, if you have your own cat / dog photo for testing, run
$ python predict.py --path path/to/your/img.png
PNGs, JPGs, BMPs are supported.