Skip to content

Training neural network models for image classification on the CIFAR dataset. Implementing deep learning compression methods.

Notifications You must be signed in to change notification settings

brondaeh/CIFAR-Image-Classifier

Repository files navigation

CIFAR Image Classifier

Introduction

In this project, multiple CNN architectures are used to train image classification models on the CIFAR dataset. Model compression methods like network pruning are also explored to reduce the size of the neural network by removing redundant parameters which leads to improved computational efficiency during inference.

Performance

Dataset Model Pruning Method Pruning Ratio (%) Params (M) Top1 Acc (%)
CIFAR10 VGG16 - - 33.65 94.71
CIFAR10 VGG16 L1 Norm Uniform 10 28.74 92.19
CIFAR10 MobileNetV1 - - 3.22 91.53
CIFAR10 MobileNetV2 - - 2.3 94.44
CIFAR10 ResNet18 - - 11.17 95.42
CIFAR100 VGG16 - - 34.02 73.95
CIFAR100 VGG16 L1 Norm Uniform 10 29.07 63.92
CIFAR100 MobileNetV1 - - 3.31 68.18
CIFAR100 MobileNetV2 - - 2.41 75.96
CIFAR100 ResNet18 - - 11.22 77.32

Getting Started

  1. Clone the repository to the desired directory:
git clone https://github.com/brondaeh/CIFAR-Image-Classifier.git
  1. Install the required libraries to your environment:
pip install -r requirements.txt
  1. Modify the configuration parameters in config.yaml.
    • Set enable_gpu to True if you have a cuda enabled GPU for training.
    • By default L1 norm pruning is only applicable for VGG16, keep pruning_flag to False for other models.
  2. Start training and testing by running main.py.

File Structure

CIFAR-Image-Classifier
├── Model_Details
│   ├── mobilenetv1_details.md
│   ├── mobilenetv2_details.md
│   ├── resnet18_details.md
│   └── vgg16_details.md
├── Models
│   ├── mobilenetv1.py
│   ├── mobilenetv2.py
│   ├── resnet.py
│   └── vgg.py
├── Pruner
│   ├── Pruning_Criterion
│   │   ├── KMean
│   │   │   └── KMean_base.py
│   │   ├── L1norm
│   │   │   └── L1norm.py
│   │   ├── Taylor
│   │   │   └── Taylor.py
│   ├── pruning_engine_base.py
│   └── pruning_engine.py
├── Pruning_Functions
│   └── prune_vgg.py
├── config.yaml
├── main.py
├── README.md
├── requirements.txt
├── utils.py

References

  1. CNN-Pruning-Engine
  2. pytorch-cifar
  3. MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications
  4. MobileNetV2: Inverted Residuals and Linear Bottlenecks
  5. Deep Residual Learning for Image Recognition
  6. Very Deep Convolutional Networks for Large-Scale Image Recognition

About

Training neural network models for image classification on the CIFAR dataset. Implementing deep learning compression methods.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages