Skip to content

CanReader/MnistPY

Repository files navigation

MNIST Digit Classifier

A professional PyTorch implementation of a Convolutional Neural Network (CNN) for handwritten digit classification on the MNIST dataset.

Architecture

Layer Output Shape Parameters
Input 1 x 28 x 28 -
Conv2d(1, 32, 3) + BN + ReLU + MaxPool 32 x 14 x 14 416
Conv2d(32, 64, 3) + BN + MaxPool 64 x 7 x 7 18,624
Flatten + Linear(3136, 128) + ReLU 128 401,536
Dropout(0.25) + Linear(128, 10) 10 1,290
Total ~422K

Features

  • CNN architecture with batch normalisation and dropout
  • Data augmentation (random rotation, translation) to improve generalisation
  • Train / Validation / Test split (54K / 6K / 10K)
  • Early stopping on validation loss with best-model checkpointing
  • Learning rate scheduling (StepLR)
  • Full evaluation report with per-class precision, recall, and F1
  • Visualisations — training curves, confusion matrix, sample predictions
  • Single-image inference script for custom images

Quick Start

# 1. Install dependencies
pip install -r requirements.txt

# 2. Train and evaluate
python main.py

# 3. Predict on a custom image
python predict.py path/to/digit.png

CLI Options

python main.py [OPTIONS]

  --epochs N       Number of training epochs (default: 15)
  --batch-size N   Batch size (default: 128)
  --lr F           Learning rate (default: 0.001)
  --eval-only      Skip training, evaluate a saved checkpoint

Project Structure

MnistPY/
├── main.py           # Entry point — train, evaluate, visualise
├── config.py         # All hyperparameters and paths
├── model.py          # CNN architecture (MNISTNet)
├── dataset.py        # Data loading, transforms, splits
├── train.py          # Training loop with validation
├── evaluate.py       # Metrics and classification report
├── predict.py        # Single-image inference CLI
├── visualize.py      # Plotting utilities
├── requirements.txt  # Python dependencies
└── outputs/          # Generated after training
    ├── best_model.pth
    ├── training_curves.png
    ├── confusion_matrix.png
    └── sample_predictions.png

Expected Results

With default hyperparameters (~15 epochs), the model reaches ~99.2% test accuracy.

About

This is my first AI model which uses tensorflo/keras

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages