# MNIST Classification with Transfer Learning (Final Notebook)

This notebook runs the complete pipeline using the updated script `mnist_transfer_learning_comparison.py`.

It will:
- Train LeNet-5 on MNIST
- Train transfer learning models (VGG16, ResNet50, MobileNet)
- Load local images `0.jpeg` and `1.jpeg`, preprocess and predict
- Compare results across models

## 1. Environment and Imports

In [1]:
%pip -q install tensorflow matplotlib opencv-python --upgrade
import os, sys, pathlib
from IPython.display import display

# Ensure the script directory is on sys.path
project_dir = pathlib.Path(r'/').resolve()
if str(project_dir) not in sys.path:
    sys.path.append(str(project_dir))
print('Project dir set to:', project_dir)

Note: you may need to restart the kernel to use updated packages.
Project dir set to: C:\


## 2. Import pipeline functions from the script

In [2]:
from mnist_transfer_learning_comparison import (
    load_and_prepare_mnist,
    create_lenet5_model,
    train_lenet5,
    train_transfer_learning_models,
    get_local_test_images,
    predict_external_images,
    compare_results
)
print('Functions imported from mnist_transfer_learning_comparison.py')

Functions imported from mnist_transfer_learning_comparison.py


## 3. Verify local test images (0.jpeg and 1.jpeg)

In [6]:
import cv2, matplotlib.pyplot as plt
img_names = ['0.jpeg', '1.jpeg']
for name in img_names:
    path = project_dir / name
    print(f'Checking {path} ->', path.exists())
    if path.exists():
        img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
        plt.figure(figsize=(3,3))
        plt.imshow(img, cmap='gray')
        plt.title(name)
        plt.axis('off')
        plt.show()
    else:
        print('WARNING: Image not found. Place the file in:', project_dir)

Checking C:\0.jpeg -> False
Checking C:\1.jpeg -> False


## 4. Load MNIST and train LeNet-5

In [None]:
(x_train, y_train, y_train_cat), (x_test, y_test, y_test_cat) = load_and_prepare_mnist()
lenet = create_lenet5_model()
lenet_history, lenet_acc = train_lenet5(lenet, x_train, y_train_cat, x_test, y_test_cat)
print(f'LeNet-5 accuracy: {lenet_acc:.4f}')

## 5. Train transfer learning models (VGG16, ResNet50, MobileNet)

In [None]:
transfer_results = train_transfer_learning_models(x_train, y_train_cat, x_test, y_test_cat)
all_models = {'LeNet-5': {'model': lenet, 'test_accuracy': lenet_acc, 'history': lenet_history}}
all_models.update(transfer_results)
print('Transfer learning training completed.')

## 6. Predict on local images 0.jpeg and 1.jpeg

In [None]:
test_images = get_local_test_images()
if not test_images:
    print('No local images detected. Ensure 0.jpeg and 1.jpeg are in:', project_dir)
else:
    all_predictions = predict_external_images(all_models, test_images)
    print('Predictions completed.')

## 7. Compare results

In [None]:
compare_results(lenet_acc, transfer_results)
print('Done.')