In [1]:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.utils.np_utils import to_categorical

In [None]:
!git clone https://github.com/burchim/NeuralNetsWithNumpy.git
%cd NeuralNetsWithNumpy

In [2]:
# Modules
from nnet.modules import (
  Identity,
  Sigmoid,
  Tanh,
  ReLU,
  PReLU,
  Swish,
  LayerNorm,
  BatchNorm
)

# Optimizers
from nnet.optimizers import (
  SGD,
  RMSprop,
  Adam,
  AdamW
)

# Losses
from nnet.losses import (
  MeanAbsoluteError,
  MeanSquaredError,
  SoftmaxCrossEntropy
)

# Schedulers
from nnet.schedulers import (
  ConstantScheduler,
  WarmupCosineAnnealingScheduler
)

# Model
from nnet.models import MLPMixerModel

In [3]:
# Seed
seed = 42
np.random.seed(seed)

# dtype
dtype = np.float32

# Load Dataset
(x_train, y_train), (x_val, y_val) = mnist.load_data()

# Prepare Dataset
x_train = (x_train/255).astype(dtype)
y_train = to_categorical(y_train).astype(dtype)
x_val = (x_val/255).astype(dtype)
y_val = to_categorical(y_val).astype(dtype)

# shapes
print('x_train: ' + str(x_train.shape))
print('y_train: ' + str(y_train.shape))
print('x_val:  '  + str(x_val.shape))
print('y_val:  '  + str(y_val.shape))
print()

x_train: (60000, 28, 28)
y_train: (60000, 10)
x_val:  (10000, 28, 28)
y_val:  (10000, 10)



In [4]:
# Params
epochs = 5
batch_size = 32
optimizer = AdamW
scheduler = WarmupCosineAnnealingScheduler

in_height = 28
in_width = 28
in_dim = 1

patch_size = 4
num_layers = 2
dim_feat = 32
dim_expand_feat = 128
dim_expand_seq = 128
out_dim = 10
out_layers = [128, out_dim]
out_norm = LayerNorm

hidden_function = ReLU
drop_rate = 0.1
out_function = Identity
loss_function = SoftmaxCrossEntropy

In [5]:
# Create Model
model = MLPMixerModel(
  in_height=in_height, 
  in_width=in_width, 
  in_dim=in_dim, 
  patch_size=patch_size,
  num_layers=num_layers,
  dim_feat=dim_feat,
  dim_expand_feat=dim_expand_feat,
  dim_expand_seq=dim_expand_seq,
  out_layers=out_layers,
  out_norm=out_norm,
  hidden_function=hidden_function,
  out_function=out_function,
  drop_rate=drop_rate,
  loss_function=loss_function,
  dtype=dtype
)

model.optimizer = optimizer(model.get_parameters(), lr=0, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.1)
model.scheduler = scheduler(model.optimizer, warmup_steps=len(x_train) // batch_size, lr_max=0.005, lr_min=0.00005, end_step=5*len(x_train) // batch_size)

# Model Summary
model.summary()

245324 parameters
embedding.linear.weight                  shape (16, 32)         mean 0.0016       std 0.2578       dtype float32     
embedding.linear.bias                    shape (32,)            mean 0.0000       std 0.0000       dtype float32     
layers.0.MLPMixer.layernorm1.gamma       shape (32,)            mean 1.0000       std 0.0000       dtype float32     
layers.0.MLPMixer.layernorm1.beta        shape (32,)            mean 0.0000       std 0.0000       dtype float32     
layers.0.MLPMixer.mlp1.0.Linear.weight   shape (49, 128)        mean -0.0019      std 0.1428       dtype float32     
layers.0.MLPMixer.mlp1.0.Linear.bias     shape (128,)           mean 0.0000       std 0.0000       dtype float32     
layers.0.MLPMixer.mlp1.1.Linear.weight   shape (128, 49)        mean -0.0012      std 0.0879       dtype float32     
layers.0.MLPMixer.mlp1.1.Linear.bias     shape (49,)            mean 0.0000       std 0.0000       dtype float32     
layers.0.MLPMixer.layernorm2.gamma    

In [6]:
# Batch Training set
train_len = len(x_train)
overflow = train_len % batch_size
x_train = x_train[:train_len-overflow]
y_train = y_train[:train_len-overflow]
x_train = np.reshape(x_train, (-1, batch_size, in_height, in_width))
y_train = np.reshape(y_train, (-1, batch_size, out_dim))

print('x_train: ' + str(x_train.shape))
print('y_train: ' + str(y_train.shape))

# Batch validation set
val_len = len(x_val)
overflow = val_len % batch_size
x_val = x_val[:val_len-overflow]
y_val = y_val[:val_len-overflow]
x_val = np.reshape(x_val, (-1, batch_size, in_height, in_width))
y_val = np.reshape(y_val, (-1, batch_size, out_dim))

print('x_val: ' + str(x_val.shape))
print('y_val: ' + str(y_val.shape))

x_train: (1875, 32, 28, 28)
y_train: (1875, 32, 10)
x_val: (312, 32, 28, 28)
y_val: (312, 32, 10)


In [7]:
# Train Model
model.fit(
  dataset_train=(x_train, y_train),
  epochs=epochs,
  dataset_val=(x_val, y_val)
)

Epoch 1/5:


mean loss: 0.3331 - batch loss: 0.0367 - mean acc: 89.92 - batch acc: 100.00 - lr: 0.005000 - step: 1875: 100%|█████████| 1875/1875 [01:33<00:00, 20.15it/s]                
mean loss: 0.1251 - batch loss: 0.1848 - mean acc: 96.08 - batch acc: 93.75: 100%|████████████████████████████████████████| 312/312 [00:05<00:00, 60.46it/s]


validation loss: 0.1251
validation accuracy: 96.08%

Epoch 2/5:


mean loss: 0.1359 - batch loss: 0.0275 - mean acc: 95.75 - batch acc: 100.00 - lr: 0.004275 - step: 3750: 100%|█████████| 1875/1875 [01:28<00:00, 21.27it/s]
mean loss: 0.0867 - batch loss: 0.1380 - mean acc: 97.32 - batch acc: 93.75: 100%|████████████████████████████████████████| 312/312 [00:07<00:00, 42.95it/s]


validation loss: 0.0867
validation accuracy: 97.32%

Epoch 3/5:


mean loss: 0.0923 - batch loss: 0.0212 - mean acc: 97.18 - batch acc: 100.00 - lr: 0.002525 - step: 5625: 100%|█████████| 1875/1875 [01:24<00:00, 22.11it/s]
mean loss: 0.0630 - batch loss: 0.0279 - mean acc: 98.07 - batch acc: 100.00: 100%|███████████████████████████████████████| 312/312 [00:05<00:00, 62.04it/s]


validation loss: 0.0630
validation accuracy: 98.07%

Epoch 4/5:


mean loss: 0.0530 - batch loss: 0.0310 - mean acc: 98.38 - batch acc: 96.88 - lr: 0.000775 - step: 7500: 100%|██████████| 1875/1875 [01:17<00:00, 24.05it/s]
mean loss: 0.0368 - batch loss: 0.0503 - mean acc: 98.80 - batch acc: 96.88: 100%|████████████████████████████████████████| 312/312 [00:04<00:00, 64.01it/s]


validation loss: 0.0368
validation accuracy: 98.80%

Epoch 5/5:


mean loss: 0.0239 - batch loss: 0.0205 - mean acc: 99.30 - batch acc: 100.00 - lr: 0.000050 - step: 9375: 100%|█████████| 1875/1875 [01:14<00:00, 25.02it/s]
mean loss: 0.0282 - batch loss: 0.0057 - mean acc: 99.09 - batch acc: 100.00: 100%|███████████████████████████████████████| 312/312 [00:04<00:00, 62.49it/s]


validation loss: 0.0282
validation accuracy: 99.09%

