In [18]:
import tensorflow as tf
import tensorflow_datasets as tfds
from matplotlib import pyplot as plt

In [2]:
(train_ds, test_ds), ds_info = tfds.load(
    'mnist', split=['train' , 'test'], as_supervised=True, with_info=True,
    data_dir='data/'
)

In [3]:
glorot_init = tf.initializers.GlorotUniform()
zeros = tf.initializers.Zeros()

In [4]:
@tf.function
def dense(features, weights, biases):
    return tf.add(tf.matmul(features, weights), biases)

In [5]:
@tf.function
def norm_data(image, label):
    image = tf.cast(image, tf.float32)
    image = image / 255.
    image = tf.squeeze(image, axis=2)
    return image, tf.one_hot(label, 10)

train_ds = train_ds.shuffle(len(train_ds)).map(norm_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(64).prefetch(tf.data.experimental.AUTOTUNE)
# test_ds = test_ds.map(norm_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(1).prefetch(tf.data.experimental.AUTOTUNE)

In [6]:
@tf.function
def loss_fn(logits, labels):
    return tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(
            labels=labels, logits=logits
        )
    )

In [7]:
flatten = tf.keras.layers.Flatten()

In [8]:
n_inputs = (784, 10)
n_classes = 10

weights = tf.Variable(glorot_init(shape=n_inputs), name='weights')
biases = tf.Variable(zeros(n_classes), name='biases')

In [9]:
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-4)

In [10]:
total_loss = []
for epoch in range(10):
    avg_loss = 0
    for step, (train_image, train_label) in enumerate(train_ds):
        with tf.GradientTape() as tape:
            flatten_train_image = flatten(train_image)
            logits = dense(flatten_train_image, weights, biases) # Forward pass
            loss = loss_fn(logits, train_label)
    
        grads = tape.gradient(loss, [weights, biases])
        optimizer.apply_gradients(zip(grads, [weights, biases]))
        avg_loss += loss
        total_loss.append(avg_loss)
        
#     print(tf.convert_to_tensor(total_loss) / len(train_ds))

In [11]:
image, label = next(iter(test_ds))

In [12]:
norm_image, _ = norm_data(image, label)

In [13]:
preds = dense(flatten(tf.expand_dims(norm_image, axis=0)), weights, biases)

In [22]:
preds.numpy().argmax(axis=1)[0], label.numpy()

(6, 2)