### Homework 07
## Assignment 2: Implement LSTM
# 2.1 Prepare dataset
- MNIST
- divide the images up into sequences that will be fed into the model; shape should be (batch, sequencelength, features)
- need to alternate the signs of the targets, and implement a cumulative sum

In [73]:
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Conv2D, AveragePooling2D, TimeDistributed, LSTM, GlobalAvgPool2D, AbstractRNNCell, MaxPooling2D, RNN
import numpy as np
import matplotlib.pyplot as plt
import datetime
import tqdm

# magic line only needed in jupyter notebooks!
%reload_ext tensorboard

In [74]:
(train_ds, test_ds) = tfds.load('mnist', split=['train', 'test'], as_supervised=True, with_info=False)

In [75]:
def prepare_data(mnist, batch_size, sequence_length):
    
    # change image datatype from unit8 to tf.float32
    mnist = mnist.map(lambda img, target:(tf.cast(img, tf.float32), target))
    # normalize values
    mnist = mnist.map(lambda img, target: (tf.cast(tf.image.per_image_standardization(img), tf.float32), target))
    # batch amount of images depending on the wanted sequence length 
    mnist_sequence =  mnist.shuffle(1000).batch(sequence_length)

    # calculations
    # create alternating positve and negative signes of target values and take cummulative sum
    
    # range to identify which target in the sequence needs with new sign
    range_vals = tf.range(sequence_length)
    # empty lists to store tensors with sequence of images and new tensor with newly calculated target values
    mnist_seq = list()
    mnist_targets = list()
    # for each sequence of images
    for seq in mnist_sequence:
        # take old target values
        target_digits = seq[-1]
        # create alternating signes of target values by checking whether the entry index modulo 2 is zero 
        # (i.e. even entries are positive, uneven ones negative)
        alternating_target_numbers = tf.where(tf.math.floormod(range_vals,2)==0, (target_digits), -(target_digits))
        # take cum. sum and cast it to float32
        new_target = tf.math.cumsum(alternating_target_numbers)
        new_target = tf.cast(new_target, tf.float32)
        # add sequence to a list and add new target values to a list (later we will create the new dataset out of those)
        mnist_seq.append(seq[0])
        mnist_targets.append(new_target)
            
    # create datasets for image sequences and for targets and then zip the two together
    sequences_dataset = tf.data.Dataset.from_tensor_slices(mnist_seq)
    targets_dataset = tf.data.Dataset.from_tensor_slices(mnist_targets)
    mnist_dataset = tf.data.Dataset.zip((sequences_dataset, targets_dataset))
    

    # cache, batch and prefetch the new dataset
    mnist_dataset = mnist_dataset.cache().batch(batch_size).prefetch(10)
    
    return mnist_dataset

In [None]:
# create training and testing data sets 
train_dataset = prepare_data(train_ds, 32, 4)
test_dataset = prepare_data(test_ds, 32, 4)

# print how a batch looks like
iterator = iter(train_dataset)
iterator.get_next()

# 2.2 CNN and LSTM Network
- first part: basic CNN structure
- should extract vector representations from each MNIST image using Conv2D layers as well as (global) pooling or Flatten layers
- Conv2D layer can be called on a batch of sequences of images, where the time dimension is in the second axis; time dimension will then be processed like a second batch dimension -> extended batch shape
while Conv2D layers accept a (batch, sequence-length, image) data structure with their extended batch size functionality, for the pooling layers to work correctly they need to be wrapped in TensorFlow’s TimeDistributed layers!
- Once all images are encoded as vectors, the shape of the tensor should be (batch, sequence-length, features)

In [77]:
class CNN(tf.keras.Model):
    
    def __init__(self):  
        super().__init__()
        
        # layers
        self.conv1 = TimeDistributed(Conv2D(filters=24, kernel_size=3, padding='same', activation='relu'))
        self.conv2 = TimeDistributed(Conv2D(filters=24, kernel_size=3, padding='same', activation='relu'))
        self.maxpool = TimeDistributed(MaxPooling2D(pool_size=2, strides=2))

        self.conv3 = TimeDistributed(Conv2D(filters=48, kernel_size=3, padding='same', activation='relu'))
        self.conv4 = TimeDistributed(Conv2D(filters=48, kernel_size=3, padding='same', activation='relu'))
        self.globalpool = TimeDistributed(GlobalAvgPool2D())

        self.out = TimeDistributed(Dense(10, activation='softmax'))
        
    @tf.function
    def __call__(self, x, training=False):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.maxpool(x)
        
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.globalpool(x)
        
        x = self.out(x)
        return x

# 2.3 LSTM AbstractRNNcell layer
- subclass the AbstractRNNCell layer and implement its methods and define the required properties (state size, output size, and get initial state, which determines the initial hidden and cell state of the LSTM (usually tensors filled with zeros))
- LSTM-cell layer’s call method should take one (batch of) feature vector(s) as its input, along with the ”states”, a list containing the different state tensors of the LSTM cell (cell state and hidden state!)

In [78]:
class LSTMCell(tf.keras.layers.AbstractRNNCell):

    def __init__(self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
        super().__init__(trainable, name, dtype, dynamic, **kwargs)

        self.hidden_states = 25
        self.cell_states = 25
        
        self.layer1 = Dense(self.hidden_states)
        self.layer2 = Dense(self.cell_states)
        
        # first recurrent layer in the RNN
        self.rnn_layer_1 = Dense(self.hidden_states, 
                                                       kernel_initializer= tf.keras.initializers.Orthogonal(
                                                           gain=1.0, seed=None),
                                                       activation=tf.nn.sigmoid)
        # layer normalization for trainability
        self.layer_norm_1 = tf.keras.layers.LayerNormalization()
        
        # second recurrent layer in the RNN
        self.rnn_layer_2 = Dense(self.cell_states, 
                                                       kernel_initializer= tf.keras.initializers.Orthogonal(
                                                           gain=1.0, seed=None), 
                                                       activation=tf.nn.tanh)
        # layer normalization for trainability
        self.layer_norm_2 = tf.keras.layers.LayerNormalization()
    
    @property
    def state_size(self):
        return [tf.TensorShape([self.hidden_states]), 
                tf.TensorShape([self.cell_states])]
    @property
    def output_size(self):
        return [tf.TensorShape([self.cell_states])]
    
    def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
        return [tf.zeros([self.hidden_states]), 
                tf.zeros([self.cell_states])]

    # call method takes (batch of) feature vector(s) as its input, along with the ”states” 
    def call(self, inputs, states):
        
        hidden_state = states[0]
        cell_state = states[1]
        
        # linearly project input
        x = self.layer1(inputs) + hidden_state
        
        # apply first recurrent kernel
        new_state_layer_1 = self.rnn_layer_1(x)
        
        # apply layer norm
        x = self.layer_norm_1(new_state_layer_1)
        
        # linearly project output of layer norm
        x = self.layer2(x) + cell_state
        
        # apply second recurrent layer
        new_state_layer_2 = self.rnn_layer_2(x)
        
        # apply second layer's layer norm
        x = self.layer_norm_2(new_state_layer_2)
        
        # return output and the list of new states of the layers
        return x, [new_state_layer_1, new_state_layer_2]

    def get_config(self):
        return {"hidden_states": self.hidden_states,
                "cell_states": self.cell_states}

# 2.4 Wrapping LSTM Cell layer with RNN layer
- tf.keras.layers.RNN takes an instance of your LSTM cell as the first argument in its constructor
- the ”wrapper” RNN layer then takes the sequence of vector representations of the mnist images as its input (batch, seq len, feature dim)
- need to specify whether you want the RNN wrapper layer to return the output of your LSTM-cell for every time-step or only for the last step (with the argument return sequences=True) -> generally task-dependent (so think about what makes sense here)
- for speed-ups (at the cost of memory usage), set the ”unroll” argument to True
# 2.5 Computing model output
-could (if the task demands it) use the same Dense layer to predict targets for all time-steps; but likely do not want to have a Dense layer for each time-step’s target prediction (potential for overfitting!)

In [79]:
class LSTMModel(tf.keras.Model):
    def __init__(self, cnn, lstm_cell, optimizer, loss_function):
        super().__init__()
        
        self.cnn = CNN
        self.lstm_cell = LSTMCell
        self.output_layer = Dense(36,activation='softmax')

        self.metrics_list = [
            tf.keras.metrics.CategoricalAccuracy(name="accuracy"),
            tf.keras.metrics.Mean(name="loss")]

        self.optimizer = optimizer
        self.loss_function = loss_function

    @property
    def metrics(self):
        return self.metrics_list
    
    def reset_metrics(self):
     for metric in self.metrics:
        metric.reset_state()
    
    def call(self, sequence, training = False):
        cnn_output = self.cnn(sequence)
        cnn_number = tf.argmax(cnn_output, axis=-1)
        lstm_output = self.lstm_cell(cnn_output)
        output = self.output_layer(lstm_output)
        return output
    
    @tf.function
    def training_step(self, image, label):

        with tf.GradientTape() as tape: 
            prediction = self(image, training = True)

            loss = self.loss_function(label, prediction)

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients,self.trainable_variables))
        self.metrics[0].update_state(label, prediction)
        self.metrics[1].update_state(loss)  

    @tf.function
    def test_step(self, data):
        image, label = data
        prediction = self(image, training = False)
        loss = self.loss_function(label, prediction)
        self.metrics[0].update_state(label, prediction)
        self.metrics[1].update_state(loss)
        return {m.name : m.result() for m in self.metrics}        

# 2.6 Training
- own training loop or model.compile and model.fit methods
- track experiments properly, save configs (e.g. hyperparameters) of settings, save logs (e.g. with Tensorboard) and checkpoint the model’s weights (or even the complete model)
- visualize your results (e.g default history callback of model.fit)

In [81]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss = tf.keras.losses.CategoricalCrossentropy()

model = LSTMModel(CNN, LSTMCell, optimizer, loss)

model.compile(optimizer=optimizer, loss=loss)

In [82]:
EXPERIMENT_NAME = "Run_1"
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logging_callback = tf.keras.callbacks.TensorBoard(log_dir=f"./logs/{EXPERIMENT_NAME}/{current_time}")

In [None]:
history = model.fit(train_dataset, validation_data=test_dataset, initial_epoch=0, epochs=5, callbacks=([logging_callback]))

In [None]:
# save the complete model (incl. optimizer state, loss function, metrics etc.)
# ideally save to google drive if you're using colab
model.save("saved_model")

In [None]:
# load the model and resume training where we had to stop
loaded_model = tf.keras.models.load_model("saved_model", custom_objects={"LSTMCell": LSTMCell,
                                                                         "LSTMModel": LSTMModel})

In [None]:
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.legend(labels=["training","validation"])
plt.xlabel("Epoch")
plt.ylabel("Categorical Crossentropy Loss")
plt.show()

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir="logs/Run_1"