In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CIFAR-10 Classification Notebook\n",
    "\n",
    "This notebook demonstrates the evaluation and visualization of the trained CIFAR-10 CNN model.\n",
    "It loads the preprocessed data and the trained model, evaluates its performance on the test set,\n",
    "and visualizes example predictions.\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Import necessary libraries\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Load the data preprocessing function from the src package\n",
    "from src.data_preprocessing import load_and_preprocess_data\n",
    "\n",
    "# Load CIFAR-10 data (already preprocessed)\n",
    "(x_train, y_train), (x_test, y_test) = load_and_preprocess_data()\n",
    "\n",
    "# Print the shape of the test data to verify\n",
    "print(f\"Test data shape: {x_test.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Import the load_model function from Keras\n",
    "from tensorflow.keras.models import load_model\n",
    "\n",
    "# Load the trained model (ensure that 'cifar10_cnn_model.h5' exists in the project root)\n",
    "model = load_model('cifar10_cnn_model.h5')\n",
    "\n",
    "# Evaluate the model's performance on the test set\n",
    "test_loss, test_acc = model.evaluate(x_test, y_test)\n",
    "print(f\"Test Accuracy: {test_acc * 100:.2f}%\")\n",
    "print(f\"Test Loss: {test_loss:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Make predictions on a subset of test images\n",
    "predictions = model.predict(x_test[:5])\n",
    "\n",
    "# Define the CIFAR-10 class names for easier interpretation of predictions\n",
    "class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n",
    "\n",
    "# Visualize the first 5 test images and their predicted and true labels\n",
    "fig, axes = plt.subplots(1, 5, figsize=(15, 3))\n",
    "for i, ax in enumerate(axes):\n",
    "    ax.imshow(x_test[i])\n",
    "    predicted_label = np.argmax(predictions[i])\n",
    "    true_label = y_test[i][0]\n",
    "    ax.set_title(f\"Pred: {class_names[predicted_label]}, True: {class_names[true_label]}\")\n",
    "    ax.axis('off')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}