This project demonstrates the use of,
- PyTorch to write neural networks
- CIFAR10 dataset for image classification and
- Metal Performance Shaders (MPS) backend for GPU training acceleration (on Mac computers with Apple silicon)
This project is organised as shown below,
.
├── Makefile
├── README.md
├── cifar10_playground.ipynb # Batch Normalization
├── cifar10_playground_layer_norm.ipynb # Layer Normalization
├── cifar10_playground_group_norm.ipynb # Group Normalization
├── data # This folder is created
│ ├── cifar-10-batches-py # and data downloaded
│ │ ├── batches.meta # when the notebook is run
│ │ ├── data_batch_1
│ │ ├── data_batch_2
│ │ ├── data_batch_3
│ │ ├── data_batch_4
│ │ ├── data_batch_5
│ │ ├── readme.html
│ │ └── test_batch
│ └── cifar-10-python.tar.gz
├── exploratory_analysis.ipynb
├── model_analysis.ipynb
├── models.py # Contains the 3 models NetBN, NetLN and NetGN
└── runs
└── exploratory_analysis
6 directories, 18 files
1 directory, 5 files
- Make sure
JupyterLab
is installed,
$ jupyter --version
Selected Jupyter core packages...
IPython : 8.19.0
ipykernel : 6.28.0
ipywidgets : not installed
jupyter_client : 8.6.0
jupyter_core : 5.5.1
jupyter_server : 2.12.1
jupyterlab : 4.0.9
nbclient : 0.9.0
nbconvert : 7.13.1
nbformat : 5.9.2
notebook : not installed
qtconsole : not installed
traitlets : 5.14.0
If not, install it,
# Using pip:
$ pip install jupyterlab
# OR using Homebrew, a package manager for macOS and Linux
$ brew install jupyterlab
- Clone this repository to your local machine.
$ git clone https://github.com/bensooraj/pytorch-s8-CIFAR10
$ cd pytorch-s8-CIFAR10
- Start the lab!
$ make start-lab
This should automatically launch your default browser and open http://localhost:8888/lab
.
All set!
The individual PRs contain more details,
Normalization | Training Accuracy | Testing Accuracy |
---|---|---|
BatchNorm (BN) | 44.57% | 47.64% |
LayerNorm (LN) | 42.66% | 45.36% |
GroupNorm (GN) | 23.66% | 23.9% |
BatchNorm > LayerNorm >> GroupNorm
- The MPS backend doesn't work properly with
shuffle=True
fortorch.utils.data.DataLoader
.