Step 1: Data Preprocessing & Loading 
        Visualization of Images and Labels and Inserting Grayscale Conversion

In [72]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [73]:
#Load the data set CIFAR-10 Dataset
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

In [74]:
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

In [75]:
# Visualize the images in the CIFAR-10 dataset

# Define a list with all the class labels
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

# Initialize the figure
plt.figure(figsize=(6, 6))

image_count = 0

# Loop through class labels to pick 10 images per class
for class_index, class_name in enumerate(classes):
    class_images = x_train[y_train.flatten() == class_index][:10]

 # Loop through the images, arranging them in 10 x 10    
    for img in class_images:
        plt.subplot(10, 10, image_count + 1)
        plt.imshow(img)
        plt.axis('off')
        if image_count % 10 == 0:
            plt.ylabel(class_name, rotation=0, size='large', labelpad=50)
        image_count += 1

# Show the images
plt.show()

In [76]:
#Data Augmentation:

# Function to convert images to grayscale
def rgb_to_grayscale(x):
    x = tf.image.rgb_to_grayscale(x)
    return x

# Create an instance of ImageDataGenerator
datagen = ImageDataGenerator(
    rescale=1./255,  # Normalize the pixel values
    preprocessing_function=rgb_to_grayscale,  # Convert images to grayscale
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)


In [77]:
# Function to collect augmented data
def collect_augmented_data(datagen, x_data, y_data, batch_size=32):
    iterator = datagen.flow(x_data, y_data, batch_size=batch_size)
    augmented_images = []
    augmented_labels = []
    
    total_samples = len(x_data)
    batches_to_process = int(np.ceil(total_samples / batch_size))
    
    for _ in range(batches_to_process):
        augmented_batch, labels_batch = next(iterator)
        augmented_images.append(augmented_batch)
        augmented_labels.append(labels_batch)
    
    augmented_images = np.concatenate(augmented_images)
    augmented_labels = np.concatenate(augmented_labels)
    
    # Ensure images have a single channel by reshaping if necessary
    if augmented_images.shape[-1] == 3:  # If still in 32x32x3 shape
        augmented_images = np.mean(augmented_images, axis=-1, keepdims=True)

    return augmented_images, augmented_labels

# Collect augmented training data
augmented_x_train, augmented_y_train = collect_augmented_data(datagen, x_train, y_train)
# Collect augmented testing data
augmented_x_test, augmented_y_test = collect_augmented_data(datagen, x_test, y_test)

print("Augmented Training Images Shape:", augmented_x_train.shape)
print("Augmented Training Labels Shape:", augmented_y_train.shape)
print("Augmented Testing Images Shape:", augmented_x_test.shape)
print("Augmented Testing Labels Shape:", augmented_y_test.shape)

In [78]:
# Function to visualize augmented images
def visualize_augmented_images(images, labels, classes, title="Augmented Images"):
    plt.figure(figsize=(6, 6))
    image_count = 0

    # Loop through class labels to pick 10 images per class
    for class_index, class_name in enumerate(classes):
        class_images = images[labels.flatten() == class_index][:10]
        
        # Loop through the images, arranging them in 10 x 10    
        for img in class_images:
            if img.shape[-1] == 1:  # Handle grayscale images
                img = img.reshape(img.shape[0], img.shape[1])
            plt.subplot(10, 10, image_count + 1)
            plt.imshow(img, cmap='gray')
            plt.axis('off')
            if image_count % 10 == 0:
                plt.ylabel(class_name, rotation=0, size='large', labelpad=50)
            image_count += 1
    
    plt.suptitle(title)
    plt.show()

# Visualize augmented training images
visualize_augmented_images(augmented_x_train, augmented_y_train, classes, title="Augmented Training Images")


In [9]:
# Display the shape of the data
print(f"Training data shape: {x_train.shape}, Training labels shape: {y_train.shape}")
print(f"Test data shape: {x_test.shape}, Test labels shape: {y_test.shape}")

Training data shape: (50000, 32, 32, 3), Training labels shape: (50000, 1)
Test data shape: (10000, 32, 32, 3), Test labels shape: (10000, 1)
