Skip to content
ShuffleNet in PyTorch. Based on https://arxiv.org/abs/1707.01083
Python
Branch: master
Clone or download
jaxony Merge pull request #7 from jaxony/feature/inference
Added inference script and weights for ImageNet
Latest commit e9bf42f Dec 20, 2017
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
imagenet
.gitignore
LICENSE
README.md
infer.py
model.py
tests.py

README.md

ShuffleNet in PyTorch

An implementation of ShuffleNet in PyTorch. ShuffleNet is an efficient convolutional neural network architecture for mobile devices. According to the paper, it outperforms Google's MobileNet by a small percentage.

What is ShuffleNet?

In one sentence, ShuffleNet is a ResNet-like model that uses residual blocks (called ShuffleUnits), with the main innovation being the use of pointwise, or 1x1, group convolutions as opposed to normal pointwise convolutions.

Usage

Clone the repo:

git clone https://github.com/jaxony/ShuffleNet.git

Use the model defined in model.py:

from model import ShuffleNet

# running on MNIST
net = ShuffleNet(num_classes=10, in_channels=1)

Performance

Trained on ImageNet (using the PyTorch ImageNet example) with groups=3 and no channel multiplier. On the test set, got 62.2% top 1 and 84.2% top 5. Unfortunately, this isn't comparable to Table 5 of the paper, because they don't run a network with these settings, but it is somewhere between the network with groups=3 and half the number of channels (42.8% top 1) and the network with the same number of channels but groups=8 (32.4% top 1). The pretrained state dictionary can be found here, in the following format:

{
    'epoch': epoch + 1,
    'arch': args.arch,
    'state_dict': model.state_dict(),
    'best_prec1': best_prec1,
    'optimizer' : optimizer.state_dict()
}

Note: trained with the default ImageNet settings, which are actually different from the training regime described in the paper. Pending running again with those settings (and groups=8).

You can’t perform that action at this time.