In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical

# Residual Block (Basic)
def residual_block(x, filters, stride=1, use_projection=False):
    shortcut = x

    # First convolution
    x = Conv2D(filters, kernel_size=(3, 3), strides=stride, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # Second convolution
    x = Conv2D(filters, kernel_size=(3, 3), strides=1, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)

    # Projection shortcut if needed
    if use_projection:
        shortcut = Conv2D(filters, kernel_size=(1, 1), strides=stride, use_bias=False)(shortcut)
        shortcut = BatchNormalization()(shortcut)

    # Add shortcut
    x = Add()([x, shortcut])
    x = Activation('relu')(x)

    return x

# ResNet Model
def ResNet18(input_shape=(32, 32, 3), num_classes=10):  # Adjust input shape and number of classes for CIFAR-10
    inputs = Input(shape=input_shape)

    # Initial layers
    x = Conv2D(64, kernel_size=(3, 3), strides=1, padding='same', use_bias=False)(inputs)  # Change kernel size and stride
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # Residual blocks
    x = residual_block(x, filters=64)
    x = residual_block(x, filters=64)

    x = residual_block(x, filters=128, stride=2, use_projection=True)
    x = residual_block(x, filters=128)

    x = residual_block(x, filters=256, stride=2, use_projection=True)
    x = residual_block(x, filters=256)

    x = residual_block(x, filters=512, stride=2, use_projection=True)
    x = residual_block(x, filters=512)

    # Global average pooling and output layer
    x = GlobalAveragePooling2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    # Model
    model = Model(inputs, outputs)
    return model

# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0  # Normalize pixel values
y_train, y_test = to_categorical(y_train, 10), to_categorical(y_test, 10)  # One-hot encoding

# Instantiate and compile the model
model = ResNet18(input_shape=(32, 32, 3), num_classes=10)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=5, batch_size=64)  # Reduce epochs for small dataset


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 0us/step
Epoch 1/5
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1301s[0m 2s/step - accuracy: 0.4536 - loss: 1.5673 - val_accuracy: 0.5453 - val_loss: 1.4001
Epoch 2/5
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1505s[0m 2s/step - accuracy: 0.7243 - loss: 0.7797 - val_accuracy: 0.6821 - val_loss: 0.9243
Epoch 3/5
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1494s[0m 2s/step - accuracy: 0.8038 - loss: 0.5644 - val_accuracy: 0.7394 - val_loss: 0.7648
Epoch 4/5
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1520s[0m 2s/step - accuracy: 0.8509 - loss: 0.4281 - val_accuracy: 0.7646 - val_loss: 0.7510
Epoch 5/5
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1196s[0m 2s/step - accuracy: 0.8896 - loss: 0.3137 - val_accuracy: 0.7530 - val_loss: 0.8334


<keras.src.callbacks.history.History at 0x167e84ff9e0>