This repository is a PyTorch implementation of AlexNet.
- based on PyTorch 1.5 and Python 3.7.
- using TensorBoardX to record loss and accuracy.
- pretrained on imagenette (a subset of 10 classes from imagenet).
- supports both Batch Normalization and Local Response Normalization.
- using groups of convolution layers to simulate multi-gpu training, thus the network structure is more familiar to the original one in the paper rather than the official implementation of pytorch.
python=3.7
tqdm
torch
torchvision
tensorboardx
CUDA support is recommended but not essential.
Please refer to this page to download the imagenette dataset.
Extract the imagenette2 folder to PyTorch-AlexNet/datasets/.
Run the following command under PyTorch-AlexNet/.
python train.py --name myalexnet
--normalizationchoose which normalization method to use, eitherbnorlrn.--activationchoose which activation method to use, eitherreluortanh.--poolingchoose which pooling method to use, eithermaxoravg.--epochshow many epochs to train, a possitive integer.--batch_sizehow many images a batch contains, a possitive integer.--num_classeshow many classes to classify in this dataset, it can be automatically set if using imagenet or imagenette dataset.--datasetthe name of dataset, eitherimagenetorimagenette.--starting_epochthe starting epoch, default 0, if set to a possitive integer, the starting_epoch-1 checkpoint will be loaded before training.
I found that 5e-3 is a nice learning rate for both batch normalization and local response normalization network. The learning rate will automatically decrease during training. Actually, it will be multiplied by 0.1 every 30 epochs.
the tensorboard logdir is log/name, run the following command to start tensorboard server.
tensorboard --logdir log/myalexnet
Remember that tensorboard and tensorflow should be installed before this.
It's the team project for 算法设计与分析.
Krizhevsky, Alex & Sutskever, Ilya & Hinton, Geoffrey. (2012). ImageNet Classification with Deep Convolutional Neural Networks. Neural Information Processing Systems. 25. 10.1145/3065386.