In [1]:
import time
import datetime
 
import matplotlib
import matplotlib.pyplot as plt
 
import tensorflow as tf
import numpy as np

# Helper Function

In [2]:
class Timer():
    """
    A small class to measure time during training.
    """
    def __init__(self):
        self._start_time = None

    def start(self):
        """
        Start a new timer
        """
        self._start_time = time.perf_counter()

    def stop(self):
        """
        Stop the timer, and report the elapsed time
        """
        if self._start_time is None:
            print(f"Timer is not running. Use .start() to start it")
            return 0
    
        elapsed_time = time.perf_counter() - self._start_time
        self._start_time = None
        
        return elapsed_time
    
def eval_metrics(metrics, metric_lists):
    for m, m_l in zip(metrics, metric_lists):
        val = m.result()
        m_l.append(val)
        m.reset()

In [10]:
def data_gen():
    
#     time_steps_days = 170*365
#    time_steps_months = 170*12 not necessary, as we interpolate/repeat labels to correspond to the time_steps
    long = 36
    lat = 72
    
    data = np.random.randint(low = 0, high = 100, size=(long,lat,2), dtype=np.int32)
#     labels = np.random.randint(low=-5, high = 200,size=(lat,long,2), dtype=np.int32)
    
    while True:
        yield data

In [13]:
input_shape =(4, 36, 72, 2, 1)
x = tf.random.normal(input_shape)
print(input_shape[1:])

y = tf.keras.layers.Conv3D(5, 3, activation='relu', input_shape=input_shape[1:], padding='same')(x)
print(y.shape)

(36, 72, 2, 1)
(4, 36, 72, 2, 5)


# Autoencoder

In [46]:
class CNN_Encoder(tf.keras.layers.Layer):
    def __init__(self, input_dim, latent_dim):
        super(CNN_Encoder, self).__init__()
        
        self.layers = []
        
        self.layers.append(tf.keras.layers.Conv2D(filters=32,
                                                  kernel_size=3,
                                                  strides=2,
                                                  padding='same',
                                                  input_shape=input_dim))  
        self.layers.append(tf.keras.layers.BatchNormalization())
        self.layers.append(tf.keras.layers.Activation('relu'))

        self.layers.append(tf.keras.layers.Conv2D(filters=64,
                                                  kernel_size=3,
                                                  strides=2,
                                                  padding='same'))  
        self.layers.append(tf.keras.layers.BatchNormalization())
        self.layers.append(tf.keras.layers.Activation('relu'))

        self.layers.append(tf.keras.layers.Flatten())
        self.layers.append(tf.keras.layers.Dense(latent_dim, activation='relu'))
        
    def call(self, x, training=False):
        for layer in self.layers:
            try:  # training argument only for BN layer
                x = layer(x, training) 
#                 print(x.shape)
            except:
                x = layer(x)
#                 print(x.shape)
        return x
    
    def make_untrainable():
        for layer in self.layers:
            layer.trainable = False

class CNN_Decoder(tf.keras.layers.Layer):
    def __init__(self, latent_dim, output_dim, restore_shape):
        super(CNN_Decoder, self).__init__()
        self.layers = []
        
        # dense layer to restore dim of flattend data
        self.layers.append(tf.keras.layers.Dense(units=int(tf.math.reduce_prod((restore_shape))),
                                                 input_shape=(latent_dim,)))
        self.layers.append(tf.keras.layers.BatchNormalization())
        self.layers.append(tf.keras.layers.Activation('relu'))
        
        self.layers.append(tf.keras.layers.Reshape(target_shape=restore_shape))        
        self.layers.append(tf.keras.layers.Conv2DTranspose(filters=32,
                                                           kernel_size=3,
                                                           strides=2,
                                                           padding='same'))
        self.layers.append(tf.keras.layers.BatchNormalization())
        self.layers.append(tf.keras.layers.Activation('relu'))
       
        self.layers.append(tf.keras.layers.Conv2DTranspose(filters=2,
                                                  kernel_size=3,
                                                  strides=2,
                                                  padding='same'))  
        self.layers.append(tf.keras.layers.BatchNormalization())
        self.layers.append(tf.keras.layers.Activation('sigmoid'))
       
    def call(self, x, training=False):
        for layer in self.layers:
            try:  # training argument only for BN layer
                x = layer(x, training) 
#                 print(x.shape)
            except:
                x = layer(x)
#                 print(x.shape)
        return x
    
    def make_untrainable():
        for layer in self.layers:
            layer.trainable = False
        
class CNN_Autoencoder(tf.keras.Model):
    def __init__(self, input_dim, latent_dim, restore_shape):
        super(CNN_Autoencoder, self).__init__()
        # encoder and decoder are symmetric
        self.encoder = CNN_Encoder(input_dim=input_dim,
                                   latent_dim=latent_dim)
        
        self.decoder = CNN_Decoder(latent_dim=latent_dim,
                                   output_dim=input_dim,
                                   restore_shape=restore_shape)
        
    def call(self, x, training=False):
        x = self.encoder(x, training)
        self.latent_repr = x  # keep latent_repr as property in case it should be analyzed
        x = self.decoder(x, training)
        
        return x
    
    def make_untrainable():
        self.trainable = False
        self.encoder.make_untrainable()
        self.decoder.make_untrainable()

In [48]:
input_shape = (4, 36, 72, 2)
data = tf.random.normal(input_shape)
print(input_shape)

model_AE = CNN_Autoencoder(input_dim=(None, lat, long, 2), 
                           latent_dim=1000,
                           restore_shape=(9, 18, 64))

latent = model_AE.encoder(data)
x = model_AE.decoder(latent)

(4, 36, 72, 2)
(4, 18, 36, 32)
(4, 18, 36, 32)
(4, 18, 36, 32)
(4, 9, 18, 64)
(4, 9, 18, 64)
(4, 9, 18, 64)
(4, 10368)
(4, 1000)
(4, 10368)
(4, 10368)
(4, 10368)
(4, 9, 18, 64)
(4, 18, 36, 32)
(4, 18, 36, 32)
(4, 18, 36, 32)
(4, 36, 72, 2)
(4, 36, 72, 2)
(4, 36, 72, 2)


# Training the Autoencoder

In [None]:
@tf.function
def train_step_AE(model, train_ds, loss_function, optimizer, train_loss_metric):
    '''
    Training for one epoch. Adjusted for Autoencoder as there are no acc_metric.
    '''
    for img, label in train_ds:  # there are no (input,label) pairs
        with tf.GradientTape() as tape:
            # forward pass
            prediction = model(img, training=True)
            loss = loss_function(img, prediction) 

        # backward pass
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # update metrics
        train_loss_metric.update_state(loss)

@tf.function
def eval_step_AE(model, ds, loss_function, loss_metric):
    '''
    Evaluate without training. Adjusted for autoencoder. 
    Return a random image and reconstructed version of it.
    '''
    prediction = 0.0
    img = 0.0

    for img, label in ds:
        # forward pass
        prediction = model(img, training=False)
        
        # update metrics
        loss = loss_function(img, prediction)
        loss_metric.update_state(loss)
        
    return img, prediction

In [None]:
# Hyperparameter
epochs = 25
laerning_rate = 0.001

tf.keras.backend.clear_session()
timer = Timer()

model_AE = CNN_Autoencoder(input_dim=(None,28,28,1), 
                     latent_dim=10,
                     restore_shape=(9, 18, 64))

loss_function = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(laerning_rate)

In [None]:
metrics_AE = [tf.keras.metrics.Mean('train_loss'),
              tf.keras.metrics.Mean('test_loss')]

# loss[0] - train, loss[1] - test
losses_AE = []

times = []

In [None]:
print(f'[INFO] - Evaluating the Dataset on the {model_AE.name} before training.')
timer.start()

# evaluate once before training 
eval_step_AE(model_AE,
             train_ds,
             loss_function, 
             loss_metric=train_loss_metric)

eval_step_AE(model_AE,  
             test_ds,
             loss_function, 
             loss_metric=test_loss_metric)

# Evaluate the metrics
eval_metrics(metrics=metrics_AE, metric_lists=losses_AE)

# Evaluate the timer
elapsed_time = timer.stop()
times.append(elapsed_time)

print(f'train_loss: {train_loss:0.4f}, test_loss: {test_loss:0.4f}')

for epoch in range(epochs):
    print(f'\n[EPOCH] ____________________{epoch}____________________')
    
    timer.start()
    
    # Training step    
    train_step_AE(model_AE, train_ds, loss_function, optimizer, train_loss_metric)
    
    # Test step    
    img_original, img_reconstructed = eval_step_AE(model_AE, test_ds, loss_function, test_loss_metric)
    
    # Evaluate the metrics
    eval_metrics(metrics=metrics_AE, metric_lists=losses_AE)
    
    elapsed_time = timer.stop()
    times.append(elapsed_time)

    print(f'[{epoch}] - Finished Epoch in {elapsed_time:0.2f} seconds - train_loss: {train_loss:0.4f}; test_loss: {test_loss:0.4f}')
    
    # print progress every while
    if epoch%5 == 0:
        print(f'\n[INFO] - Total time elapsed: {np.sum(times)/60:0.4f} min. Total time remaining: {(np.sum(times)/(epoch+1))*(EPOCHS-epoch-1)/60:0.4f} min.')
        
        # Visualize reconstructed image      
        plt.figure(figsize=(1, 2))

        plt.subplot(121)
        plt.title('Original Image')
        plt.imshow(img_original[0,:,:,0], cmap='gray')

        plt.subplot(122)
        plt.title('Reconstructed Image')
        plt.imshow(img_reconstructed[0,:,:,0], cmap='gray')

        plt.tight_layout()

        plt.show()
                    
print(f'\n[INFO] - Total run time: {np.sum(times)/60:0.4f} min.')

# LSTM Model

In [51]:
class RNN(tf.keras.Model):
    def __init__(self, rnn_units, autoencoder):
        super().__init__(GRU_AE, self)
        self.autoencoder = autoencoder
        self.autoencoder.make_untrainable()      
        
        self.gru = tf.keras.layers.GRU(rnn_units,
                                       return_sequences=True)
        
    def call(self, x, training=False):
        x = self.autoencoder.encode(x)
        x = self.gru(x, training=training)
        x = self.autoencoder.decode(x)

        return x

In [None]:
# Hyperparameter
epochs = 25
learning_rate = 0.001

tf.keras.backend.clear_session()
timer = Timer()

model_RNN = RNN(40, model_AE)

loss_function = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(learning_rate)

In [None]:
metrics_RNN = [tf.keras.metrics.Mean('train_loss'),
              tf.keras.metrics.Mean('test_loss')]

# loss[0] - train, loss[1] - test
losses_RNN = []

times = []

In [None]:
@tf.function
def train_step_RNN(model, train_ds, loss_function, optimizer, train_loss_metric):
    '''
    Training for one epoch. Adjusted for Autoencoder as there are no acc_metric.
    '''
    for x, label in train_ds:  # there are no (input,label) pairs
        with tf.GradientTape() as tape:
            # forward pass
            prediction = model(x, training=True)
            loss = loss_function(x, prediction) 

        # backward pass
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # update metrics
        train_loss_metric.update_state(loss)

# @tf.function
# def eval_step_RNN(model, ds, loss_function, loss_metric):
#     '''
#     Evaluate without training. Adjusted for autoencoder. 
#     Return a random image and reconstructed version of it.
#     '''
#     prediction = 0.0
#     img = 0.0

#     for img, label in ds:
#         # forward pass
#         prediction = model(img, training=False)
        
#         # update metrics
#         loss = loss_function(img, prediction)
#         loss_metric.update_state(loss)
        
#     return img, prediction

In [None]:
print(f'[INFO] - Evaluating the Dataset on the {model_AE.name} before training.')
timer.start()

# evaluate once before training 
eval_step_AE(model_AE,
             train_ds,
             loss_function, 
             loss_metric=train_loss_metric)

eval_step_AE(model_AE,  
             test_ds,
             loss_function, 
             loss_metric=test_loss_metric)

# Evaluate the metrics
eval_metrics(metrics=metrics_RNN, metric_lists=losses_RNN)

# Evaluate the timer
elapsed_time = timer.stop()
times.append(elapsed_time)

print(f'train_loss: {train_loss:0.4f}, test_loss: {test_loss:0.4f}')

for epoch in range(EPOCHS):
    print(f'\n[EPOCH] ____________________{epoch}____________________')
    
    timer.start()
    
    # Training step    
    train_step_AE(model_AE, train_ds, loss_function, optimizer, train_loss_metric)
    
    # Test step    
    img_original, img_reconstructed = eval_step_AE(model_AE, test_ds, loss_function, test_loss_metric)
    
    # Evaluate the metrics
    eval_metrics(metrics=metrics_RNN, metric_lists=losses_RNN)
    
    elapsed_time = timer.stop()
    times.append(elapsed_time)

    print(f'[{epoch}] - Finished Epoch in {elapsed_time:0.2f} seconds - train_loss: {train_loss:0.4f}; test_loss: {test_loss:0.4f}')
    
    # print progress every while
    if epoch%5 == 0:
        print(f'\n[INFO] - Total time elapsed: {np.sum(times)/60:0.4f} min. Total time remaining: {(np.sum(times)/(epoch+1))*(EPOCHS-epoch-1)/60:0.4f} min.')
        
        # Visualize reconstructed image      
        plt.figure(figsize=(9, 3))

        plt.subplot(121)
        plt.title('Original Image')
        plt.imshow(img_original[0,:,:,0], cmap='gray')

        plt.subplot(122)
        plt.title('Reconstructed Image')
        plt.imshow(img_reconstructed[0,:,:,0], cmap='gray')

        plt.tight_layout()

        plt.show()
                    
print(f'\n[INFO] - Total run time: {np.sum(times)/60:0.4f} min.')