# Multi-Class Classification: CIFAR-10 dataset (Transfer Learning)
---

In [None]:
import os
import tensorflow as tf, matplotlib.pyplot as plt, numpy as np, seaborn as sns
import sklearn

In [None]:
(training_images, training_labels), (validation_images, validation_labels) = tf.keras.datasets.cifar10.load_data()

In [None]:
num_images = 8
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

random_indices = np.random.choice(len(training_images), size=num_images)
selected_images = training_images[random_indices]
selected_labels = training_labels[random_indices].flatten(); selected_classes = [classes[i] for i in selected_labels]

cols=4; rows=num_images//cols

plt.figure(figsize=(cols*3, rows*3), dpi=100)
for i in range(num_images):
    plt.subplot(rows, cols, i+1); plt.imshow(selected_images[i]); plt.title(f"{selected_classes[i]}"); plt.axis("off")

plt.suptitle("training_images"); plt.tight_layout(); plt.show()

In [None]:
def preprocess_images(images):
    # Normalize pixel values to be between 0 and 1
    images = images.astype("float32")
    images = tf.keras.applications.resnet50.preprocess_input(images)
    return images

train_X = preprocess_images(training_images)
train_Y = tf.keras.utils.to_categorical(training_labels, num_classes=10)

validation_X = preprocess_images(validation_images)
validation_Y = tf.keras.utils.to_categorical(validation_labels, num_classes=10)

In [None]:
def feature_extractor(inputs):
    feature_extractor = tf.keras.applications.ResNet50(input_shape=(224, 224, 3), include_top=False, weights='imagenet')(inputs)
    feature_extractor.trainable = False
    return feature_extractor

def classifier(inputs):
    x = tf.keras.layers.GlobalAveragePooling2D()(inputs)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1024, activation='relu')(x)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    outputs = tf.keras.layers.Dense(10, activation='softmax', name="Classification")(x)
    return outputs

def final_model(inputs):
    resize = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
    x = feature_extractor(resize)
    x = classifier(x)
    return x

def define_compile_model():
    inputs = tf.keras.Input(shape=(32, 32, 3))
    classification_outputs = final_model(inputs)
    model = tf.keras.Model(inputs, classification_outputs, name="CIFAR10_Classifier")
    model.compile(optimizer="SGD", loss='categorical_crossentropy', metrics=['accuracy'])
    return model

model = define_compile_model()
model.summary()

In [None]:
EPOCHS = 5
history = model.fit(train_X, train_Y, epochs=EPOCHS,
                    validation_data=(validation_X, validation_Y),
                    verbose=1, batch_size=32)

In [None]:
preds = model.predict(validation_X, verbose=1)
pred_labels = np.argmax(preds, axis=1)

In [None]:
loss, accuracy = model.evaluate(validation_X, validation_Y, batch_size=32, verbose=1)
print(f"Validation Loss: {loss}")
print(f"Validation Accuracy: {accuracy}")

In [None]:
acc=history.history['accuracy']
val_acc=history.history['val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss']

epochs=range(len(acc)) # Get number of epochs

figure = plt.figure(figsize=(12, 4), dpi=300)

plt.subplot(1, 2, 1)
plt.plot(epochs, acc, label="Training Accuracy"); plt.plot(epochs, val_acc, label="Validation Accuracy")
plt.title('Training and validation accuracy'); plt.xlabel('Epochs'); plt.ylabel('Accuracy')
plt.legend(); plt.grid()

plt.subplot(1, 2, 2)
plt.plot(epochs, loss, label="Training Loss"); plt.plot(epochs, val_loss, label="Validation Loss")
plt.title('Training and validation loss'); plt.xlabel('Epochs'); plt.ylabel('Loss')
plt.legend(); plt.grid()

plt.tight_layout(); plt.show()

In [None]:
from sklearn.metrics import classification_report
report = classification_report(validation_labels, pred_labels, target_names=classes)

print("📊 Classification Report")
print("=======================")
print(report)

In [None]:
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(validation_labels, pred_labels)

plt.figure(figsize=(6, 5), dpi=100)
sns.heatmap(cm, cmap='Blues', xticklabels=classes, yticklabels=classes, annot=True, fmt='d')
plt.title('Confusion Matrix', fontsize=12)
plt.xlabel('Predicted Label', fontsize=8)
plt.ylabel('True Label', fontsize=8)

plt.tight_layout()
plt.show()