diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index 4e1cde0..0000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,27 +0,0 @@ -# Contributing to fb.resnet.torch -We want to make contributing to this project as easy and transparent as -possible. - -## Pull Requests -We actively welcome your pull requests. - -1. Fork the repo and create your branch from `master`. -2. If you haven't already, complete the Contributor License Agreement ("CLA"). - -## Contributor License Agreement ("CLA") -In order to accept your pull request, we need you to submit a CLA. You only need -to do this once to work on any of Facebook's open source projects. - -Complete your CLA here: - -## Issues -We use GitHub issues to track public bugs. Please ensure your description is -clear and has sufficient instructions to be able to reproduce the issue. - -## Coding Style -* Use three spaces for indentation rather than tabs -* 80 character line length - -## License -By contributing to fb.resnet.torch, you agree that your contributions will be -licensed under its BSD license. diff --git a/INSTALL.md b/INSTALL.md deleted file mode 100644 index 5b15965..0000000 --- a/INSTALL.md +++ /dev/null @@ -1,92 +0,0 @@ -Torch ResNet Installation -========================= - -This is the suggested way to install the Torch ResNet dependencies on [Ubuntu 14.04+](http://www.ubuntu.com/): -* NVIDIA CUDA 7.0+ -* NVIDIA cuDNN v4 -* Torch -* ImageNet dataset - -## Requirements -* NVIDIA GPU with compute capability 3.5 or above - -## Install CUDA -1. Install the `build-essential` package: - ```bash - sudo apt-get install build-essential - ``` - -2. If you are using a Virtual Machine (like Amazon EC2 instances), install: - ```bash - sudo apt-get update - sudo apt-get install linux-generic - ``` - -3. Download the CUDA .deb file for Linux Ubuntu 14.04 64-bit from: https://developer.nvidia.com/cuda-downloads. -The file will be named something like `cuda-repo-ubuntu1404-7-5-local_7.5-18_amd64.deb` - -4. Install CUDA from the .deb file: - ```bash - sudo dpkg -i cuda-repo-ubuntu1404-7-5-local_7.5-18_amd64.deb - sudo apt-get update - sudo apt-get install cuda - echo "export PATH=/usr/local/cuda/bin/:\$PATH; export LD_LIBRARY_PATH=/usr/local/cuda/lib64/:\$LD_LIBRARY_PATH; " >>~/.bashrc && source ~/.bashrc - ``` - -4. Restart your computer - -## Install cuDNN v4 -1. Download cuDNN v4 from https://developer.nvidia.com/cuDNN (requires registration). - The file will be named something like `cudnn-7.0-linux-x64-v4.0-rc.tgz`. - -2. Extract the file to `/usr/local/cuda`: - ```bash - tar -xvf cudnn-7.0-linux-x64-v4.0-rc.tgz - sudo cp cuda/include/*.h /usr/local/cuda/include - sudo cp cuda/lib64/*.so* /usr/local/cuda/lib64 - ``` - -## Install Torch -1. Install the Torch dependencies: - ```bash - curl -sk https://raw.githubusercontent.com/torch/ezinstall/master/install-deps | bash -e - ``` - -2. Install Torch in a local folder: - ```bash - git clone https://github.com/torch/distro.git ~/torch --recursive - cd ~/torch; ./install.sh - ``` - -If you want to uninstall torch, you can use the command: `rm -rf ~/torch` - -## Install the Torch cuDNN v4 bindings -```bash -git clone -b R4 https://github.com/soumith/cudnn.torch.git -cd cudnn.torch; luarocks make -``` - -## Download the ImageNet dataset -The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) dataset has 1000 categories and 1.2 million images. The images do not need to be preprocessed or packaged in any database, but the validation images need to be moved into appropriate subfolders. - -1. Download the images from http://image-net.org/download-images - -2. Extract the training data: - ```bash - mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train - tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar - find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done - cd .. - ``` - -3. Extract the validation data and move images to subfolders: - ```bash - mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar - wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash - ``` - -## Download Torch ResNet -```bash -git clone https://github.com/facebook/fb.resnet.torch.git -cd fb.resnet.torch -``` diff --git a/PATENTS b/PATENTS deleted file mode 100644 index 2ea150b..0000000 --- a/PATENTS +++ /dev/null @@ -1,33 +0,0 @@ -Additional Grant of Patent Rights Version 2 - -"Software" means the fb.resnet.torch software distributed by Facebook, Inc. - -Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software -("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable -(subject to the termination provision below) license under any Necessary -Claims, to make, have made, use, sell, offer to sell, import, and otherwise -transfer the Software. For avoidance of doubt, no license is granted under -Facebook’s rights in any patent claims that are infringed by (i) modifications -to the Software made by you or any third party or (ii) the Software in -combination with any software or other technology. - -The license granted hereunder will terminate, automatically and without notice, -if you (or any of your subsidiaries, corporate affiliates or agents) initiate -directly or indirectly, or take a direct financial interest in, any Patent -Assertion: (i) against Facebook or any of its subsidiaries or corporate -affiliates, (ii) against any party if such Patent Assertion arises in whole or -in part from any software, technology, product or service of Facebook or any of -its subsidiaries or corporate affiliates, or (iii) against any party relating -to the Software. Notwithstanding the foregoing, if Facebook or any of its -subsidiaries or corporate affiliates files a lawsuit alleging patent -infringement against you in the first instance, and you respond by filing a -patent infringement counterclaim in that lawsuit against that party that is -unrelated to the Software, the license granted hereunder will not terminate -under section (i) of this paragraph due to such counterclaim. - -A "Necessary Claim" is a claim of a patent owned by Facebook that is -necessarily infringed by the Software standing alone. - -A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, -or contributory infringement or inducement to infringe any patent, including a -cross-claim or counterclaim. diff --git a/README.md b/README.md index 6b4860e..386772c 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,50 @@ -GUNN-15 training in Torch -============================ +## Gradually Updated Neural Networks for Large-Scale Image Recognition + +Torch implementation for gradually updated neural networks: +[Gradually Updated Neural Networks for Large-Scale Image Recognition](http://www.cs.jhu.edu/~alanlab/Pubs18/qiao2018gunn.pdf) +[Siyuan Qiao](http://www.cs.jhu.edu/~syqiao/), [Zhishuai Zhang](https://zhishuai.xyz/), [Wei Shen](http://wei-shen.weebly.com/), [Bo Wang](https://bowang87.weebly.com/), [Alan Yuille](http://www.cs.jhu.edu/~ayuille/) +In Thirty-fifth International Conference on Machine Learning (ICML), 2018. + +The code is built on [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch). + +``` +@inproceedings{Gunn, + title = {Gradually Updated Neural Networks for Large-Scale Image Recognition}, + author = {Siyuan Qiao and Zhishuai Zhang and Wei Shen and Bo Wang and Alan L. Yuille}, + booktitle = {International Conference on Machine Learning (ICML)}, + year = {2018} +} +``` + +### Introduction +The state-of-the-art network architectures usually increase the depths by cascading convolutional layers or building blocks. +Gradually Updated Neural Network (GUNN) presents an alternative method to increase the depth. +It introduces computation orderings to the channels within convolutional +layers or blocks, based on which it gradually computes the outputs in a channel-wise manner. +The added orderings not only increase the depths and the learning capacities of the networks without any additional computation costs, but also eliminate the overlap singularities so that the networks are able to converge faster and perform +better. + + + +### Usage +Install Torch and required packages following [here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md). +Training on CIFAR ```bash -th main.lua -netType gunn -dataset cifar10 -batchSize 64 -nGPU 4 -nThreads 8 -shareGradInput true +th main.lua -netType gunn-15 -dataset cifar10 -batchSize 64 -nGPU 4 -nThreads 8 -shareGradInput true -nEpochs 300 +``` +For CIFAR-100, please change cifar10 to cifar100 after -dataset. Training on ImageNet +``` +th main.lua -netType gunn-18 -dataset imagenet -batchSize 256 -nGPU 4 -nThreads 16 -shareGradInput true -nEpochs 120 -data [data folder] ``` + +### Results + +Model | Parameters| CIFAR-10 | CIFAR-100 +-------|:---------:|:---------:|:----------: +GUNN-15 | 1.6M | 4.15 | 20.45 +GUNN-24 | 29.6M | 3.21 | 16.69 + +Model | Parameters| ImageNet Top-1 | ImageNet Top-5 +-------|:---------:|:---------:|:----------: +GUNN-18 | 28.9M | 21.65 | 5.87 +Wide GUNN-18 | 45.6M | 20.59 | 5.52 diff --git a/TRAINING.md b/TRAINING.md deleted file mode 100644 index ce9a644..0000000 --- a/TRAINING.md +++ /dev/null @@ -1,68 +0,0 @@ -Training recipes ----------------- - -### CIFAR-10 - -To train ResNet-20 on CIFAR-10 with 2 GPUs: - -```bash -th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 20 -``` - -To train ResNet-110 instead just change the `-depth` flag: - -```bash -th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 110 -``` - -To fit ResNet-1202 on two GPUs, you will need to use the [`-shareGradInput`](#sharegradinput) flag: - -```bash -th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 1202 -shareGradInput true -``` - -### ImageNet - -See the [installation instructions](INSTALL.md#download-the-imagenet-dataset) for ImageNet data setup. - -To train ResNet-18 on ImageNet with 4 GPUs and 8 data loading threads: - -```bash -th main.lua -depth 18 -nGPU 4 -nThreads 8 -batchSize 256 -data [imagenet-folder] -``` - -To train ResNet-34 instead just change the `-depth` flag: - -```bash -th main.lua -depth 34 -nGPU 4 -nThreads 8 -batchSize 256 -data [imagenet-folder] -``` -To train ResNet-50 on 4 GPUs, you will need to use the [`-shareGradInput`](#sharegradinput) flag: - -```bash -th main.lua -depth 50 -nGPU 4 -nThreads 8 -batchSize 256 -shareGradInput true -data [imagenet-folder] -``` - -To train ResNet-101 or ResNet-152 with batch size 256, you may need 8 GPUs: - -```bash -th main.lua -depth 152 -nGPU 8 -nThreads 12 -batchSize 256 -shareGradInput true -data [imagenet-folder] -``` - -## Useful flags - -For a complete list of flags, run `th main.lua --help`. - -### shareGradInput - -The `-shareGradInput` flag enables sharing of `gradInput` tensors between modules of the same type. This reduces -memory usage. It works correctly with the included ResNet models, but may not work for other network architectures. See -[models/init.lua](models/init.lua#L42-L60) for the implementation. - -The `shareGradInput` implementation may not work with older versions of the `nn` package. Update your `nn` package by running `luarocks install nn`. - -### shortcutType - -The `-shortcutType` flag selects the type of shortcut connection. The [ResNet paper](http://arxiv.org/abs/1512.03385) describes three different shortcut types: -- `A`: identity shortcut with zero-padding for increasing dimensions. This is used for all CIFAR-10 experiments. -- `B`: identity shortcut with 1x1 convolutions for increasing dimesions. This is used for most ImageNet experiments. -- `C`: 1x1 convolutions for all shortcut connections. diff --git a/datasets/cifar10.lua b/datasets/cifar10.lua index 5be6794..b051938 100644 --- a/datasets/cifar10.lua +++ b/datasets/cifar10.lua @@ -46,6 +46,7 @@ function CifarDataset:preprocess() t.ColorNormalize(meanstd), t.HorizontalFlip(0.5), t.RandomCrop(32, 4), + t.Jigsaw(), } elseif self.split == 'val' then return t.ColorNormalize(meanstd) diff --git a/datasets/cifar100.lua b/datasets/cifar100.lua index ef460e5..dcffa4e 100644 --- a/datasets/cifar100.lua +++ b/datasets/cifar100.lua @@ -57,6 +57,7 @@ function CifarDataset:preprocess() t.ColorNormalize(meanstd), t.HorizontalFlip(0.5), t.RandomCrop(32, 4), + t.Jigsaw(), } elseif self.split == 'val' then return t.ColorNormalize(meanstd) diff --git a/datasets/transforms.lua b/datasets/transforms.lua index 981e9f0..4d25fbd 100644 --- a/datasets/transforms.lua +++ b/datasets/transforms.lua @@ -289,4 +289,29 @@ function M.ColorJitter(opt) return M.RandomOrder(ts) end +function M.Jigsaw() + return function(input) + c, h, w = input:size(1), input:size(2), input:size(3) + if torch.uniform() < 1/3 then + return input + end + if torch.uniform() < 0.5 then + d = torch.random(1, w) + if d < w then + l, r = input:narrow(3, 1, d):clone(), input:narrow(3, d + 1, w - d):clone() + input:narrow(3, 1, w - d):copy(r) + input:narrow(3, w - d + 1, d):copy(l) + end + else + d = torch.random(1, h) + if d < h then + u, b = input:narrow(2, 1, d):clone(), input:narrow(2, d + 1, h - d):clone() + input:narrow(2, 1, h - d):copy(b) + input:narrow(2, h - d + 1, d):copy(u) + end + end + return input + end +end + return M diff --git a/intro.png b/intro.png new file mode 100644 index 0000000..c48d45b Binary files /dev/null and b/intro.png differ diff --git a/models/GunnLayer.lua b/models/GunnLayer.lua index b0efdf5..7a08c12 100644 --- a/models/GunnLayer.lua +++ b/models/GunnLayer.lua @@ -1,11 +1,10 @@ require 'nn' require 'cunn' require 'cudnn' -local nninit = require 'nninit' local GunnLayer, parent = torch.class('nn.GunnLayer', 'nn.Container') -function GunnLayer:__init(nChannels, nSegments) +function GunnLayer:__init(nChannels, nSegments, opt) parent.__init(self) self.train = true assert(nChannels % nSegments == 0) @@ -21,13 +20,17 @@ function GunnLayer:__init(nChannels, nSegments) convLayer:add(cudnn.ReLU(true)) convLayer:add(cudnn.SpatialConvolution(oChannels * 2, oChannels, 1, 1, 1, 1, 0, 0)) convLayer:add(cudnn.SpatialBatchNormalization(oChannels)) - local shortcut = nn.Sequential() - shortcut:add(cudnn.SpatialConvolution(nChannels, oChannels, 1, 1, 1, 1, 0, 0)) - shortcut:add(cudnn.SpatialBatchNormalization(oChannels)) - local module = nn.Sequential() - module:add(nn.ConcatTable():add(shortcut):add(convLayer)) - module:add(nn.CAddTable(true)) - table.insert(self.modules, module) + if opt.dataset == 'imagenet' then + table.insert(self.modules, convLayer) + else + local shortcut = nn.Sequential() + shortcut:add(cudnn.SpatialConvolution(nChannels, oChannels, 1, 1, 1, 1, 0, 0)) + shortcut:add(cudnn.SpatialBatchNormalization(oChannels)) + local module = nn.Sequential() + module:add(nn.ConcatTable():add(shortcut):add(convLayer)) + module:add(nn.CAddTable(true)) + table.insert(self.modules, module) + end end self.inputContiguous = torch.CudaTensor() self.inputTable = {} diff --git a/models/gunn.lua b/models/gunn-15.lua similarity index 54% rename from models/gunn.lua rename to models/gunn-15.lua index 0a284a3..7788c4a 100644 --- a/models/gunn.lua +++ b/models/gunn-15.lua @@ -10,40 +10,36 @@ local function createModel(opt) local stp = {20, 25, 30} for i = 1, 3 do cfg[i] = cfg[i] * expansion end local model = nn.Sequential() - if opt.dataset == 'cifar10' or opt.dataset == 'cifar100' then - model:add(cudnn.SpatialConvolution(3, 64, 3, 3, 1, 1, 1, 1)) - model:add(cudnn.SpatialBatchNormalization(64)) - model:add(cudnn.ReLU(true)) - -- - model:add(cudnn.SpatialConvolution(64, cfg[1], 1, 1, 1, 1, 0, 0)) - model:add(cudnn.SpatialBatchNormalization(cfg[1])) - model:add(cudnn.ReLU(true)) - model:add(nn.GunnLayer(cfg[1], stp[1])) - -- - model:add(cudnn.SpatialConvolution(cfg[1], cfg[2], 1, 1, 1, 1, 0, 0)) - model:add(cudnn.SpatialBatchNormalization(cfg[2])) - model:add(cudnn.ReLU(true)) - model:add(cudnn.SpatialAveragePooling(2, 2)) - model:add(nn.GunnLayer(cfg[2], stp[2])) - -- - model:add(cudnn.SpatialConvolution(cfg[2], cfg[3], 1, 1, 1, 1, 0, 0)) - model:add(cudnn.SpatialBatchNormalization(cfg[3])) - model:add(cudnn.ReLU(true)) - model:add(cudnn.SpatialAveragePooling(2, 2)) - model:add(nn.GunnLayer(cfg[3], stp[3])) - -- - model:add(cudnn.SpatialConvolution(cfg[3], cfg[3], 1, 1, 1, 1, 0, 0)) - model:add(cudnn.SpatialBatchNormalization(cfg[3])) - model:add(cudnn.ReLU(true)) - model:add(cudnn.SpatialAveragePooling(8, 8)) - model:add(nn.Reshape(cfg[3])) - end + model:add(cudnn.SpatialConvolution(3, 64, 3, 3, 1, 1, 1, 1)) + model:add(cudnn.SpatialBatchNormalization(64)) + model:add(cudnn.ReLU(true)) + -- + model:add(cudnn.SpatialConvolution(64, cfg[1], 1, 1, 1, 1, 0, 0)) + model:add(cudnn.SpatialBatchNormalization(cfg[1])) + model:add(cudnn.ReLU(true)) + model:add(nn.GunnLayer(cfg[1], stp[1], opt)) + -- + model:add(cudnn.SpatialConvolution(cfg[1], cfg[2], 1, 1, 1, 1, 0, 0)) + model:add(cudnn.SpatialBatchNormalization(cfg[2])) + model:add(cudnn.ReLU(true)) + model:add(cudnn.SpatialAveragePooling(2, 2)) + model:add(nn.GunnLayer(cfg[2], stp[2], opt)) + -- + model:add(cudnn.SpatialConvolution(cfg[2], cfg[3], 1, 1, 1, 1, 0, 0)) + model:add(cudnn.SpatialBatchNormalization(cfg[3])) + model:add(cudnn.ReLU(true)) + model:add(cudnn.SpatialAveragePooling(2, 2)) + model:add(nn.GunnLayer(cfg[3], stp[3], opt)) + -- + model:add(cudnn.SpatialConvolution(cfg[3], cfg[3], 1, 1, 1, 1, 0, 0)) + model:add(cudnn.SpatialBatchNormalization(cfg[3])) + model:add(cudnn.ReLU(true)) + model:add(cudnn.SpatialAveragePooling(8, 8)) + model:add(nn.Reshape(cfg[3])) if opt.dataset == 'cifar10' then model:add(nn.Linear(cfg[3], 10)) elseif opt.dataset == 'cifar100' then model:add(nn.Linear(cfg[3], 100)) - elseif opt.dataset == 'imagenet' then - model:add(nn.Linear(256, 1000)) end --Initialization following ResNet local function ConvInit(name) diff --git a/models/gunn-18.lua b/models/gunn-18.lua new file mode 100644 index 0000000..d3e2ab7 --- /dev/null +++ b/models/gunn-18.lua @@ -0,0 +1,88 @@ +require 'nn' +require 'cunn' +require 'cudnn' +require 'models/GunnLayer' + +local function createModel(opt) + -- Build GUNN-18 + local cfg = {400, 800, 1600, 2000} + local stp = {10, 20, 40, 50} + local model = nn.Sequential() + + model:add(cudnn.SpatialConvolution(3, 64, 7, 7, 2, 2, 3, 3)) + model:add(cudnn.SpatialBatchNormalization(64)) + model:add(cudnn.ReLU(true)) + model:add(nn.SpatialMaxPooling(3, 3, 2, 2, 1, 1)) + -- + model:add(cudnn.SpatialConvolution(64, cfg[1], 1, 1, 1, 1, 0, 0)) + model:add(cudnn.SpatialBatchNormalization(cfg[1])) + model:add(cudnn.ReLU(true)) + model:add(nn.GunnLayer(cfg[1], stp[1], opt)) + model:add(cudnn.SpatialAveragePooling(2, 2)) + -- + model:add(cudnn.SpatialConvolution(cfg[1], cfg[2], 1, 1, 1, 1, 0, 0)) + model:add(cudnn.SpatialBatchNormalization(cfg[2])) + model:add(cudnn.ReLU(true)) + model:add(nn.GunnLayer(cfg[2], stp[2], opt)) + model:add(cudnn.SpatialAveragePooling(2, 2)) + -- + model:add(cudnn.SpatialConvolution(cfg[2], cfg[3], 1, 1, 1, 1, 0, 0)) + model:add(cudnn.SpatialBatchNormalization(cfg[3])) + model:add(cudnn.ReLU(true)) + model:add(nn.GunnLayer(cfg[3], stp[3], opt)) + model:add(cudnn.SpatialAveragePooling(2, 2)) + -- + model:add(cudnn.SpatialConvolution(cfg[3], cfg[4], 1, 1, 1, 1, 0, 0)) + model:add(cudnn.SpatialBatchNormalization(cfg[4])) + model:add(cudnn.ReLU(true)) + model:add(nn.GunnLayer(cfg[4], stp[4], opt)) + model:add(cudnn.SpatialAveragePooling(7, 7)) + -- + model:add(nn.Reshape(cfg[4])) + model:add(nn.Linear(cfg[4], 1000)) + + --Initialization following ResNet + local function ConvInit(name) + for k,v in pairs(model:findModules(name)) do + local n = v.kW*v.kH*v.nOutputPlane + v.weight:normal(0,math.sqrt(2/n)) + if cudnn.version >= 4000 then + v.bias = nil + v.gradBias = nil + else + v.bias:zero() + end + end + end + + local function BNInit(name) + for k,v in pairs(model:findModules(name)) do + v.weight:fill(1) + v.bias:zero() + end + end + + ConvInit('cudnn.SpatialConvolution') + BNInit('cudnn.SpatialBatchNormalization') + for k,v in pairs(model:findModules('nn.Linear')) do + v.bias:zero() + end + + model:type(opt.tensorType) + + if opt.cudnn == 'deterministic' then + model:apply(function(m) + if m.setMode then m:setMode(1,1,1) end + end) + end + + model:get(1).gradInput = nil + + print(model) + local modelParam, np = model:parameters(), 0 + for k, v in pairs(modelParam) do np = np + v:nElement() end + print(string.format('| number of parameters: %d', np)) + return model +end + +return createModel diff --git a/train.lua b/train.lua index 7f1c98d..318bc93 100644 --- a/train.lua +++ b/train.lua @@ -176,6 +176,7 @@ function Trainer:learningRate(epoch) local decay = 0 if self.opt.dataset == 'imagenet' then decay = math.floor((epoch - 1) / 30) + if decay >=3 then decay = decay + 1 end elseif self.opt.dataset == 'cifar10' then decay = epoch >= 0.75*self.opt.nEpochs and 2 or epoch >= 0.5*self.opt.nEpochs and 1 or 0 elseif self.opt.dataset == 'cifar100' then