# Your first TensorFlow model

For this exercise you will train a fully connected model to learn to identify handwritten digits (the MNIST data set). To do this I provide you with a function that prepares the data for you. Do not mind the warnings, they do not hint at any errors.

## (1A) Write a fully connected layer by subclassing tf.Module.

The layer should take two arguments, the number of inputs and the number of units. It should instantiate weights and biases in its constructor method (init) accordingly as a ``tf.Variable``.

## (1B) Write a model by subclassing tf.keras.Model

The model should use the fully connected layers that you defined before. It therefore needs the input shape as an argument in its constructor. 

The model should have one ``tf.keras.layers.Flatten()`` layer, followed by a single fully connected layer with ``10 units``, which has ``784 inputs`` since an mnist digit has shape (28,28). After this, we want to apply an activation function that turns the linear output into a categorical probability distribution over the 10 digit classes. To do so, we can use the softmax function. 

A softmax function normalizes an input along a particular (or multiple) axes to sum to 1 in a non-linear way such that small differences matter disproportionately less in the presence of large differences in value.

$$\text{softmax}(x) = \frac{e^{x}}{\sum_j^k e^{x_j}}$$

The function can be instantiated as a layer with ``tf.keras.layers.Softmax(axis=-1)``.

In addition to this, you should add a loss function as an attribute. For 10-way classification we use Categorical Crossentropy ``tf.keras.losses.CategoricalCrossentropy()``.

The ``call method`` should output the softmaxed model output - not yet the loss, we will do this in the training loop!

## (1C) Write a training_loop function 

Arguments: ``(model, train_ds, val_ds, epochs, learning_rate)``

Returns: ``loss_history, val_loss_history``

For a number of epochs we want to iterate over the mnist data set. During each iteration, we iterate over the entire data set, which returns a tuple ``(inputs, targets)``. Within the ``tf.GradientTape`` context we compute the forward pass and the loss of the model, and then use the tape to get the gradients of the loss with respect to ``model.trainable_variables``.
The optimizer which applies the gradients to the variables is already implemented as a function for you. The training loop should call this function once you have the gradients. Once that is done, you append the loss value to ``loss_history``.

After each epoch, you want to evaluate your model on data not trained on, this is why we have ``val_ds``. Like with ``train_ds`` you iterate over it and compute the loss, except here you do not use a gradient tape and you do not use an optimizer. The validation loss should be appended to `val_loss_history`.

## (1D) Train the model for a few epochs and plot the (val) loss history with matplotlib

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

def get_mnist(batch_size):
    """
    Load and prepare MNIST as a tensorflow dataset.
    Returns a train and a validation dataset.

    Args:
    batch_size (int)
    """
    # load data from tensorflow-datasets
    train_ds, val_ds = tfds.load('mnist', split=['train', 'test'], shuffle_files=True)
    
    # function for one-hot encoding labels
    one_hot = lambda x: tf.one_hot(x, 10)
    
    # function to make sure shapes are correct, apply one_hot to labels and normalize inputs
    map_func = lambda x,y: (tf.cast(
        tf.expand_dims(x, -1), dtype=tf.float32)/255.,
                            tf.cast(one_hot(y),tf.float32))
    
    # turn dictionary data set into tuple data set (inputs, labels)
    map_func_2 = lambda x: (x["image"],x["label"])
    
    # map the defined functions to the data set (they will be applied to each element while the data is loaded)
    train_ds = train_ds.map(map_func_2).map(map_func)
    val_ds   = val_ds.map(map_func_2).map(map_func)
    
    # shuffle, then create batches, then prefetch (preparing a batch while the model still computes the previous one)
    train_ds = train_ds.shuffle(4096).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    val_ds   = val_ds.shuffle(4096).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return (train_ds, val_ds)

train_ds, val_ds = get_mnist(batch_size=32)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [7]:
def optimization_step(variables, grads, learning_rate=1e-3):
    for layer, gradient in zip(variables, grads):
        for param, grad in zip(layer,gradient):
            param.assign_sub(grad*learning_rate)


class LinearLayer(tf.Module):
    def __init__(self, ):
        super().__init__()
        # YOUR CODE HERE
        pass
        
    def __call__(x):
        # YOUR CODE HERE
        pass
        
        
class DigitClassifier(tf.keras.Model):
    def __init__(self, ):
        super().__init__()
        # YOUR CODE HERE
        pass
    def call(x):
        # YOUR CODE HERE
        pass

def training_loop(model, train_ds, val_ds, epochs, learning_rate=1e-3):
    loss_history=[]
    val_loss_history=[]
    for e in range(epochs):
        # YOUR CODE HERE
        ...
    