In [9]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np
import cv2
from tensorflow.keras import backend as K


############################ Parameters  ####################################################
epochs = 20
batch_size = 64
validation_split = 0.2
optimizer = 'adam'
loss= 'categorical_crossentropy'
metrics= 'accuracy'
#############################################################################################

# Load and preprocess the CIFAR-10 dataset
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0

# One-hot encode the labels
train_labels = to_categorical(train_labels, num_classes=10)
test_labels = to_categorical(test_labels, num_classes=10)

# Build the CNN model
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

# Compile the model
model.compile(optimizer, loss, metrics)

# Model architecture
model.summary()

# Train the model and capture history for plotting
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6)
history = model.fit(train_images, train_labels, epochs=epochs, batch_size=batch_size, validation_split=validation_split, callbacks=[reduce_lr])

# Evaluate the model on the test set
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f'Test accuracy: {test_acc}')

# Make predictions on new data
predictions = model.predict(test_images[:5])
print('Predictions:', predictions.argmax(axis=1))
print('True labels:', test_labels[:5].argmax(axis=1))

# Plot confusion matrix
test_predictions = model.predict(test_images)
conf_matrix = confusion_matrix(test_labels.argmax(axis=1), test_predictions.argmax(axis=1))
disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=range(10))
disp.plot(cmap='Blues', values_format='d')
plt.title('Confusion Matrix')
plt.show()

# Visualize some sample predictions
sample_indices = np.random.choice(len(test_images), 5, replace=False)
sample_images = test_images[sample_indices]
sample_labels = test_labels[sample_indices]

sample_predictions = model.predict(sample_images)
sample_predictions_labels = sample_predictions.argmax(axis=1)
sample_true_labels = sample_labels.argmax(axis=1)

plt.figure(figsize=(12, 6))
for i in range(5):
    plt.subplot(2, 5, i + 1)
    plt.imshow(sample_images[i])
    plt.title(f"Pred: {sample_predictions_labels[i]}\nTrue: {sample_true_labels[i]}")
    plt.axis('off')

# Plot learning rate
plt.figure(figsize=(8, 4))
plt.plot(history.history['lr'])
plt.title('Learning Rate')
plt.xlabel('Epoch')
plt.ylabel('LR')
plt.show()

# Visualize filters from the first convolutional layer
layer_outputs = [layer.output for layer in model.layers[:3]]  # Get the outputs of the first three layers
activation_model = models.Model(inputs=model.input, outputs=layer_outputs)  # Create a model that will return these outputs, given the model input

# Select an image from the test set
img = test_images[0]
img = np.expand_dims(img, axis=0)  # Add batch dimension

# Get activations of the first convolutional layer
activations = activation_model.predict(img)

# Plot filters
plt.figure(figsize=(12, 4))
for i in range(32):  # Assuming 32 filters in the first convolutional layer
    plt.subplot(4, 8, i + 1)
    plt.imshow(activations[0][0, :, :, i], cmap='viridis')
    plt.axis('off')
plt.suptitle('Filters from the First Convolutional Layer')
plt.show()

# Plot information graphic
plt.figure(figsize=(12, 8))

# Model architecture
plt.subplot(2, 3, 1)
plt.text(0.5, 0.5, 'Model Architecture', fontsize=14, ha='center', va='center', color='white')
plt.axis('off')

# Training & Validation Accuracy
plt.subplot(2, 3, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Training & Validation Loss
plt.subplot(2, 3, 3)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Confusion Matrix
plt.subplot(2, 3, 4)
disp.plot(cmap='Blues', values_format='d')
plt.title('Confusion Matrix')

# Sample Predictions
plt.subplot(2, 3, 5)
for i in range(5):
    plt.text(i * 0.2, 0.7, f"Pred: {sample_predictions_labels[i]}\nTrue: {sample_true_labels[i]}", fontsize=10, ha='left', va='center')
    plt.imshow(sample_images[i])
    plt.axis('off')
plt.title('Sample Predictions')

# Learning Rate
plt.subplot(2, 3, 6)
plt.plot(history.history['lr'])
plt.title('Learning Rate Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('LR')

plt.tight_layout()
plt.show()


Model: "sequential_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_21 (Conv2D)          (None, 30, 30, 32)        896       
                                                                 
 max_pooling2d_14 (MaxPooli  (None, 15, 15, 32)        0         
 ng2D)                                                           
                                                                 
 conv2d_22 (Conv2D)          (None, 13, 13, 64)        18496     
                                                                 
 max_pooling2d_15 (MaxPooli  (None, 6, 6, 64)          0         
 ng2D)                                                           
                                                                 
 conv2d_23 (Conv2D)          (None, 4, 4, 128)         73856     
                                                                 
 flatten_7 (Flatten)         (None, 2048)             