In [1]:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.utils.np_utils import to_categorical

In [None]:
!git clone https://github.com/burchim/NeuralNetsWithNumpy.git
%cd NeuralNetsWithNumpy

In [2]:
# Modules
from nnet.modules import (
  Identity,
  Sigmoid,
  Tanh,
  ReLU,
  PReLU,
  Swish,
  LayerNorm,
  BatchNorm
)

# Optimizers
from nnet.optimizers import (
  SGD,
  RMSprop,
  Adam,
  AdamW
)

# Losses
from nnet.losses import (
  MeanAbsoluteError,
  MeanSquaredError,
  SoftmaxCrossEntropy
)

# Schedulers
from nnet.schedulers import (
  ConstantScheduler,
  WarmupCosineAnnealingScheduler
)

# Model
from nnet.models import VisionTransformer

In [3]:
# Seed
seed = 42
np.random.seed(seed)

# dtype
dtype = np.float32

# Load Dataset
(x_train, y_train), (x_val, y_val) = mnist.load_data()

# Prepare Dataset
x_train = (x_train/255).astype(dtype)
y_train = to_categorical(y_train).astype(dtype)
x_val = (x_val/255).astype(dtype)
y_val = to_categorical(y_val).astype(dtype)

# shapes
print('x_train: ' + str(x_train.shape))
print('y_train: ' + str(y_train.shape))
print('x_val:  '  + str(x_val.shape))
print('y_val:  '  + str(y_val.shape))
print()

x_train: (60000, 28, 28)
y_train: (60000, 10)
x_val:  (10000, 28, 28)
y_val:  (10000, 10)



In [4]:
# Params
epochs = 5
batch_size = 32
optimizer = AdamW
scheduler = WarmupCosineAnnealingScheduler

in_height = 28
in_width = 28
in_dim = 1

patch_size = 4
dim_model = 32
ff_ratio = 4
num_heads = 4
num_blocks = 2
drop_rate = 0.1
out_dim = 10
dim_mlp_layers = [128, out_dim]
mlp_norm = LayerNorm
hidden_function = ReLU
out_function = Identity
loss_function = SoftmaxCrossEntropy

In [5]:
# Create Model
model = VisionTransformer(
  in_height=in_height, 
  in_width=in_width, 
  in_dim=in_dim, 
  patch_size=patch_size,
  num_blocks=num_blocks,
  dim_model=dim_model,
  ff_ratio=ff_ratio,
  num_heads=num_heads,
  drop_rate=drop_rate,
  dim_mlp_layers=dim_mlp_layers,
  mlp_norm=mlp_norm,
  hidden_function=hidden_function,
  out_function=out_function,
  loss_function=loss_function,
  dtype=dtype
)

model.optimizer = optimizer(model.get_parameters(), lr=0, betas=(0.9, 0.999), eps=1e-8)
model.scheduler = scheduler(model.optimizer, warmup_steps=len(x_train) // batch_size, lr_max=0.005, lr_min=0.00005, end_step=5*len(x_train) // batch_size)

# Model Summary
model.summary()

229898 parameters
embedding.linear.weight                                        shape (16, 32)         mean 0.0016       std 0.2578       dtype float32     
embedding.linear.bias                                          shape (32,)            mean 0.0000       std 0.0000       dtype float32     
pos_encoding.embeddings                                        shape (49, 32)         mean 0.0491       std 0.9994       dtype float32     
blocks.0.TransformerBlock.mhsa_module.layernorm.gamma          shape (32,)            mean 1.0000       std 0.0000       dtype float32     
blocks.0.TransformerBlock.mhsa_module.layernorm.beta           shape (32,)            mean 0.0000       std 0.0000       dtype float32     
blocks.0.TransformerBlock.mhsa_module.mhsa.query_layer.weight  shape (32, 32)         mean 0.0007       std 0.1780       dtype float32     
blocks.0.TransformerBlock.mhsa_module.mhsa.query_layer.bias    shape (32,)            mean 0.0000       std 0.0000       dtype float32     
bl

In [6]:
# Batch Training set
train_len = len(x_train)
overflow = train_len % batch_size
x_train = x_train[:train_len-overflow]
y_train = y_train[:train_len-overflow]
x_train = np.reshape(x_train, (-1, batch_size, in_height, in_width))
y_train = np.reshape(y_train, (-1, batch_size, out_dim))

print('x_train: ' + str(x_train.shape))
print('y_train: ' + str(y_train.shape))

# Batch validation set
val_len = len(x_val)
overflow = val_len % batch_size
x_val = x_val[:val_len-overflow]
y_val = y_val[:val_len-overflow]
x_val = np.reshape(x_val, (-1, batch_size, in_height, in_width))
y_val = np.reshape(y_val, (-1, batch_size, out_dim))

print('x_val: ' + str(x_val.shape))
print('y_val: ' + str(y_val.shape))

x_train: (1875, 32, 28, 28)
y_train: (1875, 32, 10)
x_val: (312, 32, 28, 28)
y_val: (312, 32, 10)


In [7]:
# Train Model
model.fit(
  dataset_train=(x_train, y_train),
  epochs=epochs,
  dataset_val=(x_val, y_val)
)

Epoch 1/5:


mean loss: 0.3800 - batch loss: 0.2495 - mean acc: 88.40 - batch acc: 93.75 - lr: 0.005000 - step: 1875: 100%|█████| 1875/1875 [01:48<00:00, 17.22it/s]
mean loss: 0.1257 - batch loss: 0.1067 - mean acc: 96.14 - batch acc: 93.75: 100%|███████████████████████████████████| 312/312 [00:05<00:00, 58.55it/s]


validation loss: 0.1257
validation accuracy: 96.14%

Epoch 2/5:


mean loss: 0.1215 - batch loss: 0.0974 - mean acc: 96.25 - batch acc: 96.88 - lr: 0.004275 - step: 3750: 100%|█████| 1875/1875 [01:31<00:00, 20.47it/s]
mean loss: 0.0821 - batch loss: 0.0508 - mean acc: 97.49 - batch acc: 100.00: 100%|██████████████████████████████████| 312/312 [00:05<00:00, 56.23it/s]


validation loss: 0.0821
validation accuracy: 97.49%

Epoch 3/5:


mean loss: 0.0747 - batch loss: 0.2189 - mean acc: 97.66 - batch acc: 93.75 - lr: 0.002525 - step: 5625: 100%|█████| 1875/1875 [01:36<00:00, 19.34it/s]
mean loss: 0.0473 - batch loss: 0.0642 - mean acc: 98.49 - batch acc: 96.88: 100%|███████████████████████████████████| 312/312 [00:08<00:00, 35.33it/s]


validation loss: 0.0473
validation accuracy: 98.49%

Epoch 4/5:


mean loss: 0.0448 - batch loss: 0.0026 - mean acc: 98.61 - batch acc: 100.00 - lr: 0.000775 - step: 7500: 100%|████| 1875/1875 [01:41<00:00, 18.45it/s]
mean loss: 0.0372 - batch loss: 0.0178 - mean acc: 98.86 - batch acc: 100.00: 100%|██████████████████████████████████| 312/312 [00:04<00:00, 64.09it/s]


validation loss: 0.0372
validation accuracy: 98.86%

Epoch 5/5:


mean loss: 0.0236 - batch loss: 0.0201 - mean acc: 99.32 - batch acc: 100.00 - lr: 0.000050 - step: 9375: 100%|████| 1875/1875 [01:43<00:00, 18.14it/s]
mean loss: 0.0327 - batch loss: 0.0165 - mean acc: 98.90 - batch acc: 100.00: 100%|██████████████████████████████████| 312/312 [00:09<00:00, 33.84it/s]

validation loss: 0.0327
validation accuracy: 98.90%




