In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, datasets # type: ignore
import numpy as np
import matplotlib.pyplot as plt

In [None]:
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

# Normalize the data to range of [0, 1]
x_train, x_test = x_train / 255.0, x_test / 255.0

# Convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

# Print the shape of the datasets
print(f'x_train shape: {x_train.shape}')
print(f'y_train shape: {y_train.shape}')
print(f'x_test shape: {x_test.shape}')
print(f'y_test shape: {y_test.shape}')

In [None]:
# Define the resnet architecture
def resnet_block(input_layer, filters, kernal_size= 3, stride= 1):
    x = layers.Conv2D(filters, kernal_size, strides=stride, padding='same')(input_layer)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    x = layers.Conv2D(filters, kernal_size, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    if stride != 1:
        input_layer = layers.Conv2D(filters, 1, strides= stride)(input_layer)
        
    x = layers.add([x, input_layer])
    x = layers.Activation('relu')(x)
    return x


In [None]:
def create_resnet():
    input_layer = layers.Input(shape=(32, 32, 3))
    x = layers.Conv2D(64, 3, padding='same')(input_layer)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    x = resnet_block(x, 64)
    x = resnet_block(x, 64)
    
    x = resnet_block(x, 128, stride=2)
    x = resnet_block(x, 128)
    
    x = resnet_block(x, 256, stride=2)
    x = resnet_block(x, 256)
    
    x = layers.GlobalAveragePooling2D()(x)
    output_layer = layers.Dense(10, activation='softmax')(x)
    
    model = models.Model(inputs = input_layer, outputs = output_layer)
    return model

model = create_resnet()
model.summary()

In [None]:
model.compile(optimizer= 'adam',
              loss = 'categorical_crossentropy',
              metrics = ['accuracy'])

In [None]:
history = model.fit(x_train, y_train, epochs=10, batch_size=64, validation_data=(x_test, y_test))

In [None]:
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_acc}')


In [None]:
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()
