This notebook contains code from Hareesh Bahuleyan's 'probabilistic_nlg' repository (https://github.com/HareeshBahuleyan/probabilistic_nlg) and runs the standard deterministic WAE. It was designed specifically to be run in the Colaboratory environment.

### SET-UP ###

In [None]:
#I will attach my Drive so that I can easily load files.
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

In [None]:
import nltk
nltk.download('punkt')

### UTILS ###

In [None]:
import numpy as np
import gensim
import re
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import corpus_bleu
from nltk.collocations import BigramCollocationFinder
from nltk.probability import FreqDist
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences


def calculate_bleu_scores(references, hypotheses):
    """
    Calculates BLEU 1-4 scores based on NLTK functionality
    Args:
        references: List of reference sentences
        hypotheses: List of generated sentences
    Returns:
        bleu_1, bleu_2, bleu_3, bleu_4: BLEU scores
    """
    bleu_1 = np.round(100 * corpus_bleu(references, hypotheses, weights=(1.0, 0., 0., 0.)), decimals=2)
    bleu_2 = np.round(100 * corpus_bleu(references, hypotheses, weights=(0.50, 0.50, 0., 0.)), decimals=2)
    bleu_3 = np.round(100 * corpus_bleu(references, hypotheses, weights=(0.34, 0.33, 0.33, 0.)), decimals=2)
    bleu_4 = np.round(100 * corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25)), decimals=2)
    return bleu_1, bleu_2, bleu_3, bleu_4


def calculate_ngram_diversity(corpus):
    """
    Calculates unigram and bigram diversity
    Args:
        corpus: tokenized list of sentences sampled
    Returns:
        uni_diversity: distinct-1 score
        bi_diversity: distinct-2 score
    """
    bigram_finder = BigramCollocationFinder.from_words(corpus)
    bi_diversity = len(bigram_finder.ngram_fd) / bigram_finder.N

    dist = FreqDist(corpus)
    uni_diversity = len(dist) / len(corpus)

    return uni_diversity, bi_diversity


def calculate_entropy(corpus):
    """
    Calculates diversity in terms of entropy (using unigram probability)
    Args:
        corpus: tokenized list of sentences sampled
    Returns:
        ent: entropy on the sample sentence list
    """
    fdist = FreqDist(corpus)
    total_len = len(corpus)
    ent = 0
    for k, v in fdist.items():
        p = v / total_len

        ent += -p * np.log(p)

    return ent


def tokenize_sequence(sentences, filters, max_num_words, max_vocab_size):
    """
    Tokenizes a given input sequence of words.
    Args:
        sentences: List of sentences
        filters: List of filters/punctuations to omit (for Keras tokenizer)
        max_num_words: Number of words to be considered in the fixed length sequence
        max_vocab_size: Number of most frequently occurring words to be kept in the vocabulary
    Returns:
        x : List of padded/truncated indices created from list of sentences
        word_index: dictionary storing the word-to-index correspondence
    """

    sentences = [' '.join(word_tokenize(s)[:max_num_words]) for s in sentences]

    tokenizer = Tokenizer(filters=filters)
    tokenizer.fit_on_texts(sentences)

    word_index = dict()
    word_index['PAD'] = 0
    word_index['UNK'] = 1
    word_index['GO'] = 2
    word_index['EOS'] = 3

    for i, word in enumerate(dict(tokenizer.word_index).keys()):
        word_index[word] = i + 4

    tokenizer.word_index = word_index
    x = tokenizer.texts_to_sequences(list(sentences))

    for i, seq in enumerate(x):
        if any(t >= max_vocab_size for t in seq):
            seq = [t if t < max_vocab_size else word_index['UNK'] for t in seq]
        seq.append(word_index['EOS'])
        x[i] = seq

    x = pad_sequences(x, padding='post', truncating='post', maxlen=max_num_words, value=word_index['PAD'])

    word_index = {k: v for k, v in word_index.items() if v < max_vocab_size}

    return x, word_index


def create_embedding_matrix(word_index, embedding_dim, w2v_path):
    """
    Create the initial embedding matrix for TF Graph.
    Args:
        word_index: dictionary storing the word-to-index correspondence
        embedding_dim: word2vec dimension
        w2v_path: file path to the w2v pickle file
    Returns:
        embeddings_matrix : numpy 2d-array with word vectors
    """
    w2v_model = gensim.models.Word2Vec.load(w2v_path)
    embeddings_matrix = np.random.uniform(-0.05, 0.05, size=(len(word_index), embedding_dim))
    for word, i in word_index.items():
        try:
            embeddings_vector = w2v_model[word]
            embeddings_matrix[i] = embeddings_vector
        except KeyError:
            pass

    return embeddings_matrix


def get_sentences(file_path):
    with open(file_path, 'r') as f:
        data = f.readlines()

    return data


def clean_sentence(sent):
    sent = re.sub(r'[^\w\s\?\.\,]', '', sent.strip().lower())  # Lower case, remove punctuations (except , ? .)
    sent = re.sub(r'(([a-z]*)\d+.?\d*\%?)', ' NUM ', sent.strip())  # Replace Numbers with <NUM> token
    return sent


def get_batches(x, batch_size):
    """
    Generate inputs and targets in a batch-wise fashion for feed-dict
    Args:
        x: entire source sequence array
        batch_size: batch size
    Returns:
        x_batch, y_batch, sentence_length
    """

    for batch_i in range(0, len(x) // batch_size):
        start_i = batch_i * batch_size
        x_batch = x[start_i:start_i + batch_size]
        y_batch = x[start_i:start_i + batch_size]

        sentence_length = [np.count_nonzero(seq) for seq in x_batch]

        yield x_batch, y_batch, sentence_length


def get_batches_xy(x, y, batch_size):
    """
    Generate inputs and targets in a batch-wise fashion for feed-dict
    Args:
        x: entire source sequence array
        y: entire output sequence array
        batch_size: batch size
    Returns:
        x_batch, y_batch, source_sentence_length, target_sentence_length
    """

    for batch_i in range(0, len(x) // batch_size):
        start_i = batch_i * batch_size
        x_batch = x[start_i:start_i + batch_size]
        y_batch = y[start_i:start_i + batch_size]

        source_sentence_length = [np.count_nonzero(seq) for seq in x_batch]
        target_sentence_length = [np.count_nonzero(seq) for seq in y_batch]

        yield x_batch, y_batch, source_sentence_length, target_sentence_length


def create_data_split(x, y, dataset_sizes):
    """
    Create test-train split according to previously defined CSV files
    Depending on the experiment - qgen or dialogue
    Args:
        x: input sequence of indices
        y: output sequence of indices
    Returns:
        x_train, y_train, x_val, y_val, x_test, y_test: train val test split arrays
    """

    train_size, val_size, test_size = dataset_sizes[0], dataset_sizes[1], dataset_sizes[2],

    train_indices = range(train_size)
    val_indices = range(train_size, train_size + val_size)
    test_indices = range(train_size + val_size, train_size + val_size + test_size)

    x_train = x[train_indices]
    y_train = y[train_indices]
    x_val = x[val_indices]
    y_val = y[val_indices]
    x_test = x[test_indices]
    y_test = y[test_indices]

    return x_train, y_train, x_val, y_val, x_test, y_test


def plot_2d(zvectors, labels, method):
    if method == 'tsne':
        cluster = TSNE(n_components=2, random_state=17)
    else:  # PCA
        cluster = PCA(n_components=2, random_state=17)

    cluster_result = cluster.fit_transform(X=zvectors)
    labels = labels[:cluster_result.shape[0]]
    labels = np.array(labels)

    class_dict = {0: 'automobile', 1: 'home and kitchen'}
    fig, ax = plt.subplots()
    ax.figure.set_size_inches(w=10, h=10)
    ax.scatter(cluster_result[np.where(labels == 0), 0], cluster_result[np.where(labels == 0), 1], s=6,
               label=class_dict[0])
    ax.scatter(cluster_result[np.where(labels == 1), 0], cluster_result[np.where(labels == 1), 1], s=6,
               label=class_dict[1])
    plt.grid()
    plt.legend(fontsize=12)
    plt.show()


### EMBEDDINGS ### 

In [None]:
import os
import gensim
import numpy as np
from nltk.tokenize import word_tokenize

def main():

    snli_data = get_sentences(file_path='/content/drive/My Drive/sources.txt')

    print('[INFO] Number of sentences = {}'.format(len(snli_data)))

    sentences = [s.strip() for s in snli_data]

    np.random.shuffle(sentences)
    sentences = [word_tokenize(s) for s in sentences]
    w2v_model = gensim.models.Word2Vec(sentences,
                                       size=300,
                                       min_count=1,
                                       iter=50)
    if not os.path.exists('w2v_models'):
        os.mkdir('w2v_models')

    w2v_model.save('/content/drive/My Drive/w2v_models/fce_incorrect.pkl')
    print('[INFO] Word embeddings pre-trained successfully')


#if __name__ == '__main__':
#    main()

### DECODER ### 

In [None]:
import collections
import tensorflow as tf
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as layers_base
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.util import nest

__all__ = [
    "BasicDecoderOutput",
    "BasicDecoder",
]


class BasicDecoderOutput(collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))):
    pass


class BasicDecoder(decoder.Decoder):
    """Basic sampling decoder."""

    def __init__(self, cell, helper, initial_state, latent_vector, output_layer=None):
        """Initialize BasicDecoder.
        Args:
          cell: An `RNNCell` instance.
          helper: A `Helper` instance.
          initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
            The initial state of the RNNCell.
          output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
            `tf.layers.Dense`.  Optional layer to apply to the RNN output prior
            to storing the result or sampling.
        Raises:
          TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
        """
        #if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
        #    raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
        if not isinstance(helper, helper_py.Helper):
            raise TypeError("helper must be a Helper, received: %s" % type(helper))
        if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)):
            raise TypeError("output_layer must be a Layer, received: %s" % type(output_layer))
        self._cell = cell
        self._helper = helper
        self._initial_state = initial_state
        self._output_layer = output_layer
        self._latent_vector = latent_vector

    @property
    def batch_size(self):
        return self._helper.batch_size

    def _rnn_output_size(self):
        size = self._cell.output_size
        if self._output_layer is None:
            return size
        else:
      # To use layer's compute_output_shape, we need to convert the
      # RNNCell's output_size entries into shapes with an unknown
      # batch size.  We then pass this through the layer's
      # compute_output_shape and read off all but the first (batch)
      # dimensions to get the output size of the rnn with the layer
      # applied to the top.
            output_shape_with_unknown_batch = nest.map_structure(
            lambda s: tensor_shape.TensorShape([None]).concatenate(s),
            size)
            layer_output_shape = self._output_layer.compute_output_shape(  # pylint: disable=protected-access
                    output_shape_with_unknown_batch)
        return nest.map_structure(lambda s: s[1:], layer_output_shape)

    @property
    def output_size(self):
    # Return the cell output and the id
        return BasicDecoderOutput(
            rnn_output=self._rnn_output_size(),
            sample_id=tensor_shape.TensorShape([]))

    @property
    def output_dtype(self):
        # Assume the dtype of the cell is the output_size structure
        # containing the input_state's first component's dtype.
        # Return that structure and int32 (the id)
        dtype = nest.flatten(self._initial_state)[0].dtype
        return BasicDecoderOutput(
            nest.map_structure(lambda _: dtype, self._rnn_output_size()),
            dtypes.int32)

    def initialize(self, name=None):
        """Initialize the decoder.
        Args:
          name: Name scope for any created operations.
        Returns:
          `(finished, first_inputs, initial_state)`.
        """
        # Concatenate the latent vector to the 1st input to the decoder LSTM, i.e, the <GO> embedding + latent vector
        return (self._helper.initialize()[0], tf.concat([self._helper.initialize()[1], self._latent_vector], axis=-1)) + (self._initial_state,)

    def step(self, time, inputs, state, name=None):
        """Perform a decoding step.
        Args:
          time: scalar `int32` tensor.
          inputs: A (structure of) input tensors.
          state: A (structure of) state tensors and TensorArrays.
          name: Name scope for any created operations.
        Returns:
          `(outputs, next_state, next_inputs, finished)`.
        """
        with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
            cell_outputs, cell_state = self._cell(inputs, state)

            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)
            sample_ids = self._helper.sample(
                            time=time, outputs=cell_outputs, state=cell_state)
            (finished, next_inputs, next_state) = self._helper.next_inputs(
                                                  time=time,
                                                  outputs=cell_outputs,
                                                  state=cell_state,
                                                  sample_ids=sample_ids)

            # Concatenate the latent vector to the predicted word's embedding
            next_inputs = tf.concat([next_inputs, self._latent_vector], axis=-1)

        outputs = BasicDecoderOutput(cell_outputs, sample_ids)

        return (outputs, next_state, next_inputs, finished)


### MODEL ### 

In [None]:
import time
import pickle
import tensorflow as tf
import numpy as np
import os
from tqdm import tqdm
from nltk.tokenize import word_tokenize
from tensorflow.python.layers.core import Dense

class DetWAEModel(object):

    def __init__(self, config, embeddings_matrix, word_index):

        self.config = config

        self.lstm_hidden_units = config['lstm_hidden_units']
        self.embedding_size = config['embedding_size']
        self.latent_dim = config['latent_dim']
        self.num_layers = config['num_layers']
        
        self.lambda_val = config['lambda_val']

        self.vocab_size = config['vocab_size']
        self.num_tokens = config['num_tokens']

        self.dropout_keep_prob = config['dropout_keep_prob']

        self.initial_learning_rate = config['initial_learning_rate']
        self.learning_rate_decay = config['learning_rate_decay']
        self.min_learning_rate = config['min_learning_rate']

        self.batch_size = config['batch_size']
        self.epochs = config['n_epochs']

        self.embeddings_matrix = embeddings_matrix
        self.word_index = word_index
        self.idx_word = dict((i, word) for word, i in word_index.items())

        self.logs_dir = config['logs_dir']
        self.model_checkpoint_dir = config['model_checkpoint_dir']
        self.bleu_path = config['bleu_path']

        self.pad = self.word_index['PAD']
        self.eos = self.word_index['EOS']
        self.unk = self.word_index['UNK']
        
        self.epoch_bleu_score_val = {'1': [], '2': [], '3': [], '4': []}
        self.log_str = []

        self.build_model()

    def build_model(self):
        print("[INFO] Building Model ...")

        self.init_placeholders()
        self.embedding_layer()
        self.build_encoder()
        self.build_latent_space()
        self.sample_gaussian()
        self.build_decoder()
        self.loss()
        self.optimize()
        self.summary()

    def init_placeholders(self):
        with tf.name_scope("model_inputs"):
            # Create palceholders for inputs to the model
            self.input_data = tf.placeholder(tf.int32, [self.batch_size, self.num_tokens], name='input') # batch x maxlen
            self.target_data = tf.placeholder(tf.int32, [self.batch_size, self.num_tokens], name='targets') # batch x maxlen
            self.lr = tf.placeholder(tf.float32, name='learning_rate', shape=())
            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')  # Dropout Keep Probability
            self.source_sentence_length = tf.placeholder(tf.int32, shape=(self.batch_size,),
                                                         name='source_sentence_length') # batch
            self.target_sentence_length = tf.placeholder(tf.int32, shape=(self.batch_size,),
                                                         name='target_sentence_length') # batch
            self.lambda_coeff = tf.placeholder(tf.float32, name='lambda_coeff', shape=())

    def embedding_layer(self):
        with tf.name_scope("word_embeddings"):
            self.embeddings = tf.Variable(
                initial_value=np.array(self.embeddings_matrix, dtype=np.float32),
                dtype=tf.float32, trainable=False)
            self.enc_embed_input = tf.nn.embedding_lookup(self.embeddings, self.input_data) # batch x maxlen x embed_dim
            self.enc_embed_input = self.enc_embed_input[:, :tf.reduce_max(self.source_sentence_length), :]

            with tf.name_scope("decoder_inputs"):
                # shifted = tf.strided_slice(self.target_data, [0, 0], [self.batch_size, -1], [1, 1],
                #                          name='slice_input')  # Minus 1 implies everything till the last dim
                shifted = self.target_data[:,:-1] # batch x (maxlen - 1)
                self.dec_input = tf.concat([tf.fill([self.batch_size, 1], self.word_index['GO']), shifted], 1,
                                           name='dec_input') # batch x maxlen
                self.dec_embed_input = tf.nn.embedding_lookup(self.embeddings, self.dec_input)
                self.max_tar_len = tf.reduce_max(self.target_sentence_length)
                self.dec_embed_input = self.dec_embed_input[:, :self.max_tar_len, :] # batch x maxlen x embed_dim
                # self.dec_embed_input = tf.nn.dropout(self.dec_embed_input, keep_prob=self.keep_prob)

    def build_encoder(self):
        with tf.name_scope("encode"):
            for layer in range(self.num_layers):
                with tf.variable_scope('encoder_{}'.format(layer + 1)):
                    cell_fw = tf.contrib.rnn.LayerNormBasicLSTMCell(self.lstm_hidden_units)
                    cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw, input_keep_prob=self.keep_prob)

                    cell_bw = tf.contrib.rnn.LayerNormBasicLSTMCell(self.lstm_hidden_units)
                    cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw, input_keep_prob=self.keep_prob)

                    self.enc_output, self.enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
                                                                                      cell_bw,
                                                                                      self.enc_embed_input,
                                                                                      self.source_sentence_length,
                                                                                      dtype=tf.float32)

            # Join outputs since we are using a bidirectional RNN
            self.h_N = tf.concat([self.enc_state[0][1], self.enc_state[1][1]], axis=-1,
                                 name='h_N')  # Concatenated h from the fw and bw LSTMs
            self.enc_outputs = tf.concat([self.enc_output[0], self.enc_output[1]], axis=-1, name='encoder_outputs')

    def build_latent_space(self):
        with tf.name_scope("latent_space"):
            self.z_tilda = Dense(self.latent_dim, name='z_tilda')(self.h_N) # [batch_size x latent_dim]

    def sample_gaussian(self):
        with tf.name_scope('sample_gaussian'):
            # Random sample from Gaussian prior
            self.z_sampled = tf.random_normal([self.batch_size, self.latent_dim], name='z_sampled') # Dimension [batch_size x latent_dim]

    def build_decoder(self):
        with tf.variable_scope("decode"):
            for layer in range(self.num_layers):
                with tf.variable_scope('decoder_{}'.format(layer + 1)):
                    dec_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(2 * self.lstm_hidden_units)
                    dec_cell = tf.contrib.rnn.DropoutWrapper(dec_cell, input_keep_prob=self.keep_prob)

            self.output_layer = Dense(self.vocab_size)

            self.init_state = dec_cell.zero_state(self.batch_size, tf.float32)

            with tf.name_scope("training_decoder"):
                training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=self.dec_embed_input,
                                                                    sequence_length=self.target_sentence_length,
                                                                    time_major=False)

                training_decoder = BasicDecoder(dec_cell,
                                                              training_helper,
                                                              initial_state=self.init_state,
                                                              latent_vector=self.z_tilda,
                                                              output_layer=self.output_layer)

                self.training_logits, _state, _len = tf.contrib.seq2seq.dynamic_decode(training_decoder,
                                                                                       output_time_major=False,
                                                                                       impute_finished=True,
                                                                                       maximum_iterations=self.num_tokens)

                self.training_logits = tf.identity(self.training_logits.rnn_output, 'logits')

            with tf.name_scope("validate_decoder"):
                start_token = self.word_index['GO']
                end_token = self.word_index['EOS']

                start_tokens = tf.tile(tf.constant([start_token], dtype=tf.int32), [self.batch_size],
                                       name='start_tokens')

                inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.embeddings,
                                                                            start_tokens,
                                                                            end_token)

                inference_decoder = BasicDecoder(dec_cell,
                                                               inference_helper,
                                                               initial_state=self.init_state,
                                                               latent_vector=self.z_tilda,
                                                               output_layer=self.output_layer)

                self.validate_logits, _state, _len = tf.contrib.seq2seq.dynamic_decode(inference_decoder,
                                                                                        output_time_major=False,
                                                                                        impute_finished=True,
                                                                                        maximum_iterations=self.num_tokens)


                self.validate_sent = tf.identity(self.validate_logits.sample_id, name='predictions')

            with tf.name_scope("inference_decoder"):
                start_token = self.word_index['GO']
                end_token = self.word_index['EOS']

                start_tokens = tf.tile(tf.constant([start_token], dtype=tf.int32), [self.batch_size],
                                       name='start_tokens')

                inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.embeddings,
                                                                            start_tokens,
                                                                            end_token)

                inference_decoder = BasicDecoder(dec_cell,
                                                               inference_helper,
                                                               initial_state=self.init_state,
                                                               latent_vector=self.z_sampled,
                                                               output_layer=self.output_layer)

                self.inference_logits, _state, _len = tf.contrib.seq2seq.dynamic_decode(inference_decoder,
                                                                                        output_time_major=False,
                                                                                        impute_finished=True,
                                                                                        maximum_iterations=self.num_tokens)

                self.inference_logits = tf.identity(self.inference_logits.sample_id, name='predictions')

    def mmd_penalty(self, sample_qz, sample_pz):
        n = self.batch_size
        n = tf.cast(n, tf.int32)
        nf = tf.cast(n, tf.float32)
        half_size = (n * n - n) / 2

        norms_pz = tf.reduce_sum(tf.square(sample_pz), axis=1, keep_dims=True)
        dotprods_pz = tf.matmul(sample_pz, sample_pz, transpose_b=True)
        distances_pz = norms_pz + tf.transpose(norms_pz) - 2. * dotprods_pz

        norms_qz = tf.reduce_sum(tf.square(sample_qz), axis=1, keep_dims=True)
        dotprods_qz = tf.matmul(sample_qz, sample_qz, transpose_b=True)
        distances_qz = norms_qz + tf.transpose(norms_qz) - 2. * dotprods_qz

        dotprods = tf.matmul(sample_qz, sample_pz, transpose_b=True)
        distances = norms_qz + tf.transpose(norms_pz) - 2. * dotprods

        if self.config['kernel'] == 'RBF':
            # Median heuristic for the sigma^2 of Gaussian kernel
            sigma2_k = tf.nn.top_k(
                tf.reshape(distances, [-1]), half_size).values[half_size - 1]
            sigma2_k += tf.nn.top_k(
                tf.reshape(distances_qz, [-1]), half_size).values[half_size - 1]
            # if opts['verbose']:
            #     sigma2_k = tf.Print(sigma2_k, [sigma2_k], 'Kernel width:')
            res1 = tf.exp(- distances_qz / 2. / sigma2_k)
            res1 += tf.exp(- distances_pz / 2. / sigma2_k)
            res1 = tf.multiply(res1, 1. - tf.eye(n))
            res1 = tf.reduce_sum(res1) / (nf * nf - nf)
            res2 = tf.exp(- distances / 2. / sigma2_k)
            res2 = tf.reduce_sum(res2) * 2. / (nf * nf)
            stat = res1 - res2
        elif self.config['kernel'] == 'IMQ':
            # k(x, y) = C / (C + ||x - y||^2)
            # C = tf.nn.top_k(tf.reshape(distances, [-1]), half_size).values[half_size - 1]
            # C += tf.nn.top_k(tf.reshape(distances_qz, [-1]), half_size).values[half_size - 1]
            #if opts['pz'] == 'normal':
            #    Cbase = 2. * opts['zdim'] * sigma2_p
            #elif opts['pz'] == 'sphere':
            #    Cbase = 2.
            #elif opts['pz'] == 'uniform':
                # E ||x - y||^2 = E[sum (xi - yi)^2]
                #               = zdim E[(xi - yi)^2]
                #               = const * zdim
            #    Cbase = opts['zdim']

            Cbase = 2. * self.config['latent_dim'] * 2. * 1. # sigma2_p # for normal sigma2_p = 1
            stat = 0.
            for scale in [.1, .2, .5, 1., 2., 5., 10.]:
                C = Cbase * scale
                res1 = C / (C + distances_qz)
                res1 += C / (C + distances_pz)
                res1 = tf.multiply(res1, 1. - tf.eye(n))
                res1 = tf.reduce_sum(res1) / (nf * nf - nf)
                res2 = C / (C + distances)
                res2 = tf.reduce_sum(res2) * 2. / (nf * nf)
                stat += res1 - res2
        return stat

    def loss(self):
        with tf.name_scope('losses'):
            self.wasserstein_loss = self.mmd_penalty(self.z_sampled, self.z_tilda)

            # Create the weights for sequence_loss
            masks = tf.sequence_mask(self.target_sentence_length, self.num_tokens, dtype=tf.float32, name='masks')

            self.xent_loss = tf.contrib.seq2seq.sequence_loss(
                self.training_logits,
                self.target_data[:, :self.max_tar_len],
                weights=masks[:, :self.max_tar_len],
                average_across_batch=True)

            # L2-Regularization
            self.var_list = tf.trainable_variables()
            self.lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in self.var_list if 'bias' not in v.name]) * 0.001

            self.cost = self.xent_loss + self.config['lambda_val'] * self.wasserstein_loss # + self.lossL2

    def optimize(self):
        # Optimizer
        with tf.name_scope('optimization'):
            optimizer = tf.train.AdamOptimizer(self.lr)
            # optimizer = tf.train.GradientDescentOptimizer(self.lr)

            # Gradient Clipping
            gradients = optimizer.compute_gradients(self.cost, var_list=self.var_list)
            capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gradients if grad is not None]
            self.train_op = optimizer.apply_gradients(capped_gradients)

    def summary(self):
        with tf.name_scope('summaries'):
            tf.summary.scalar('xent_loss', tf.reduce_sum(self.xent_loss))
            tf.summary.scalar('l2_loss', tf.reduce_sum(self.lossL2))
            tf.summary.scalar("wasserstein_loss", tf.reduce_sum(self.wasserstein_loss))
            tf.summary.scalar('total_loss', tf.reduce_sum(self.cost))
            tf.summary.scalar('lambda', self.lambda_coeff)

            self.summary_op = tf.summary.merge_all()

    def monitor(self, x_val, sess, epoch_i, time_consumption):
        self.validate(sess, x_val)
        val_bleu_str = str(self.epoch_bleu_score_val['1'][-1]) + ' | ' \
                       + str(self.epoch_bleu_score_val['2'][-1]) + ' | ' \
                       + str(self.epoch_bleu_score_val['3'][-1]) + ' | ' \
                       + str(self.epoch_bleu_score_val['4'][-1])

        val_str = '\t\t Generated \t|\t Actual \n'
        for pred, ref in zip(self.val_pred[:20], self.val_ref[:20]):
            val_str += '\t\t' + pred + '\t|\t' + ref + '\n'

        print(val_str)
        #log_writer.write(val_str)

        generated = self.random_sample_in_session(sess)

        print(generated)
        #log_writer.write(generated)

        log_thisepoch = 'Epoch {:>3}/{} - Time {:>6.1f}, Train loss: {:>3.2f}, Val BLEU: {}\n\n'.format(epoch_i,
                                                                                                        self.epochs,
                                                                                                        time_consumption,
                                                                                                        self.train_xent,
                                                                                                        val_bleu_str)

        print(log_thisepoch)
        #log_writer.write(log_thisepoch)
        #log_writer.flush()
    
        saver = tf.train.Saver()
        saver.save(sess, self.model_checkpoint_dir + str(epoch_i) + ".ckpt")

        # Save the validation BLEU scores so far
        #with open(self.bleu_path + config_fingerprint + '.pkl', 'wb') as f:
        #    pickle.dump(self.epoch_bleu_score_val, f)

        self.log_str.append(log_thisepoch)

        #with open('bleu_logs.txt', 'w') as f:
        #    f.write('\n'.join(self.log_str))

    def train(self, x_train, x_val, checkpoint_return = True, checkpoint = None):

        print('[INFO] Training process started')

        learning_rate = self.initial_learning_rate
        iter_i = 0

        with tf.Session() as sess: 
            sess.run(tf.global_variables_initializer())
            if checkpoint_return:
                saver = tf.train.Saver()
                saver.restore(sess, checkpoint)

            writer = tf.summary.FileWriter(self.logs_dir, sess.graph)

            for epoch_i in range(1, self.epochs + 1):

                start_time = time.time()
                for batch_i, (input_batch, output_batch, sent_lengths) in enumerate(
                        get_batches(x_train, self.batch_size)):

                    try:
                        iter_i += 1

                        _, _summary, self.train_xent = sess.run(
                            [self.train_op, self.summary_op, self.xent_loss],
                            feed_dict={self.input_data: input_batch, # <batch x maxlen>
                                       self.target_data: output_batch, # <batch x maxlen>
                                       self.lr: learning_rate,
                                       self.source_sentence_length: sent_lengths,
                                       self.target_sentence_length: sent_lengths,
                                       self.keep_prob: self.dropout_keep_prob,
                                       self.lambda_coeff: self.lambda_val,
                                       })

                        writer.add_summary(_summary, iter_i)

                    except Exception as e:
                        print(iter_i, e)
                        pass

                # Reduce learning rate, but not below its minimum value
                learning_rate = np.max([self.min_learning_rate, learning_rate * self.learning_rate_decay])
                time_consumption = time.time() - start_time
                self.monitor(x_val, sess, epoch_i, time_consumption)

    def validate(self, sess, x_val):
        # Calculate BLEU on validation data
        hypotheses_val = []
        references_val = []

        for batch_i, (input_batch, output_batch, sent_lengths) in enumerate(
               get_batches(x_val, self.batch_size)):
            pred_sentences, self._validate_logits = sess.run(
                                     [self.validate_sent, self.validate_logits],
                                     feed_dict={self.input_data: input_batch,
                                                self.source_sentence_length: sent_lengths,
                                                self.keep_prob: 1.0,
                                                })


            for pred, actual in zip(pred_sentences, output_batch):
                hypotheses_val.append(
                    word_tokenize(
                        " ".join([self.idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]])))
                references_val.append(
                    [word_tokenize(" ".join([self.idx_word[i] for i in actual if i not in [self.pad, -1, self.eos]]))])
            self.val_pred = ([" ".join(sent)    for sent in hypotheses_val])
            self.val_ref  = ([" ".join(sent[0]) for sent in references_val])

        bleu_scores = calculate_bleu_scores(references_val, hypotheses_val)

        self.epoch_bleu_score_val['1'].append(bleu_scores[0])
        self.epoch_bleu_score_val['2'].append(bleu_scores[1])
        self.epoch_bleu_score_val['3'].append(bleu_scores[2])
        self.epoch_bleu_score_val['4'].append(bleu_scores[3])

    def predict(self, checkpoint, x_test):
        pred_logits = []
        hypotheses_test = []
        references_test = []

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint)

            for batch_i, (input_batch, output_batch, sent_lengths) in enumerate(
                    get_batches(x_test, self.batch_size)):
                result = sess.run(self.validate_sent, feed_dict={self.input_data: input_batch,
                                                                    self.source_sentence_length: sent_lengths,
                                                                    self.keep_prob: 1.0,
                                                                    })

                pred_logits.extend(result)

                for pred, actual in zip(result, output_batch):
                    hypotheses_test.append(
                        word_tokenize(" ".join(
                            [self.idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]])))
                    references_test.append([word_tokenize(
                        " ".join([self.idx_word[i] for i in actual if i not in [self.pad, -1, self.eos]]))])

            bleu_scores = calculate_bleu_scores(references_test, hypotheses_test)

        print('BLEU 1 to 4 : {}'.format(' | '.join(map(str, bleu_scores))))

        return pred_logits

    def show_output_sentences(self, preds, x_test):
        for pred, actual in zip(preds, x_test):
            # Actual and generated
            print('A: {}'.format(
                " ".join([self.idx_word[i] for i in actual if i not in [self.pad, self.eos]])))
            print('G: {}\n'.format(
                " ".join([self.idx_word[i] for i in pred if i not in [self.pad, self.eos]])))

    def random_sample(self, checkpoint):

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint)

            
            z_sampled = np.random.normal(size=(self.batch_size, self.latent_dim))
            result = sess.run(self.inference_logits,
                                feed_dict={self.z_sampled: z_sampled,
                                            self.keep_prob: 1.0,
                                            })

            for pred in result:
                sent = " ".join([self.idx_word[i] for i in pred if i not in [self.pad, self.eos]])
                print('G: {}'.format(sent))

    def random_sample_save(self, checkpoint, num_batches=1):

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint)
            gen_samples = []
            repeated_counter = 0
            num_batches = int(100000/self.batch_size)
            print('[INFO] NUMBER OF SAMPLES: ', num_batches * self.batch_size)

            for i in range(num_batches):
                z_sampled = np.random.normal(size=(self.batch_size, self.latent_dim))
                result = sess.run(self.inference_logits,
                                  feed_dict={self.z_sampled: z_sampled,
                                            self.keep_prob: 1.0,
                                            })

                for pred in result:
                    sent = " ".join([self.idx_word[i] for i in pred if i not in [self.pad, self.eos]])
                    if sent not in gen_samples:
                      gen_samples.append(sent)
                    else:
                      repeated_counter += 1

        # Create directories for saving sentences generated by random sampling
        out_file = '/content/drive/My Drive/wae_det_samples.txt'
        
        with open(out_file, 'w') as f:
            f.write('\n'.join(gen_samples))

    def random_sample_in_session(self, sess):
        z_sampled = np.random.normal(size=(self.batch_size, self.latent_dim))
        result = sess.run(self.inference_logits,feed_dict={self.z_sampled: z_sampled,self.keep_prob: 1.0,})

        generated = ''

        for pred in result[:10]:
            generated += '\t\t' + ' '.join([self.idx_word[i] for i in pred if i not in [self.pad, self.eos]]) + '\n'
        return generated
                
    def linear_interpolate(self, checkpoint, num_samples):
        sampled = []
        for i in range(self.batch_size // num_samples):
            z = np.random.normal(0, 1, (2, self.latent_dim))
            s1_z = z[0]
            s2_z = z[1]
            s1_z = np.repeat(s1_z[None, :], num_samples, axis=0)
            s2_z = np.repeat(s2_z[None, :], num_samples, axis=0)
            steps = np.linspace(0, 1, num_samples)[:, None]
            sampled.append(s1_z * (1 - steps) + s2_z * steps)

        sampled = np.reshape(np.array(sampled), newshape=(self.batch_size, self.latent_dim))
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint)

            result = sess.run(self.inference_logits,
                              feed_dict={self.z_sampled: sampled,
                                         self.keep_prob: 1.0,
                                         })

            for i, pred in enumerate(result):
                if i % num_samples == 0:
                    print()
                print('G: {}'.format(
                    " ".join([self.idx_word[i] for i in pred if i not in [self.pad, self.eos]])))
                
    def linear_interpolate_between_inputs(self, checkpoint, start_sent, end_sent, num_samples=8):

        # Convert seq of words to seq of indices
        # if the word is not present, use default in get(): UNK
        start_sent = word_tokenize(start_sent)
        end_sent = word_tokenize(end_sent)
        start_idx_seq = [self.word_index.get(word, self.unk) for word in start_sent] + [self.eos] 
        end_idx_seq = [self.word_index.get(word, self.unk) for word in end_sent] + [self.eos]  # Append EOS token
        start_idx_seq = np.concatenate([start_idx_seq, np.zeros(max(0, self.num_tokens - len(start_idx_seq)))])[
                      :self.num_tokens]
        end_idx_seq = np.concatenate([end_idx_seq, np.zeros(max(0, self.num_tokens - len(end_idx_seq)))])[
                        :self.num_tokens]

        # Reshape/tile so that the input has first dimension as batch size
        inp_idx_seq = np.tile(np.vstack([start_idx_seq, end_idx_seq]), [self.batch_size//2, 1])
        # source_sent_lengths = [np.count_nonzero(seq) for seq in inp_idx_seq]

        # Get z_vector of first and last sentence
        z_vecs = self.get_zvector(checkpoint, inp_idx_seq)

        sampled = []
        s1_z = z_vecs[0]
        s2_z = z_vecs[1]
        s1_z = np.repeat(s1_z[None, :], num_samples, axis=0)
        s2_z = np.repeat(s2_z[None, :], num_samples, axis=0)
        steps = np.linspace(0, 1, num_samples)[:, None]
        sampled.append(s1_z * (1 - steps) + s2_z * steps)

        sampled = np.tile(sampled[0], [self.batch_size//num_samples, 1])
        # sampled = np.reshape(np.array(sampled), newshape=(self.batch_size, self.latent_dim))
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint)

            result = sess.run(self.inference_logits,
                              feed_dict={self.z_sampled: sampled,
                                         self.keep_prob: 1.0,
                                         })

            for i, pred in enumerate(result[:num_samples]):
                print('G: {}'.format(
                    " ".join([self.idx_word[i] for i in pred if i not in [self.pad, self.eos]])))

    def get_zvector(self, checkpoint, x_test):
        z_vecs = []
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint)

            for batch_i, (input_batch, output_batch, sent_lengths) in enumerate(
                    get_batches(x_test, self.batch_size)):
                result = sess.run(self.z_tilda, feed_dict={self.input_data: input_batch,
                                                           self.source_sentence_length: sent_lengths,
                                                           self.keep_prob: 1.0,
                                                           })
                z_vecs.extend(result)

        return np.array(z_vecs)


### RUN MODEL ###

In [None]:
import argparse
import os


def model_argparse():
    parser = argparse.ArgumentParser()

    # parser.add_argument("--isDebug", type=bool, default=isDebug, help='is debug')
    parser.add_argument("--device", type=str, default="0", help='tf device') # GPU 0 or 1
    parser.add_argument("--lstm_hidden_units", type=int, default=100, help='number of hidden units for the LSTM')
    parser.add_argument("--embedding_size", type=int, default=300, help='word embedding dimension')
    parser.add_argument("--num_layers", type=int, default=1, help='number of LSTM layers')
    parser.add_argument("--vocab_size", type=int, default=30000, help='vocabulary size')
    parser.add_argument("--num_tokens", type=int, default=20, help='max number of words/tokens in the input/generated sequence')
    
    parser.add_argument("--latent_dim", type=int, default=100, help='dimension of z-latent space')
    parser.add_argument("--batch_size", type=int, default=128, help='batch size')
    parser.add_argument("--n_epochs", type=int, default=20, help='number of epochs')

    parser.add_argument("--dropout_keep_prob", type=float, default=0.8, help='dropout keep probability')
    parser.add_argument("--initial_learning_rate", type=float, default=0.001, help='initial learning rate')
    parser.add_argument("--learning_rate_decay", type=float, default=1.0, help='learning rate decay')
    parser.add_argument("--min_learning_rate", type=float, default=0.00001, help='minimum learning rate')

    parser.add_argument("--lambda_val", type=float, default=0., help='initial value of lambda, i.e., MMD co-efficient')
    parser.add_argument("--kernel", type=str, default='IMQ', help='MMD loss based on kernel type from: IMQ or RBF ')

    parser.add_argument("--data", type=str, default='../data/snli_sentences_all.txt')
    parser.add_argument("--w2v_file", type=str, default='../w2v_models/w2v_300d_snli_all_sentences.pkl')
    parser.add_argument("--bleu_path", type=str, default='bleu/', help='path to save bleu scores')
    parser.add_argument("--model_checkpoint_dir", type=str, default='', help='path to save model checkpoints')
    parser.add_argument("--logs_dir", type=str, default='', help='path to save log files')

    parser.add_argument("--ckpt", type=str, default=None, help='checkpoint')
  
    
    config = dict()
    config['device'] = '0'
    config['lstm_hidden_units'] = 100
    config['embedding_size'] = 300
    config['vocab_size'] = 30000
    config['num_layers'] = 1
    config['num_tokens'] = 50
    
    config['latent_dim'] = 100
    config['batch_size'] = 128
    config['n_epochs'] = 500
    
    config['dropout_keep_prob'] = 0.8
    config['initial_learning_rate'] = 0.001
    config['learning_rate_decay'] = 1.0
    config['min_learning_rate'] = 0.00001
    
    config['lambda_val'] = 3.0
    config['kernel'] = 'IMQ'
    
    config['data'] ='/content/drive/My Drive/sources.txt'
    config['w2v_file'] = '/content/drive/My Drive/w2v_models/fce_incorrect.pkl'
    config['bleu_path'] = 'bleu/'
    config['model_checkpoint_dir'] = ''
    config['logs_dir'] = ''

    # Output log file
    config_fingerprint = 'full_snli_' + \
            'lambdaWAE' + str(config['lambda_val']) + \
            '_batch' + str(config['batch_size']) + \
            '_kernel_' + str(config['kernel']) + \
            '_num_tokens_' + str(config['num_tokens'])
    
    if not isTrain:
        return config

    # Create directories for saving model runs and stats
    pwd = os.path.dirname('/content/drive/My Drive/')

    if not os.path.exists(pwd + '/bleu'):
        os.mkdir(pwd + '/bleu')
    
    if not os.path.exists(pwd + '/runs'):
        os.mkdir(pwd + '/runs')

    log_writer = open(pwd + '/runs/log_' + config_fingerprint, 'a')
    log_writer.write(str(config) + '\n')
    log_writer.flush()

    # Model checkpoint
    if not os.path.exists(pwd + '/models'):
        os.mkdir(pwd + '/models')
    model_path = pwd + '/models/' + config_fingerprint
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    config['model_checkpoint_dir'] = model_path + '/'

    # Model summary directory
    if not os.path.exists(pwd + '/summary_logs'):
        os.mkdir(pwd + '/summary_logs')
    summary_path = pwd + '/summary_logs/' + config_fingerprint
    if not os.path.exists(summary_path):
        os.mkdir(summary_path)

    config['logs_dir'] = summary_path

    return config

In [None]:
import os
isTrain = True

config = model_argparse()

import tensorflow as tf
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
sess = tf.Session(config=tf_config)

import numpy as np
from sklearn.model_selection import train_test_split

print('FILE PATH: ', config['data'])
snli_data = get_sentences(file_path = config['data'])

print('[INFO] Number of sentences = {}'.format(len(snli_data)))

sentences = [s.strip() for s in snli_data]

np.random.shuffle(sentences)

print('[INFO] Tokenizing input and output sequences')
filters = '!"#$%&()*+/:;<=>@[\\]^`{|}~\t\n'
x, word_index = tokenize_sequence(sentences,
                                             filters,
                                             config['num_tokens'],
                                             config['vocab_size'])

print('[INFO] Split data into train-validation-test sets')
x_train, _x_val_test = train_test_split(x, test_size = 0.1, random_state = 10)
x_val, x_test = train_test_split(_x_val_test, test_size = 0.5, random_state = 10)

w2v = config['w2v_file']
embeddings_matrix = create_embedding_matrix(word_index,
                                                  config['embedding_size'],
                                                  w2v)

# Re-calculate the vocab size based on the word_idx dictionary
config['vocab_size'] = len(word_index)


In [None]:

#----------------------------------------------------------------#
model = DetWAEModel(config,
                    embeddings_matrix,
                    word_index)

checkpoint = 
model.train(x_train, x_val)

log_writer.close()

In [None]:
model = DetWAEModel(config,
                    embeddings_matrix,
                    word_index)

checkpoint = '/content/drive/My Drive/models/full_snli_lambdaWAE3.0_batch128_kernel_IMQ_num_tokens_50/146.ckpt'
model.random_sample_save(checkpoint = checkpoint)
