In [1]:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.utils.np_utils import to_categorical

In [None]:
!git clone https://github.com/burchim/NeuralNetsWithNumpy.git
%cd NeuralNetsWithNumpy

In [2]:
# Modules
from nnet.modules import (
  Identity,
  Sigmoid,
  Tanh,
  ReLU,
  PReLU,
  Swish,
  LayerNorm,
  BatchNorm
)

# Optimizers
from nnet.optimizers import (
  SGD,
  RMSprop,
  Adam,
  AdamW
)

# Losses
from nnet.losses import (
  MeanAbsoluteError,
  MeanSquaredError,
  SoftmaxCrossEntropy
)

# Schedulers
from nnet.schedulers import (
  ConstantScheduler,
  WarmupCosineAnnealingScheduler
)

# Model
from nnet.models import CNNModel

In [3]:
# Seed
seed = 42
np.random.seed(seed)

# dtype
dtype = np.float32

# Load Dataset
(x_train, y_train), (x_val, y_val) = mnist.load_data()

# Prepare Dataset
x_train = (x_train/255).astype(dtype)
y_train = to_categorical(y_train).astype(dtype)
x_val = (x_val/255).astype(dtype)
y_val = to_categorical(y_val).astype(dtype)

# shapes
print('x_train: ' + str(x_train.shape))
print('y_train: ' + str(y_train.shape))
print('x_val:  '  + str(x_val.shape))
print('y_val:  '  + str(y_val.shape))
print()

x_train: (60000, 28, 28)
y_train: (60000, 10)
x_val:  (10000, 28, 28)
y_val:  (10000, 10)



In [4]:
# Params
epochs = 5
batch_size = 32
optimizer = AdamW
scheduler = WarmupCosineAnnealingScheduler

in_height = 28
in_width = 28
in_dim = 1

dim_cnn_layers = [16, 32]
kernel_size = (3, 3)
strides = [[2, 2], [2, 2]]
out_dim = 10
dim_mlp_layers = [128, out_dim]
norm = BatchNorm
drop_rate = 0
hidden_function = ReLU
out_function = Identity
loss_function = SoftmaxCrossEntropy

In [5]:
# Create Model
model = CNNModel(
  in_height=in_height, 
  in_width=in_width, 
  in_dim=in_dim, 
  dim_cnn_layers=dim_cnn_layers,
  kernel_size=kernel_size,
  strides=strides,
  dim_mlp_layers=dim_mlp_layers,
  hidden_function=hidden_function,
  out_function=out_function,
  norm=norm, 
  drop_rate=drop_rate,
  loss_function=loss_function,
  dtype=dtype
)

model.optimizer = optimizer(model.get_parameters(), lr=0, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.1)
model.scheduler = scheduler(model.optimizer, warmup_steps=len(x_train) // batch_size, lr_max=0.005, lr_min=0.00005, end_step=5*len(x_train) // batch_size)

# Model Summary
model.summary()

207274 parameters
cnn.0.Conv2d.weight                      shape (3, 3, 1, 16)    mean -0.0227      std 0.3420       dtype float32     
cnn.0.Conv2d.bias                        shape (16,)            mean 0.0000       std 0.0000       dtype float32     
cnn.0.BatchNorm.gamma                    shape (16,)            mean 1.0000       std 0.0000       dtype float32     
cnn.0.BatchNorm.beta                     shape (16,)            mean 0.0000       std 0.0000       dtype float32     
cnn.0.BatchNorm.moving_var               shape (16,)            mean 1.0000       std 0.0000       dtype float32     
cnn.0.BatchNorm.moving_mean              shape (16,)            mean 0.0000       std 0.0000       dtype float32     
cnn.1.Conv2d.weight                      shape (3, 3, 16, 32)   mean -0.0007      std 0.0834       dtype float32     
cnn.1.Conv2d.bias                        shape (32,)            mean 0.0000       std 0.0000       dtype float32     
cnn.1.BatchNorm.gamma                 

In [6]:
# Batch Training set
train_len = len(x_train)
overflow = train_len % batch_size
x_train = x_train[:train_len-overflow]
y_train = y_train[:train_len-overflow]
x_train = np.reshape(x_train, (-1, batch_size, in_height, in_width, 1))
y_train = np.reshape(y_train, (-1, batch_size, out_dim))

print('x_train: ' + str(x_train.shape))
print('y_train: ' + str(y_train.shape))

# Batch validation set
val_len = len(x_val)
overflow = val_len % batch_size
x_val = x_val[:val_len-overflow]
y_val = y_val[:val_len-overflow]
x_val = np.reshape(x_val, (-1, batch_size, in_height, in_width, 1))
y_val = np.reshape(y_val, (-1, batch_size, out_dim))

print('x_val: ' + str(x_val.shape))
print('y_val: ' + str(y_val.shape))

x_train: (1875, 32, 28, 28, 1)
y_train: (1875, 32, 10)
x_val: (312, 32, 28, 28, 1)
y_val: (312, 32, 10)


In [7]:
# Train Model
model.fit(
  dataset_train=(x_train, y_train),
  epochs=epochs,
  dataset_val=(x_val, y_val)
)

Epoch 1/5:


mean loss: 0.2653 - batch loss: 0.0741 - mean acc: 92.20 - batch acc: 96.88 - lr: 0.005000 - step: 1875: 100%|██████████| 1875/1875 [06:38<00:00,  4.70it/s]
mean loss: 0.1007 - batch loss: 0.1191 - mean acc: 96.72 - batch acc: 93.75: 100%|████████████████████████████████████████| 312/312 [00:15<00:00, 19.72it/s]


validation loss: 0.1007
validation accuracy: 96.72%

Epoch 2/5:


mean loss: 0.0893 - batch loss: 0.0520 - mean acc: 97.20 - batch acc: 96.88 - lr: 0.004275 - step: 3750: 100%|██████████| 1875/1875 [07:13<00:00,  4.33it/s]
mean loss: 0.0743 - batch loss: 0.0776 - mean acc: 97.58 - batch acc: 93.75: 100%|████████████████████████████████████████| 312/312 [00:17<00:00, 18.20it/s]


validation loss: 0.0743
validation accuracy: 97.58%

Epoch 3/5:


mean loss: 0.0632 - batch loss: 0.0278 - mean acc: 98.00 - batch acc: 100.00 - lr: 0.002525 - step: 5625: 100%|█████████| 1875/1875 [06:41<00:00,  4.67it/s]
mean loss: 0.0439 - batch loss: 0.0157 - mean acc: 98.58 - batch acc: 100.00: 100%|███████████████████████████████████████| 312/312 [00:13<00:00, 22.54it/s]


validation loss: 0.0439
validation accuracy: 98.58%

Epoch 4/5:


mean loss: 0.0316 - batch loss: 0.0222 - mean acc: 99.07 - batch acc: 100.00 - lr: 0.000775 - step: 7500: 100%|█████████| 1875/1875 [06:52<00:00,  4.54it/s]
mean loss: 0.0342 - batch loss: 0.0044 - mean acc: 98.97 - batch acc: 100.00: 100%|███████████████████████████████████████| 312/312 [00:16<00:00, 19.39it/s]


validation loss: 0.0342
validation accuracy: 98.97%

Epoch 5/5:


mean loss: 0.0108 - batch loss: 0.0014 - mean acc: 99.75 - batch acc: 100.00 - lr: 0.000050 - step: 9375: 100%|█████████| 1875/1875 [06:28<00:00,  4.83it/s]
mean loss: 0.0235 - batch loss: 0.0014 - mean acc: 99.23 - batch acc: 100.00: 100%|███████████████████████████████████████| 312/312 [00:15<00:00, 20.51it/s]

validation loss: 0.0235
validation accuracy: 99.23%




