In [None]:
import numpy as np
import tensorflow as tf
import mdn
import time

In [None]:
import tensorflow.keras.backend as K

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import LSTM, Dense, InputLayer, Attention
from tensorflow.keras.layers import (Conv2D, Input, Reshape, 
                                     Lambda, Dense, Conv2DTranspose)

if tf.test.is_gpu_available():
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
import matplotlib.pyplot as plt
import os
import utils

from tensorflow.keras.callbacks import (EarlyStopping, ModelCheckpoint, 
                                       TensorBoard, Callback)
import datetime
from time import time
from utils import TrainTimeCallback

In [None]:
import matplotlib.pyplot as plt

In [None]:
import models

In [None]:
seq_len = 128
act_len = 3
n_mixtures = 5
output_dims = 32

In [None]:
!free -m

In [None]:
num_instances = len(os.listdir('./sausage/states'))

In [None]:
def load_folder(path):
    files = sorted(os.listdir(path))
    _1 = np.load(os.path.join(path, files[0]))
    data = np.zeros((len(files), *_1.shape))
    for i, fname in enumerate(files):
        data[i] = np.load(os.path.join(path, fname))
    return data

In [None]:
z_states = load_folder('./sausage/z_states')

In [None]:
z_states.shape

In [None]:
actions = load_folder('./sausage/actions')

In [None]:
actions.shape

In [None]:
154624 / 128 / 128

In [None]:
pair = np.concatenate((z_states, actions), axis=1)

In [None]:
pair.shape

In [None]:
latent_dataset = tf.data.Dataset.from_tensor_slices(pair)

In [None]:
sequences = latent_dataset.batch(seq_len + 1, drop_remainder=True)

In [None]:
sequences

In [None]:
def split_input_target(chunk):
    input_z = chunk[:-1]
    target_z = chunk[1:, :32]
    return input_z, target_z

In [None]:
dataset = sequences.map(split_input_target)

In [None]:
dataset

In [None]:
dataset = dataset.shuffle(10000).batch(utils.BATCH_SIZE, drop_remainder=True)

In [None]:
dataset

In [None]:
i = 0
for a, b in dataset:
    i += 1

In [None]:
percent_20 = i // 5
val = dataset.take(percent_20)
train = dataset.skip(percent_20)

In [None]:
val

In [None]:
train

In [None]:
!mkdir "./logs/"
!mkdir "./logs/fit"
!rm "./logs/fit/*"

In [None]:
log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

In [None]:
class BahdanauAttention(tf.keras.Model):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.units = units
        self.W1 = tf.keras.layers.Dense(units, input_shape=(35,))
        self.W2 = tf.keras.layers.Dense(units, input_shape=(256,))
        self.V = tf.keras.layers.Dense(1, input_shape=(256,))

    def call(self, features, hidden):
        # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)
        # hidden shape == (batch_size, hidden_size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden_size)
        # hidden_with_time_axis = tf.expand_dims(hidden, 1)
        # 
        features = tf.expand_dims(features, 0)
        hidden = tf.expand_dims(hidden, 0)
        # score shape == (batch_size, 64, hidden_size)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden))

        # attention_weights shape == (batch_size, 64, 1)
        # you get 1 at the last axis because you are applying score to self.V
        attention_weights = tf.nn.softmax(self.V(score), axis=1)

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attention_weights * features
        # context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights
    
    def get_config(self):
        config = super(BahdanauAttention, self).get_config()
        config.update({
            'units':self.units
        })
        return config
    
    def from_config(cls, config):
        return cls(**config)
    

In [None]:
class attention_mdn_rnn(tf.keras.Model):
    def __init__(self, 
                seq_len=128, 
                act_len=3, 
                latent_size=32, 
                cells=256, 
                output_dim=32, 
                n_mixes=5):
        super(attention_mdn_rnn, self).__init__()

        
        self.seq_len=seq_len
        self.act_len=act_len
        self.latent_size=latent_size
        self.cells=cells
        self.output_dim=output_dim
        self.n_mixes=n_mixes
        
        #self.inputs = Input((None, self.act_len + self.latent_size))
        self.lstm   = LSTM(self.cells,
                            return_sequences=True,
                            return_state=True,
                            recurrent_initializer='glorot_uniform')

        self.attention = BahdanauAttention(self.cells)
        self.out       = mdn.MDN(self.output_dim, self.n_mixes)
        
    def call(self, x, hidden):

        context_vector, attention_weights = self.attention(x, hidden)
        #context_vector = context_vector.numpy().squeeze()
        
        # context_vector = features * attention_weights
        x, hidden_out, c = self.lstm(context_vector[0]) #remove 1 from input shape (1, x, y, z)
        x = self.out(x)
        
        return x, hidden_out#, attention_weights

    def get_config(self):
        config = super(attention_mdn_rnn, self).get_config()
        config.update({'seq_len':self.seq_len,
                        'act_len':self.act_len,
                        'latent_size':self.latent_size,
                        'cells':self.cells,
                        'output_dim':self.output_dim,
                        'n_mixes':self.n_mixes})
        return config

    def from_config(cls, config):
        return cls(**config)
    
    def reset_state(self, batch_size):
        return tf.zeros((batch_size, self.cells))

In [None]:
M = attention_mdn_rnn()

In [None]:
loss_function = mdn.get_mixture_loss_func(32, 5)

In [None]:
optimizer = tf.keras.optimizers.Adam()

In [None]:
loss_plot = []

In [None]:
num_mixes = 5
output_dim = 32

In [None]:
@tf.function
def train_step(pair, target):
    loss = 0
    hidden = M.reset_state(128)

    with tf.GradientTape() as tape:
        z, hidden = M(pair, hidden)
        
        try:
            loss += loss_function(target, z)
        except Exception as e:
            import pdb; pdb.set_trace()
    
    
    total_loss = (loss / int(target.shape[1]))

    trainable_variables = M.trainable_variables

    gradients = tape.gradient(loss, trainable_variables)

    optimizer.apply_gradients(zip(gradients, trainable_variables))

    return loss, total_loss

In [None]:
start_epoch = 0

In [None]:
tf.keras.backend.set_floatx('float64')

In [None]:
val_loss = []

In [None]:
import time

In [None]:
EPOCHS = 3

for epoch in range(start_epoch, EPOCHS):
    start = time.time()
    total_loss = 0

    for (batch, (z_tensor, target)) in enumerate(train):
        batch_loss, t_loss = train_step(z_tensor, target)
        total_loss += t_loss

#         if batch % 5 == 0:
#             print ('Epoch {} Batch {} Loss {:.4f}'.format(
#               epoch + 1, batch, batch_loss.numpy() / int(target.shape[1])))

    # storing the epoch end loss value to plot later
    loss_plot.append(total_loss / i)

    for (batch, (z_tensor, target)) in enumerate(val):
        batch_loss, t_loss = train_step(z_tensor, target)
        total_loss += t_loss
        
    val_loss.append(total_loss / i)

    # print ('Epoch {} Loss {:.6f}'.format(epoch + 1,
    #                                      total_loss/i))
    # print ('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

In [None]:
M.save_weights('./data/weights/attn_mdn_rnn', save_format='tf') 

In [None]:
plt.plot(loss_plot)

In [None]:
plt.plot(val_loss)

In [None]:
!pwd