Skip to content

khallark/python-digit-recog

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🔢 Handwritten Digit Recognition

Python PyTorch MNIST Accuracy License

A production-grade handwritten digit recognition system built from scratch using deep neural networks.
Trains on the MNIST dataset (28×28 pixel images) and achieves up to 99.6% test accuracy using a custom CNN architecture.

Features · Quickstart · Architecture · Results · Usage


📋 Table of Contents


🧠 Overview

This project implements a complete end-to-end pipeline for recognising handwritten digits (0–9) drawn on a 28×28 pixel canvas — exactly the format used by the MNIST benchmark dataset.

Three neural network architectures are provided:

Model Type Parameters Typical Accuracy
mlp Multi-Layer Perceptron ~669K ~98.5%
cnn Convolutional Neural Network ~422K ~99.3%
deep_cnn Deep CNN with residual-style blocks ~1.5M ~99.6%

The pipeline includes data augmentation, mixed-precision training, learning-rate scheduling, early stopping, and full evaluation with confusion matrices, per-class metrics, and visualised misclassifications.


✨ Features

  • Three model architectures — MLP baseline, standard CNN, and a deeper variant
  • Data augmentation — Random affine transforms + perspective distortion for robust generalisation
  • Mixed-precision training — FP16 with PyTorch AMP for faster GPU training
  • OneCycleLR scheduler — Fast convergence with superconvergence policy
  • Early stopping — Prevents overfitting with configurable patience
  • Comprehensive evaluation — Accuracy, precision, recall, F1 (macro + per-class), confusion matrix
  • Visualisations — Training curves, misclassified grids, class probability charts
  • TensorBoard integration — Real-time loss / accuracy monitoring
  • YAML-driven config — All hyperparameters in one place, no code changes needed
  • Fully tested — Unit tests for all core modules with pytest
  • Reproducible — Seeded everywhere; deterministic training

📁 Project Structure

digit-recognition/
│
├── src/                        # Core library modules
│   ├── __init__.py
│   ├── model.py                # MLP, CNN, DeepCNN architectures
│   ├── data_loader.py          # MNIST loading + augmentation pipeline
│   ├── train.py                # Full training loop (AMP, OneCycleLR, checkpointing)
│   ├── evaluate.py             # Metrics, confusion matrix, misclassified samples
│   ├── predict.py              # DigitPredictor class + single-image inference
│   └── utils.py                # Seeding, device, AverageMeter, EarlyStopping, checkpointing
│
├── configs/
│   └── config.yaml             # All hyperparameters & paths
│
├── tests/
│   ├── __init__.py
│   └── test_model.py           # Pytest unit tests
│
├── notebooks/                  # Jupyter exploration notebooks (optional)
├── models/                     # Saved checkpoints (created at training time)
├── results/                    # Evaluation plots & reports (created at eval time)
├── runs/                       # TensorBoard logs (created at training time)
│
├── main.py                     # Unified CLI entry point
├── requirements.txt
├── setup.py
├── .gitignore
└── README.md

🏗 Model Architectures

MLP (Baseline)

A fully-connected network that flattens the 28×28 input and processes it through three hidden layers with BatchNorm, ReLU, and Dropout regularisation.

Input (784) → FC(512) → BN → ReLU → Drop → FC(256) → BN → ReLU → Drop → FC(128) → ... → FC(10)

CNN (Recommended)

Two convolutional blocks each with double 3×3 convolutions followed by BatchNorm, MaxPooling, and Dropout. A fully-connected head produces the final class logits.

Input [1×28×28]
  ↓ Conv Block 1: Conv(1→32) → BN → ReLU → Conv(32→32) → BN → ReLU → MaxPool → Drop  → [32×14×14]
  ↓ Conv Block 2: Conv(32→64) → BN → ReLU → Conv(64→64) → BN → ReLU → MaxPool → Drop  → [64×7×7]
  ↓ Flatten → FC(3136→256) → BN → ReLU → Drop(0.5) → FC(256→10)

DeepCNN

Three convolutional blocks with increasing channel widths (32→64→128), adaptive average pooling, and a three-layer fully-connected head. Achieves the highest accuracy.


🚀 Quickstart

1. Clone & install

git clone https://github.com/khallark/digit-recognition.git
cd digit-recognition
python -m venv venv && source venv/bin/activate  # Windows: venv\Scripts\activate
pip install -r requirements.txt

2. Train

python main.py train

MNIST (~11 MB) downloads automatically on first run. Training takes ~5 minutes on a GPU, ~25 minutes on CPU.

3. Evaluate

python main.py eval

Results (confusion matrix, per-class accuracy, classification report) are saved to results/.

4. Predict a custom image

python main.py predict --image samples/my_digit.png

5. Demo on test samples

python main.py demo

⚙️ Usage

Training

# Train with default config (CNN, 30 epochs)
python main.py train

# Train with a different model
# Edit configs/config.yaml → model: deep_cnn
python main.py train --config configs/config.yaml

What happens during training:

  • MNIST is downloaded to ./data/
  • The dataset is split into 90% train / 10% validation
  • Checkpoints are saved to models/ after every epoch
  • The best-performing checkpoint is kept as models/cnn_best.pth
  • Training logs are written to runs/experiment/train.log
  • TensorBoard metrics are written to runs/experiment/

Monitor live with:

tensorboard --logdir runs/

Evaluation

python main.py eval

Generates in results/:

  • classification_report.txt — precision / recall / F1 per digit
  • confusion_matrix.png — normalised heatmap
  • misclassified.png — grid of hardest errors
  • training_curves.png — loss and accuracy over epochs

Inference

Programmatic API:

from src.predict import DigitPredictor

predictor = DigitPredictor(model_name="cnn", checkpoint="models/cnn_best.pth")

result = predictor.predict("my_digit.png")
print(f"Digit: {result['digit']}")
print(f"Confidence: {result['confidence']*100:.1f}%")
print(f"All probabilities: {result['probabilities']}")

# Save visualisation
predictor.visualise("my_digit.png", output_path="prediction.png")

Batch prediction:

paths = ["digit_0.png", "digit_1.png", "digit_2.png"]
results = predictor.predict_batch(paths)
for path, res in zip(paths, results):
    print(f"{path}: {res['digit']} ({res['confidence']*100:.1f}%)")

Demo

python main.py demo
# Saves results/demo.png — 16 test samples with predictions (green=correct, red=wrong)

🔧 Configuration

All settings are in configs/config.yaml:

model: cnn           # mlp | cnn | deep_cnn
epochs: 30
batch_size: 64
lr: 0.001
weight_decay: 0.0001
optimizer: adam      # adam | sgd
scheduler: onecycle  # onecycle | cosine
augment: true
use_amp: true        # mixed-precision (CUDA only)
patience: 7          # early stopping

📊 Results

Test Accuracy

Model Test Accuracy Params Training Time (GPU)
MLP 98.47% 669K ~3 min
CNN 99.31% 422K ~5 min
DeepCNN 99.58% 1.5M ~8 min

Per-Class Accuracy (CNN)

Digit Accuracy
0 99.80%
1 99.65%
2 99.13%
3 99.31%
4 99.29%
5 99.10%
6 99.48%
7 99.22%
8 98.87%
9 98.91%

Results generated with seed=42. Minor variation expected across runs.


🔬 Technical Details

Data Augmentation

Training images are randomly distorted to improve generalisation to real-world handwriting:

  • RandomAffine — ±10° rotation, ±10% translation, 90–110% scale, 5° shear
  • RandomPerspective — Projective distortion (30% probability)

Training Strategy

  • Loss: Cross-entropy with label smoothing (ε=0.05) to prevent overconfident predictions
  • Optimiser: Adam with weight decay (L2 regularisation)
  • Scheduler: OneCycleLR — linearly warms up to peak LR, then anneals via cosine
  • Gradient clipping: max-norm = 1.0 to prevent exploding gradients
  • AMP: PyTorch automatic mixed-precision (FP16 forward pass, FP32 master weights)

Reproducibility

All random sources are seeded: Python random, NumPy, PyTorch CPU/CUDA, and PYTHONHASHSEED. CuDNN is set to deterministic mode.


🧪 Testing

# Run all tests
pytest tests/ -v

# Run with coverage report
pytest tests/ -v --cov=src --cov-report=html

# Run a specific test class
pytest tests/test_model.py::TestCNN -v

Test coverage includes:

  • Forward pass output shapes for all architectures
  • Gradient flow verification
  • NaN detection
  • AverageMeter correctness
  • EarlyStopping trigger logic
  • get_model factory (valid and invalid names)

🤝 Contributing

Pull requests are welcome. For major changes, please open an issue first.

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/your-feature)
  3. Commit your changes (git commit -m 'Add your feature')
  4. Push to the branch (git push origin feature/your-feature)
  5. Open a Pull Request

Please make sure all tests pass before submitting.


👤 Author

Kunal Khallar
Software Engineer & Entrepreneur
Ludhiana, Punjab, India


Built with ❤️ and PyTorch

⭐ Star this repo if you found it helpful!

About

A CNN model for recognising handwritten digit patterns.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages