In [44]:
import tensorflow as tf
import numpy as np

In [59]:
class NaiveDense:
     def __init__(self, input_shape, output_shape, activation):
         self.activation = activation

         self.W = tf.Variable(tf.random.uniform((input_shape, output_shape)))
         self.b = tf.Variable(tf.zeros((output_shape,)))

     def __call__(self, x):
         return self.activation(tf.matmul(x, self.W) + self.b)

     @property
     def weights(self):
         return [self.W, self.b]

In [60]:
class NaiveSequential:
    def __init__(self, layers: [NaiveDense]):
        self.layers = layers

    def __call__(self, inputs):
        x = inputs
        for l in self.layers:
            x = l(x)
        return x

    @property
    def weights(self):
        weights = []
        for l in self.layers:
            weights += l.weights
        return weights

In [61]:
model = NaiveSequential([
    NaiveDense(input_shape=784, output_shape=512, activation=tf.nn.relu),
    NaiveDense(input_shape=512, output_shape=10, activation=tf.nn.softmax),
])

assert len(model.weights) == 4

In [62]:
class BatchGenerator:
    def __init__(self, images, labels, batch_size=64):
        self.index = 0
        self.images = images
        self.labels = labels
        self.batch_size = batch_size

    def next(self):
        images = self.images[self.index: self.index + self.batch_size]
        labels = self.labels[self.index: self.index + self.batch_size]
        self.index += self.batch_size
        return images, labels

In [63]:
def update_weights(gradients, weights):
    lr = 0.001
    # or apply gradients from optimizer keras
    optim = tf.keras.optimizers.SGD(learning_rate=0.001)
    optim.apply_gradients(zip(gradients, weights))
#     for g, w in zip(gradients, weights):
#         w.assign_sub(g * lr)

In [64]:
def one_training_step(c_model, images_batch, labels_batch):
    with tf.GradientTape() as tape:
        preds = model(images_batch)
        per_sample_losses = tf.keras.losses.sparse_categorical_crossentropy(labels_batch
                                                                            , preds)
        average_loss = tf.reduce_mean(per_sample_losses)
    gradients = tape.gradient(average_loss, c_model.weights)
    update_weights(gradients, c_model.weights)
    return average_loss

In [65]:
def fit(c_model, images, labels, epochs, batch_size=64):
    for epoch in range(epochs):
        print(f'Epoch: {epoch}')
        batch_generator = BatchGenerator(images, labels)
        for batch in range(len(images) // batch_size):
            images_batch , labels_batch = batch_generator.next()
            loss = one_training_step(c_model,
                                     images_batch, labels_batch,
                                     )
            if batch % 100 == 0:
                print(f'loss at batch {batch}: {loss:.2f}')

In [66]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape((60000, 28*28))
test_images = test_images.reshape((10000, 28*28))

train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255

In [67]:
fit(model, train_images, train_labels, epochs=10, batch_size=64)

Epoch: 0
loss at batch 0: 13.85
loss at batch 100: 15.11
loss at batch 200: 14.61
loss at batch 300: 14.36
loss at batch 400: 13.60
loss at batch 500: 14.61
loss at batch 600: 13.35
loss at batch 700: 14.86
loss at batch 800: 14.86
loss at batch 900: 14.10
Epoch: 1
loss at batch 0: 14.10
loss at batch 100: 13.85
loss at batch 200: 14.10
loss at batch 300: 15.36
loss at batch 400: 15.11
loss at batch 500: 14.10
loss at batch 600: 13.35
loss at batch 700: 14.86
loss at batch 800: 14.86
loss at batch 900: 14.10
Epoch: 2
loss at batch 0: 14.10
loss at batch 100: 13.85
loss at batch 200: 14.10
loss at batch 300: 15.36
loss at batch 400: 15.11
loss at batch 500: 14.10
loss at batch 600: 13.35
loss at batch 700: 14.86
loss at batch 800: 14.86
loss at batch 900: 14.10
Epoch: 3
loss at batch 0: 14.10
loss at batch 100: 13.85
loss at batch 200: 14.10
loss at batch 300: 15.36
loss at batch 400: 15.11
loss at batch 500: 14.10
loss at batch 600: 13.35
loss at batch 700: 14.86
loss at batch 800: 14.

In [70]:
preds = model(test_images)
preds = preds.numpy()
preds_labels = np.argmax(preds, axis=1)
matches = preds_labels == test_labels
f'accuracy {matches.mean():.4f}'

'accuracy 0.1135'