In [None]:
import tensorflow as tf
import numpy as np
import sys
import time
import matplotlib.pyplot as plt

class S2S(object):

    def __init__(self, x_len, y_len, x_vocab_size, y_vocab_size, emb_dim, num_layers, ckpt_path, lr=0.001, epochs=10000, model_name='s2s_model'):

        self.x_len = x_len
        self.y_len = y_len
        self.ckpt_path = ckpt_path
        self.epochs = epochs
        self.model_name = model_name

        print('x_len', x_len)
        print('y_len', y_len)
        print('ckpt_path', ckpt_path)
        print('epochs', epochs)
        print('model_name', model_name)
        print('num_layers', num_layers)

        def count_number_trainable_params():
            tot_nb_params = 0
            for trainable_variable in tf.trainable_variables():
                shape = trainable_variable.get_shape() # e.g [D,F] or [W,H,C]
                current_nb_params = get_nb_params_shape(shape)
                tot_nb_params = tot_nb_params + current_nb_params
            return tot_nb_params

        def get_nb_params_shape(shape):
            nb_params = 1
            for dim in shape:
                nb_params = nb_params*int(dim)
            return nb_params 

        def __graph__():
            tf.reset_default_graph()
            #ulazi na enkoder - maksimalan broj reci koji je dozvoljen x_len
            self.encoder_input = [tf.placeholder(shape=[None,], dtype=tf.float32,
            name='einput_{}'.format(t)) for t in range(x_len)]

            print('self.encoder_input', len(self.encoder_input))
            print('self.encoder_input[0]', self.encoder_input[0])

            #izlazi dekodera - maksimalan broj reci koji je dozvoljen y_len
            self.labels = [ tf.placeholder(shape=[None,], dtype=tf.float32, 
            name='doutput_{}'.format(t)) for t in range(y_len)]

            print('self.labels', len(self.labels))
            print('self.labels[0]', self.labels[0])

            #ulazi na dekoder - '_GO' + [y1, y2, y3, ..., yn-1]
            self.decoder_input = [ tf.zeros_like(self.encoder_input[0], dtype=tf.float32, name='GO') ] + self.labels[:-1]
            print('self.decoder_input', len(self.decoder_input))
            print('self.decoder_input[0]', self.decoder_input[0])


            #RNN Cell - Dropout regularizacija oko RNN celije
            self.keep_prob = tf.placeholder(tf.float32)
            
            cell = [tf.nn.rnn_cell.DropoutWrapper(
                tf.nn.rnn_cell.BasicLSTMCell(emb_dim), output_keep_prob = self.keep_prob) for i in range(num_layers)]

            stacked_cells = tf.nn.rnn_cell.MultiRNNCell(cell)

            with tf.variable_scope('decoder') as scope:
                self.decode_outputs, self.decode_states = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(
                    self.encoder_input,
                    self.decoder_input,
                    stacked_cells,
                    x_vocab_size,
                    y_vocab_size,
                    emb_dim
                )

                scope.reuse_variables()

                self.decode_outputs_test, self.decode_states_test = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(
                    self.encoder_input,
                    self.decoder_input,
                    stacked_cells,
                    x_vocab_size,
                    y_vocab_size,
                    emb_dim,
                    feed_previous=True
                )

            #loss funkcija
            loss_weights = [tf.ones_like(label, dtype=tf.float32) for label in self.labels]
            self.loss = tf.contrib.legacy_seq2seq.sequence_loss(
                self.decode_outputs,
                self.labels,
                loss_weights,
                y_vocab_size
            )
            #tf.summary.scalar('loss', self.loss)

            #optimizer
            self.train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.loss)
            #self.train_op = tf.train.Ada(learning_rate=lr).minimize(self.loss)

        sys.stdout.write('Building graph...\n')
        __graph__()
        #print('param num method 1:', np.sum([np.product([xi.value for xi in x.get_shape()]) for x in tf.all_variables()]))
        #print('param num method 2:', count_number_trainable_params())
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            #print('name', variable)
            print(shape)
            #print(len(shape))
            variable_parameters = 1
            for dim in shape:
                #print('###', dim.value)
                variable_parameters *= dim.value
            #print(variable_parameters)
            total_parameters += variable_parameters
        print('param num method 3:', total_parameters)
        sys.stdout.write('Done...\n')

    """
    training
    """
    
    def get_feed(self, x, y, keep_prob):
        feed_dict = {self.encoder_input[t]: x[t] for t in range(self.x_len)}
        feed_dict.update({self.labels[t]: y[t] for t in range(self.y_len)})
        feed_dict[self.keep_prob] = keep_prob
        return feed_dict

    def train_batch(self, sess, train_batch_gen):
        batchX, batchY = train_batch_gen.__next__()

        feed_dict = self.get_feed(batchX, batchY, keep_prob=0.5)
        _, loss_v = sess.run([self.train_op, self.loss], feed_dict)
        return loss_v

    def evaluate_steps(self, sess, eval_batch_gen):
        batchX, batchY = eval_batch_gen.__next__()

        feed_dict = self.get_feed(batchX, batchY, keep_prob=1.)
        loss_v, dec_op_v = sess.run([self.loss, self.decode_outputs_test], feed_dict)

        dec_op_v = np.array(dec_op_v).transpose([1, 0, 2])
        return loss_v, dec_op_v, batchX, batchY

    def eval_batches(self, sess, eval_batch_gen, num_batches):
        losses = []
        for i in range(num_batches):
            loss_v, dec_op_v, batchX, batchY = self.evaluate_steps(sess, eval_batch_gen)
            losses.append(loss_v)
        return np.mean(losses)

    def train(self, train_set, valid_set, sess=None):
        save_every = 1000
        saver = tf.train.Saver()
        loss_history = []
        loss_val_history = []
        loss_test_history = []
        execution_time_history = []
        if not sess:
            print('Pravim novu sesiju...')
            sess = tf.Session()
            sess.run(tf.global_variables_initializer())
            #merged = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter('/path/s2s', graph = sess.graph)
            #tf.global_variables_initializer().run()
            #sess.run([])
            
        sys.stdout.write('>>>Training started...\n')

        for i in range(self.epochs):
            try:
                start = time.time()
                loss = self.train_batch(sess, train_set)
                end = time.time()
                
                print('Loss', loss, 'at iteration', i, '/', self.epochs, 'Time -', end - start)

                loss_history.append(loss)
                if i and i%50 == 0:
                    val_loss = self.eval_batches(sess, valid_set, 16)
                    print('val_loss', val_loss)
                    loss_val_history.append(val_loss)
                    if val_loss < 3:
                        print('Validation set loss is less than 3. Stopping...')
                        break

                if i and i%save_every == 0:
                    saver.save(sess, self.ckpt_path + self.model_name + '.ckpt', global_step=i)
                    val_loss = self.eval_batches(sess, valid_set, 16)

                    print('Model saved after', i, 'iterations.')
                    print('Validate loss:', val_loss)
                    sys.stdout.flush()
            except KeyboardInterrupt:
                print('Interrupted by user at iteration', i)
                plt.plot(range(len(loss_history)), loss_history, label='Funkcija greske')
                #plt.plot(range(len(loss_val_history)), loss_val_history, label='Funkcija greske za validacioni skup')
                plt.legend()
                plt.show()

                plt.plot(range(len(execution_time_history)), execution_time_history, label='Vreme iteracija')
                plt.legend()
                plt.show()

                plt.plot(range(len(loss_val_history)), loss_val_history, label='Funkcija greske za validacioni skup')
                plt.legend()
                plt.show()
                self.session = sess
                return sess
        
        plt.plot(range(len(loss_history)), loss_history, label='Funkcija greske')
        #plt.plot(range(len(loss_val_history)), loss_val_history, label='Funkcija greske za validacioni skup')
        plt.legend()
        plt.show()

        plt.plot(range(len(execution_time_history)), execution_time_history, label='Vreme iteracija')
        plt.legend()
        plt.show()

        plt.plot(range(len(loss_val_history)), loss_val_history, label='Funkcija greske za validacioni skup')
        plt.legend()
        plt.show()

        summary_writer.close()

    def restore_last_session(self):
        saver = tf.train.Saver()
        sess = tf.Session()
        ckpt = tf.train.get_checkpoint_state(self.ckpt_path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('Session restored')
        return sess

    def predict(self, sess, x):
        feed_dict = {self.encoder_input[t]: x[t] for t in range(self.x_len)}
        feed_dict[self.keep_prob] = 1.0
        dec_op_v = sess.run(self.decode_outputs_test, feed_dict)

        dec_op_v = np.array(dec_op_v).transpose([1, 0, 2])

        return np.argmax(dec_op_v, axis=2)

    def advance_predict(self, sess, x, axis=3):
        feed_dict = {self.encoder_input[t]: x[t] for t in range(self.x_len)}
        feed_dict[self.keep_prob] = 1.0
        dec_op_v = sess.run(self.decode_outputs_test, feed_dict)

        dec_op_v = np.array(dec_op_v).transpose([1, 0, 2])

        if axis == 3:
            return np.argmax(dec_op_v)
        else:
            return np.argmax(dec_op_v, axis=axis)
