# CIFAR-10 Image Classification with CNN (Optimized for Faster Training)
This notebook builds and trains a Convolutional Neural Network on the CIFAR-10 dataset using TensorFlow and Keras. We use data augmentation and early stopping for better generalization and also subsample the training dataset for faster execution.

## 1. Import Required Libraries

In [1]:

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


## 2. Load and Preprocess the CIFAR-10 Dataset

In [2]:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

y_train = y_train.flatten()
y_test = y_test.flatten()


## 3. Define Class Labels

In [3]:

class_names = ['airplane','automobile','bird','cat','deer',
               'dog','frog','horse','ship','truck']


## 4. Build a Deep CNN Model

In [4]:

model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)),
    layers.BatchNormalization(),
    layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),

    layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),

    layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),

    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.BatchNormalization(),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')
])


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


## 5. Compile the Model

In [5]:

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])


## 6. Split and Subsample the Training Data

In [6]:

# Split training into train and validation sets
x_train_new, x_val, y_train_new, y_val = train_test_split(
    x_train, y_train, test_size=0.2, random_state=42
)

# Subsample 60% of the training data
sample_size = int(0.6 * len(x_train_new))
x_sample = x_train_new[:sample_size]
y_sample = y_train_new[:sample_size]


## 7. Set Up Data Augmentation and Early Stopping

In [7]:

datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True
)
datagen.fit(x_sample)

early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)


## 8. Train the Model (10–20 Epochs for Speed)

In [8]:

history = model.fit(
    datagen.flow(x_sample, y_sample, batch_size=64),
    validation_data=(x_val, y_val),
    epochs=20,
    callbacks=[early_stop],
    verbose=2
)


  self._warn_if_super_not_called()


Epoch 1/20
375/375 - 86s - 230ms/step - accuracy: 0.3322 - loss: 2.0019 - val_accuracy: 0.2202 - val_loss: 2.5131
Epoch 2/20
375/375 - 78s - 209ms/step - accuracy: 0.4642 - loss: 1.4908 - val_accuracy: 0.5279 - val_loss: 1.3526
Epoch 3/20


KeyboardInterrupt: 

## 9. Evaluate the Model

In [None]:

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print("Test accuracy:", test_acc)


## 10. Classification Report and Confusion Matrix

In [None]:

y_pred_probs = model.predict(x_test)
y_pred = np.argmax(y_pred_probs, axis=1)

report = classification_report(y_test, y_pred, target_names=class_names)
print("Classification Report:\n", report)

cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix Heatmap')
plt.show()


## 11. Accuracy and Loss Plots

In [None]:

plt.figure(figsize=(6, 4))
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training vs Validation Accuracy')
plt.show()

plt.figure(figsize=(6, 4))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training vs Validation Loss')
plt.show()


## 12. Sample Predictions

In [None]:

def show_sample_predictions(x_data, y_true, y_pred, class_names, num_samples=10):
    plt.figure(figsize=(15, 5))
    for i in range(num_samples):
        plt.subplot(2, num_samples // 2, i + 1)
        plt.imshow(x_data[i])
        plt.title(f"Pred: {class_names[y_pred[i]]}\nTrue: {class_names[y_true[i]]}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

show_sample_predictions(x_test, y_test, y_pred, class_names)


## 13. Show Misclassified Images

In [None]:

def show_misclassified_images(x_data, y_true, y_pred, class_names, num_images=10):
    misclassified_indices = np.where(y_true != y_pred)[0]
    plt.figure(figsize=(15, 5))
    for i, idx in enumerate(misclassified_indices[:num_images]):
        plt.subplot(2, num_images // 2, i + 1)
        plt.imshow(x_data[idx])
        plt.title(f"Pred: {class_names[y_pred[idx]]}\nTrue: {class_names[y_true[idx]]}", color='red')
        plt.axis('off')
    plt.tight_layout()
    plt.show()

show_misclassified_images(x_test, y_test, y_pred, class_names)


## 14. Predict a Specific Image

In [None]:

index = 5
img = x_test[index]
true_label = y_test[index]

img_input = np.expand_dims(img, axis=0)
pred_probs = model.predict(img_input)
predicted_label = np.argmax(pred_probs[0])

plt.imshow(img)
plt.title(f"True: {class_names[true_label]} | Predicted: {class_names[predicted_label]}")
plt.axis('off')
plt.show()
