A production-grade handwritten digit recognition system built from scratch using deep neural networks.
Trains on the MNIST dataset (28×28 pixel images) and achieves up to 99.6% test accuracy using a custom CNN architecture.
Features · Quickstart · Architecture · Results · Usage
- Overview
- Features
- Project Structure
- Model Architectures
- Quickstart
- Usage
- Configuration
- Results
- Technical Details
- Testing
- Contributing
- Author
- License
This project implements a complete end-to-end pipeline for recognising handwritten digits (0–9) drawn on a 28×28 pixel canvas — exactly the format used by the MNIST benchmark dataset.
Three neural network architectures are provided:
| Model | Type | Parameters | Typical Accuracy |
|---|---|---|---|
mlp |
Multi-Layer Perceptron | ~669K | ~98.5% |
cnn |
Convolutional Neural Network | ~422K | ~99.3% |
deep_cnn |
Deep CNN with residual-style blocks | ~1.5M | ~99.6% |
The pipeline includes data augmentation, mixed-precision training, learning-rate scheduling, early stopping, and full evaluation with confusion matrices, per-class metrics, and visualised misclassifications.
- Three model architectures — MLP baseline, standard CNN, and a deeper variant
- Data augmentation — Random affine transforms + perspective distortion for robust generalisation
- Mixed-precision training — FP16 with PyTorch AMP for faster GPU training
- OneCycleLR scheduler — Fast convergence with superconvergence policy
- Early stopping — Prevents overfitting with configurable patience
- Comprehensive evaluation — Accuracy, precision, recall, F1 (macro + per-class), confusion matrix
- Visualisations — Training curves, misclassified grids, class probability charts
- TensorBoard integration — Real-time loss / accuracy monitoring
- YAML-driven config — All hyperparameters in one place, no code changes needed
- Fully tested — Unit tests for all core modules with
pytest - Reproducible — Seeded everywhere; deterministic training
digit-recognition/
│
├── src/ # Core library modules
│ ├── __init__.py
│ ├── model.py # MLP, CNN, DeepCNN architectures
│ ├── data_loader.py # MNIST loading + augmentation pipeline
│ ├── train.py # Full training loop (AMP, OneCycleLR, checkpointing)
│ ├── evaluate.py # Metrics, confusion matrix, misclassified samples
│ ├── predict.py # DigitPredictor class + single-image inference
│ └── utils.py # Seeding, device, AverageMeter, EarlyStopping, checkpointing
│
├── configs/
│ └── config.yaml # All hyperparameters & paths
│
├── tests/
│ ├── __init__.py
│ └── test_model.py # Pytest unit tests
│
├── notebooks/ # Jupyter exploration notebooks (optional)
├── models/ # Saved checkpoints (created at training time)
├── results/ # Evaluation plots & reports (created at eval time)
├── runs/ # TensorBoard logs (created at training time)
│
├── main.py # Unified CLI entry point
├── requirements.txt
├── setup.py
├── .gitignore
└── README.md
A fully-connected network that flattens the 28×28 input and processes it through three hidden layers with BatchNorm, ReLU, and Dropout regularisation.
Input (784) → FC(512) → BN → ReLU → Drop → FC(256) → BN → ReLU → Drop → FC(128) → ... → FC(10)
Two convolutional blocks each with double 3×3 convolutions followed by BatchNorm, MaxPooling, and Dropout. A fully-connected head produces the final class logits.
Input [1×28×28]
↓ Conv Block 1: Conv(1→32) → BN → ReLU → Conv(32→32) → BN → ReLU → MaxPool → Drop → [32×14×14]
↓ Conv Block 2: Conv(32→64) → BN → ReLU → Conv(64→64) → BN → ReLU → MaxPool → Drop → [64×7×7]
↓ Flatten → FC(3136→256) → BN → ReLU → Drop(0.5) → FC(256→10)
Three convolutional blocks with increasing channel widths (32→64→128), adaptive average pooling, and a three-layer fully-connected head. Achieves the highest accuracy.
git clone https://github.com/khallark/digit-recognition.git
cd digit-recognition
python -m venv venv && source venv/bin/activate # Windows: venv\Scripts\activate
pip install -r requirements.txtpython main.py trainMNIST (~11 MB) downloads automatically on first run. Training takes ~5 minutes on a GPU, ~25 minutes on CPU.
python main.py evalResults (confusion matrix, per-class accuracy, classification report) are saved to results/.
python main.py predict --image samples/my_digit.pngpython main.py demo# Train with default config (CNN, 30 epochs)
python main.py train
# Train with a different model
# Edit configs/config.yaml → model: deep_cnn
python main.py train --config configs/config.yamlWhat happens during training:
- MNIST is downloaded to
./data/ - The dataset is split into 90% train / 10% validation
- Checkpoints are saved to
models/after every epoch - The best-performing checkpoint is kept as
models/cnn_best.pth - Training logs are written to
runs/experiment/train.log - TensorBoard metrics are written to
runs/experiment/
Monitor live with:
tensorboard --logdir runs/python main.py evalGenerates in results/:
classification_report.txt— precision / recall / F1 per digitconfusion_matrix.png— normalised heatmapmisclassified.png— grid of hardest errorstraining_curves.png— loss and accuracy over epochs
Programmatic API:
from src.predict import DigitPredictor
predictor = DigitPredictor(model_name="cnn", checkpoint="models/cnn_best.pth")
result = predictor.predict("my_digit.png")
print(f"Digit: {result['digit']}")
print(f"Confidence: {result['confidence']*100:.1f}%")
print(f"All probabilities: {result['probabilities']}")
# Save visualisation
predictor.visualise("my_digit.png", output_path="prediction.png")Batch prediction:
paths = ["digit_0.png", "digit_1.png", "digit_2.png"]
results = predictor.predict_batch(paths)
for path, res in zip(paths, results):
print(f"{path}: {res['digit']} ({res['confidence']*100:.1f}%)")python main.py demo
# Saves results/demo.png — 16 test samples with predictions (green=correct, red=wrong)All settings are in configs/config.yaml:
model: cnn # mlp | cnn | deep_cnn
epochs: 30
batch_size: 64
lr: 0.001
weight_decay: 0.0001
optimizer: adam # adam | sgd
scheduler: onecycle # onecycle | cosine
augment: true
use_amp: true # mixed-precision (CUDA only)
patience: 7 # early stopping| Model | Test Accuracy | Params | Training Time (GPU) |
|---|---|---|---|
| MLP | 98.47% | 669K | ~3 min |
| CNN | 99.31% | 422K | ~5 min |
| DeepCNN | 99.58% | 1.5M | ~8 min |
| Digit | Accuracy |
|---|---|
| 0 | 99.80% |
| 1 | 99.65% |
| 2 | 99.13% |
| 3 | 99.31% |
| 4 | 99.29% |
| 5 | 99.10% |
| 6 | 99.48% |
| 7 | 99.22% |
| 8 | 98.87% |
| 9 | 98.91% |
Results generated with seed=42. Minor variation expected across runs.
Training images are randomly distorted to improve generalisation to real-world handwriting:
- RandomAffine — ±10° rotation, ±10% translation, 90–110% scale, 5° shear
- RandomPerspective — Projective distortion (30% probability)
- Loss: Cross-entropy with label smoothing (ε=0.05) to prevent overconfident predictions
- Optimiser: Adam with weight decay (L2 regularisation)
- Scheduler: OneCycleLR — linearly warms up to peak LR, then anneals via cosine
- Gradient clipping: max-norm = 1.0 to prevent exploding gradients
- AMP: PyTorch automatic mixed-precision (FP16 forward pass, FP32 master weights)
All random sources are seeded: Python random, NumPy, PyTorch CPU/CUDA, and PYTHONHASHSEED. CuDNN is set to deterministic mode.
# Run all tests
pytest tests/ -v
# Run with coverage report
pytest tests/ -v --cov=src --cov-report=html
# Run a specific test class
pytest tests/test_model.py::TestCNN -vTest coverage includes:
- Forward pass output shapes for all architectures
- Gradient flow verification
- NaN detection
AverageMetercorrectnessEarlyStoppingtrigger logicget_modelfactory (valid and invalid names)
Pull requests are welcome. For major changes, please open an issue first.
- Fork the repository
- Create a feature branch (
git checkout -b feature/your-feature) - Commit your changes (
git commit -m 'Add your feature') - Push to the branch (
git push origin feature/your-feature) - Open a Pull Request
Please make sure all tests pass before submitting.
Kunal Khallar
Software Engineer & Entrepreneur
Ludhiana, Punjab, India
- GitHub: @khallark
- LinkedIn: linkedin.com/in/kunal-khallar-6a55b7359
- Email: k.khallar@op.iitg.ac.in
Built with ❤️ and PyTorch
⭐ Star this repo if you found it helpful!