# Simplifed ResNet implementation

## Loading CIFAR10 

In [1]:
import tensorflow as tf
import tensorflow.keras as keras

In [2]:
from tensorflow.keras.datasets import cifar10

(X_train, y_train), (X_test, y_test) = cifar10.load_data()

In [3]:
print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

(50000, 32, 32, 3)
(10000, 32, 32, 3)
(50000, 1)
(10000, 1)


In [4]:
# normalization
X_train, X_test = X_train/255.0, X_test/255.0

## Construct Resnet

In [5]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import ReLU
from tensorflow.keras.layers import Conv2D, MaxPool2D
from tensorflow.keras.layers import Add, AveragePooling2D 
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense

In [6]:
inputs = Input(shape=(32,32,3))
x = Conv2D(32, kernel_size=(3,3), strides=(1,1), 
           padding='same', use_bias=False)(inputs)
x = BatchNormalization()(x)
x = ReLU()(x)

x = Conv2D(32, kernel_size=(3,3), strides=(1,1),
           padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = ReLU()(x)

x = MaxPool2D(pool_size=(2,2), strides=(2,2),
              padding='valid')(x)
x = Conv2D(64, kernel_size=(3,3), strides=(1,1),
           padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = ReLU()(x)

# 1st skip connection
skip = x
x = Conv2D(64, kernel_size=(3,3), strides=(1,1),
           padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = ReLU()(x)
x = Conv2D(64, kernel_size=(3,3), strides=(1,1),
           padding='same', use_bias=False)(x)
x = BatchNormalization()(x)

x = Add()([x, skip])
x = ReLU()(x)

x = MaxPool2D(pool_size=(2,2), strides=(2,2),
              padding='valid')(x)
x = Conv2D(128, kernel_size=(3,3), strides=(1,1),
           padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = ReLU()(x)

# 2nd skip connection
skip = x
x = Conv2D(128, kernel_size=(3,3), strides=(1,1),
           padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = ReLU()(x)
x = Conv2D(128, kernel_size=(3,3), strides=(1,1),
           padding='same', use_bias=False)(x)
x = BatchNormalization()(x)

x = Add()([x, skip])
x = ReLU()(x)

# Average Pooling
x = AveragePooling2D(pool_size=(8,8))(x)
x = Flatten()(x)
x = ReLU()(x)
outputs = Dense(10, activation='softmax')(x)

Model_resnet = Model(inputs = inputs, outputs = outputs)

Model_resnet.summary()


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 32, 32, 32)   864         ['input_1[0][0]']                
                                                                                                  
 batch_normalization (BatchNorm  (None, 32, 32, 32)  128         ['conv2d[0][0]']                 
 alization)                                                                                       
                                                                                                  
 re_lu (ReLU)                   (None, 32, 32, 32)   0           ['batch_normalization[0][0]']

## Compile

In [7]:
from tensorflow.keras.optimizers import Adam

opt = Adam(learning_rate = 0.001,
          beta_1 = 0.9,
          beta_2 = 0.999)

Model_resnet.compile(optimizer = opt,
                    loss='sparse_categorical_crossentropy',
                    metrics=['acc'])

In [8]:
# training 
Model_resnet.fit(X_train,y_train,epochs=20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x18e9b520880>

In [24]:
# Evaluation
test_performance = Model_resnet.evaluate(X_test,y_test)
print(test_performance)

[0.8322671055793762, 0.8138999938964844]
