This script trains a convolutional neural network (CNN) on the CIFAR-10 image classification dataset using PyTorch.
Python 3.6+
PyTorch
torchvision
matplotlib
Instructions:
Bash
pip install torch torchvision matplotlib
Bash
python script.py
Overview:
Loads and preprocesses the CIFAR-10 dataset. Defines a CNN architecture with convolutional, pooling, and fully connected layers. Trains the model using Adam optimizer and CrossEntropy loss. Evaluates and reports test accuracy, with the option to save the best performing model. Includes functions to visualize predictions on a batch of images and analyze performance per class.
Data Loading and Preprocessing: Loads the CIFAR-10 dataset using torchvision.datasets. Applies data transformations (e.g., ToTensor, normalization) using torchvision.transforms. Defines batch sizes and data loaders for training and testing.
-
Defines a CNN architecture with convolutional layers, activation functions (ReLU), pooling layers, and a final fully connected layer.
-
Includes BatchNorm layers for regularization.
-
Modify hyperparameters (learning rate, epochs, architecture) to improve performance.
-
Explore advanced data augmentation techniques.
-
Implement early stopping to prevent overfitting.
-
Consider hyperparameter tuning using libraries like Optuna.
-
Analyze training behavior with tensorboard pvisualizations.
-
Experiment with different architectures and techniques.
LCMD
Feel free to modify the script and README for your specific needs. For more details on PyTorch and CIFAR-10, refer to their official documentation.