In [1]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import numpy as np
from tqdm import tqdm
from model import model
from weights import Weights

### Get mnist dataset

In [2]:
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train = np.array(x_train,dtype=np.float32)
x_test = np.array(x_test,dtype=np.float32)


## Normalize ,reshape input and one-hot coding label

In [3]:
x_train = x_train/255.
x_test = x_test/255.
x_train = x_train.reshape(x_train.shape[0],x_train.shape[1],x_train.shape[2],1)
x_test = x_test.reshape(x_test.shape[0],x_test.shape[1],x_test.shape[2],1)
y_train = tf.keras.utils.to_categorical(y_train,num_classes = 10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes = 10)

## Get dataset with batchsize

In [4]:
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_dataset = train_dataset.batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_dataset = test_dataset.batch(batch_size)

## Define weight_configs , initializers and get trainable_params

In [5]:
weight_configs = [
        [5,5,1,32],
        [3,3,32,64],
        [1600,1024],
        [1024,10]
]
init_method = tf.initializers.HeNormal()
w = Weights(weight_configs,init_method)
trainable_params = w.get_trainable_params()

### Define loss and optimizer

In [6]:
learning_rate = 0.001
def loss(pred,target):
    return tf.reduce_mean(tf.losses.categorical_crossentropy(target,pred))

optimizer = tf.optimizers.Adam(learning_rate)

### Define train_step and test step function

In [7]:
@tf.function
def train_step(data,target):
    with tf.GradientTape() as tape:
        pred = model(data,trainable_params)
        train_loss = loss(pred,target)
    grads = tape.gradient(train_loss,trainable_params)
    optimizer.apply_gradients(zip(grads,trainable_params))
    pred_label = tf.argmax(pred,axis=1)
    target_label = tf.argmax(target,axis=1)
    check_equal = (tf.cast(pred_label,tf.int64)) == (tf.cast(target_label,tf.int64))
    correct = tf.reduce_sum(tf.cast(check_equal,tf.float32))
    return train_loss,correct

@tf.function
def test_step(data,target):
    pred = model(data,trainable_params)
    test_loss = loss(pred,target)
    pred_label = tf.argmax(pred,axis=1)
    target_label = tf.argmax(target,axis=1)
    check_equal = (tf.cast(pred_label,tf.int64)) == (tf.cast(target_label,tf.int64))
    correct = tf.reduce_sum(tf.cast(check_equal,tf.float32))
    return test_loss,correct

## Train and test

In [8]:
epochs = 10
for epoch in range(epochs):
    nums_train = 0 
    nums_test = 0
    train_corrects = 0 
    test_corrects =0
    train_losses = []
    test_losses = []
    for data,target in tqdm(train_dataset):
        train_loss,correct = train_step(data,target)
        nums_train += len(data)
        train_corrects += correct
        train_losses.append(train_loss)
#         break
    
    for data,target in tqdm(test_dataset):
        test_loss,correct = test_step(data,target)
        nums_test += len(data)
        test_corrects += correct
        test_losses.append(test_loss)
#         break
    print(test_corrects,train_corrects)
    print("EPOCH : {}/{}, train_loss : {},train_acc : {}, test_loss :{}, test_acc : {}".format(epoch+1,epochs,sum(train_losses)/len(train_losses),train_corrects/nums_train*100.,sum(test_losses)/len(test_losses),test_corrects/nums_test*100.))
    

100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [01:15<00:00, 12.47it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 44.89it/s]
  0%|                                                                                          | 0/938 [00:00<?, ?it/s]

tf.Tensor(9771.0, shape=(), dtype=float32) tf.Tensor(57724.0, shape=(), dtype=float32)
EPOCH : 1/10, train_loss : 0.13088291883468628,train_acc : 96.2066650390625, test_loss :0.07201249897480011, test_acc : 97.70999908447266


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [01:15<00:00, 12.38it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 39.89it/s]
  0%|                                                                                          | 0/938 [00:00<?, ?it/s]

tf.Tensor(9872.0, shape=(), dtype=float32) tf.Tensor(59197.0, shape=(), dtype=float32)
EPOCH : 2/10, train_loss : 0.042802970856428146,train_acc : 98.66166687011719, test_loss :0.036214008927345276, test_acc : 98.72000122070312


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [01:13<00:00, 12.73it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 47.67it/s]
  0%|                                                                                          | 0/938 [00:00<?, ?it/s]

tf.Tensor(9893.0, shape=(), dtype=float32) tf.Tensor(59481.0, shape=(), dtype=float32)
EPOCH : 3/10, train_loss : 0.028163446113467216,train_acc : 99.13500213623047, test_loss :0.03260624781250954, test_acc : 98.93000030517578


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [01:15<00:00, 12.45it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 44.69it/s]
  0%|                                                                                          | 0/938 [00:00<?, ?it/s]

tf.Tensor(9894.0, shape=(), dtype=float32) tf.Tensor(59599.0, shape=(), dtype=float32)
EPOCH : 4/10, train_loss : 0.020761532709002495,train_acc : 99.3316650390625, test_loss :0.03458651527762413, test_acc : 98.94000244140625


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [01:18<00:00, 11.97it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 45.50it/s]
  0%|                                                                                          | 0/938 [00:00<?, ?it/s]

tf.Tensor(9891.0, shape=(), dtype=float32) tf.Tensor(59674.0, shape=(), dtype=float32)
EPOCH : 5/10, train_loss : 0.017063552513718605,train_acc : 99.4566650390625, test_loss :0.04115865007042885, test_acc : 98.90999603271484


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [01:13<00:00, 12.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 48.36it/s]
  0%|                                                                                          | 0/938 [00:00<?, ?it/s]

tf.Tensor(9900.0, shape=(), dtype=float32) tf.Tensor(59707.0, shape=(), dtype=float32)
EPOCH : 6/10, train_loss : 0.014477337710559368,train_acc : 99.51166534423828, test_loss :0.037884801626205444, test_acc : 99.0


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [01:11<00:00, 13.19it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 47.90it/s]
  0%|                                                                                          | 0/938 [00:00<?, ?it/s]

tf.Tensor(9906.0, shape=(), dtype=float32) tf.Tensor(59731.0, shape=(), dtype=float32)
EPOCH : 7/10, train_loss : 0.013900073245167732,train_acc : 99.55166625976562, test_loss :0.04029586538672447, test_acc : 99.05999755859375


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [01:11<00:00, 13.16it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 48.36it/s]
  0%|                                                                                          | 0/938 [00:00<?, ?it/s]

tf.Tensor(9898.0, shape=(), dtype=float32) tf.Tensor(59804.0, shape=(), dtype=float32)
EPOCH : 8/10, train_loss : 0.009947559796273708,train_acc : 99.67333221435547, test_loss :0.045926161110401154, test_acc : 98.97999572753906


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [01:24<00:00, 11.10it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 43.88it/s]
  0%|                                                                                          | 0/938 [00:00<?, ?it/s]

tf.Tensor(9912.0, shape=(), dtype=float32) tf.Tensor(59778.0, shape=(), dtype=float32)
EPOCH : 9/10, train_loss : 0.01039937324821949,train_acc : 99.62999725341797, test_loss :0.04023106396198273, test_acc : 99.1199951171875


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [01:15<00:00, 12.49it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 157/157 [00:03<00:00, 48.12it/s]


tf.Tensor(9890.0, shape=(), dtype=float32) tf.Tensor(59835.0, shape=(), dtype=float32)
EPOCH : 10/10, train_loss : 0.009041407145559788,train_acc : 99.7249984741211, test_loss :0.04737474396824837, test_acc : 98.9000015258789
