# AlexNet in TensorFlow

In this notebook, we leverage an [AlexNet](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks)-like deep, convolutional neural network to classify MNIST digits.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jonkrohn/DLTFpT/blob/master/notebooks/alexnet_in_tensorflow.ipynb)

#### Load dependencies

In [0]:
# Install TensorFlow using Colab's tensorflow_version command
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

In [0]:
 !pip freeze | grep tensorflow

mesh-tensorflow==0.1.9
tensorflow==2.1.0
tensorflow-addons==0.6.0
tensorflow-datasets==2.0.0
tensorflow-estimator==2.1.0
tensorflow-federated==0.11.0
tensorflow-gan==2.0.0
tensorflow-gcs-config==2.1.6
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-model-optimization==0.1.3
tensorflow-privacy==0.2.2
tensorflow-probability==0.9.0


In [0]:
import tensorflow
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.layers import BatchNormalization

#### Load data

In [0]:
(X_train, y_train), (X_valid, y_valid) = mnist.load_data()

#### Preprocess data

In [0]:
X_train = X_train.reshape(60000, 28, 28, 1).astype('float32')
X_valid = X_valid.reshape(10000, 28, 28, 1).astype('float32')

In [0]:
X_train /= 255
X_valid /= 255

In [0]:
n_classes = 10
y_train = to_categorical(y_train, n_classes)
y_valid = to_categorical(y_valid, n_classes)

#### Design neural network architecture

In [0]:
model = Sequential()

# first conv-pool block: 
model.add(Conv2D(96, kernel_size=(11, 11), strides=(1, 1), activation='relu', input_shape=(28, 28, 1)))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(1, 1)))
model.add(BatchNormalization())

# second conv-pool block: 
model.add(Conv2D(256, kernel_size=(5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(1, 1)))
model.add(BatchNormalization())

# third conv-pool block: 
model.add(Conv2D(256, kernel_size=(3, 3), activation='relu'))
model.add(Conv2D(384, kernel_size=(3, 3), activation='relu'))
model.add(Conv2D(384, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(1, 1)))
model.add(BatchNormalization())

# dense layers: 
model.add(Flatten())
model.add(Dense(4096, activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(4096, activation='tanh'))
model.add(Dropout(0.5))

# output layer: 
model.add(Dense(10, activation='softmax'))

In [0]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 18, 18, 96)        11712     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 16, 16, 96)        0         
_________________________________________________________________
batch_normalization (BatchNo (None, 16, 16, 96)        384       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 12, 12, 256)       614656    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 10, 10, 256)       0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 10, 10, 256)       1024      
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 8, 8, 256)         5

#### Configure model

In [0]:
model.compile(loss='categorical_crossentropy', optimizer='nadam', metrics=['accuracy'])

#### Train!

In [0]:
model.fit(X_train, y_train, batch_size=128, epochs=1, verbose=1, validation_data=(X_valid, y_valid))

Train on 60000 samples, validate on 10000 samples


<tensorflow.python.keras.callbacks.History at 0x7f9e870095f8>