Skip to content

jackyblackme/mxnet_center_loss

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mxnet_center_loss

This is a simple implementation of the center loss introduced by this paper : 《A Discriminative Feature Learning Approach for Deep Face Recognition》,Yandong Wen, Kaipeng Zhang, Zhifeng Li, and Yu Qiao, Shenzhen check their site

中文

Prerequisities

install mxnet

for visualization, you may have to install seaborn and matplotlib

sudo pip install seaborn matplotlib

code

  • center_loss.py implementation of the operator and custom metric of the loss
  • data.py custom MNIST iterator, output 2 labels( one for softmax and one for center loss
  • train_model.py copied from mxnet example with some modification
  • train.py script to train the model
  • vis.py script to visualise the result

running the tests

1 set path of mxnet

change mxnet_root to your mxnet root folder in data.py

2 train

  • with cpu

    python train.py --batch-size=128

  • with gpu

    python train.py --gpus=0

    or multi device( not a good idea for MNIST example here )

    python train.py --gpus=0,1 --batch-size=256

then you can see the output by typing

tail -f log.txt

3 visualize the result

run

python vis.py

You will see something like right picture... Now compare it with the 'softmax only' experiment in left, all the samples are well clustered, therefor we can expect better generalization performance. But the difference is not fatal here(center loss does help with convergence, see the last figure), since the number of classes is actually the same during train and test stages. For other application such as face recognition, the potential number of classes is unknown, then a good embedding is essential.

center_loss

training log:

train_log

About

implement center loss operator for mxnet

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%