This notebook contains code from Weili Nie's 'RelGAN' repository (https://github.com/weilinie/RelGAN) and runs the RelGAN model. It was designed speficially to be run in the Colaboratory environment.

## SET-UP## 

In [None]:
#I will attach my Drive so that I can easily load files.
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

In [None]:
import nltk
nltk.download('punkt')
  

In [None]:
# coding=utf-8
import nltk


# text tokens to code strings
def text_to_code(tokens, dictionary, seq_len):
    code_str = ""
    eof_code = len(dictionary)  # used to filled in the blank to make up a sentence with seq_len
    for sentence in tokens:
        index = 0
        for word in sentence:
            code_str += (str(dictionary[word]) + ' ')
            index += 1
        while index < seq_len:
            code_str += (str(eof_code) + ' ')
            index += 1
        code_str += '\n'
    return code_str


# code tokens to text strings
def code_to_text(codes, dictionary):
    paras = ""
    eof_code = len(dictionary)
    for line in codes:
        numbers = map(int, line)
        for number in numbers:
            if number == eof_code:
                continue
            paras += (dictionary[str(number)] + ' ')
        paras += '\n'
    return paras


# tokenlize the file
def get_tokenlized(file):
    tokenlized = list()
    with open(file) as raw:
        for text in raw:
            text = nltk.word_tokenize(text.lower())
            tokenlized.append(text)
    return tokenlized


# get word set
def get_word_list(tokens):
    word_set = list()
    for sentence in tokens:
        for word in sentence:
            word_set.append(word)
    return list(set(word_set))


# get word_index_dict and index_word_dict
def get_dict(word_set):
    word_index_dict = dict()
    index_word_dict = dict()
    index = 0
    for word in word_set:
        word_index_dict[word] = str(index)
        index_word_dict[str(index)] = word
        index += 1
    return word_index_dict, index_word_dict


# get sequence length and dict size
def text_precess(train_text_loc, test_text_loc=None):
    train_tokens = get_tokenlized(train_text_loc)
    if test_text_loc is None:
        test_tokens = list()
    else:
        test_tokens = get_tokenlized(test_text_loc)
    word_set = get_word_list(train_tokens + test_tokens)
    [word_index_dict, index_word_dict] = get_dict(word_set)

    if test_text_loc is None:
        sequence_len = len(max(train_tokens, key=len))
    else:
        sequence_len = max(len(max(train_tokens, key=len)), len(max(test_tokens, key=len)))

    # with open(oracle_file, 'w') as outfile:
    #     outfile.write(text_to_code(tokens, word_index_dict, seq_len))

    return sequence_len, len(word_index_dict) + 1


In [None]:
import numpy as np
import tensorflow as tf
import matplotlib

matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
import pprint

pp = pprint.PrettyPrinter()


def generate_samples(sess, gen_x, batch_size, generated_num, output_file=None, get_code=True, iteration = 0):
    # Generate Samples
    sent_list = []
    repeated_count = 0
    generated_samples = []
    for _ in range(int(generated_num / batch_size)):
        generated_samples.extend(sess.run(gen_x))
    codes = list()
    if output_file is not None:
        with open(output_file, 'w') as fout:
          for sent in generated_samples:
                buffer = ' '.join([str(x) for x in sent]) + '\n'
                if buffer not in sent_list:
                  sent_list.append(buffer)
                  codes += buffer
                  fout.write(buffer)
                  if get_code:
                    codes.append(sent)
                else:
                  repeated_count += 1 
        
        print('Repeated samples: ', repeated_count)
        if repeated_count < 40 and iteration > 1000:
          sent_list = []
          repeated_count = 0
          normal_count = 0
          generated_samples = []
          for _ in range(int(100000 / batch_size)):
              generated_samples.extend(sess.run(gen_x))
          codes = list()
          if output_file is not None:
              with open(output_file, 'w') as fout:
                  for sent in generated_samples:
                      buffer = ' '.join([str(x) for x in sent]) + '\n'
                      if buffer not in sent_list:
                        sent_list.append(buffer)
                        codes += buffer
                        fout.write(buffer)
                        normal_count += 1 
                        if get_code:
                          codes.append(sent)
                      else:
                        repeated_count += 1 

              print('Repeated samples: ', repeated_count)
              print('Unique samples: ', normal_count)
        return np.array(codes)
    codes = ""
    for sent in generated_samples:
        buffer = ' '.join([str(x) for x in sent]) + '\n'
        codes += buffer
    return codes

def init_sess():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    return sess


def pre_train_epoch(sess, g_pretrain_op, g_pretrain_loss, x_real, data_loader):
    # Pre-train the generator using MLE for one epoch
    supervised_g_losses = []
    data_loader.reset_pointer()

    for it in range(data_loader.num_batch):
        batch = data_loader.next_batch()
        _, g_loss = sess.run([g_pretrain_op, g_pretrain_loss], feed_dict={x_real: batch})
        supervised_g_losses.append(g_loss)

    return np.mean(supervised_g_losses)


def plot_csv(csv_file, pre_epoch_num, metrics, method):
    names = [str(i) for i in range(len(metrics) + 1)]
    data = np.genfromtxt(csv_file, delimiter=',', skip_header=0, skip_footer=0, names=names)
    for idx in range(len(metrics)):
        metric_name = metrics[idx].get_name()
        plt.figure()
        plt.plot(data[names[0]], data[names[idx + 1]], color='r', label=method)
        plt.axvline(x=pre_epoch_num, color='k', linestyle='--')
        plt.xlabel('training epochs')
        plt.ylabel(metric_name)
        plt.legend()
        plot_file = os.path.join(os.path.dirname(csv_file), '{}_{}.pdf'.format(method, metric_name))
        print(plot_file)
        plt.savefig(plot_file)


def get_oracle_file(data_file, oracle_file, seq_len):
    tokens = get_tokenlized(data_file)
    word_set = get_word_list(tokens)
    [word_index_dict, index_word_dict] = get_dict(word_set)
    with open(oracle_file, 'w') as outfile:
        outfile.write(text_to_code(tokens, word_index_dict, seq_len))

    return index_word_dict


def get_real_test_file(generator_file, gen_save_file, iw_dict):
    codes = get_tokenlized(generator_file)
    with open(gen_save_file, 'w') as outfile:
        outfile.write(code_to_text(codes=codes, dictionary=iw_dict))

## LSTM ##

In [None]:
import tensorflow as tf
from tensorflow.python.ops import tensor_array_ops, control_flow_ops
import numpy as np


class OracleLstm(object):
    def __init__(self, num_vocabulary, batch_size, emb_dim, hidden_dim, sequence_length, start_token):
        self.num_vocabulary = num_vocabulary
        self.batch_size = batch_size
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = tf.constant([start_token] * self.batch_size, dtype=tf.int32)
        self.g_params = []
        self.temperature = 1.0

        with tf.variable_scope('generator'):
            tf.set_random_seed(1234)
            self.g_embeddings = tf.Variable(
                tf.random_normal([self.num_vocabulary, self.emb_dim], 0.0, 1.0, seed=123314154))
            self.g_params.append(self.g_embeddings)
            self.g_recurrent_unit = self.create_recurrent_unit(self.g_params)  # maps h_tm1 to h_t for generator
            self.g_output_unit = self.create_output_unit(self.g_params)  # maps h_t to o_t (output token logits)

        # placeholder definition
        self.x = tf.placeholder(tf.int32, shape=[self.batch_size,
                                                 self.sequence_length])  # sequence of tokens generated by generator

        # processed for batch
        with tf.device("/cpu:0"):
            tf.set_random_seed(1234)
            self.processed_x = tf.transpose(tf.nn.embedding_lookup(self.g_embeddings, self.x),
                                            perm=[1, 0, 2])  # seq_length x batch_size x emb_dim

        # initial states
        self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
        self.h0 = tf.stack([self.h0, self.h0])

        # generator on initial randomness
        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)

        def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            log_prob = tf.log(tf.nn.softmax(o_t))
            next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32)
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token)  # batch x emb_dim
            gen_o = gen_o.write(i, tf.reduce_sum(
                tf.multiply(tf.one_hot(next_token, self.num_vocabulary, 1.0, 0.0), tf.nn.softmax(o_t)),
                1))  # [batch_size] , prob
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            return i + 1, x_tp1, h_t, gen_o, gen_x

        _, _, _, self.gen_o, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            body=_g_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, gen_o, gen_x)
        )

        self.gen_x = self.gen_x.stack()  # seq_length x batch_size
        self.gen_x = tf.transpose(self.gen_x, perm=[1, 0])  # batch_size x seq_length

        # supervised pretraining for generator
        g_predictions = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length,
            dynamic_size=False, infer_shape=True)

        ta_emb_x = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length)
        ta_emb_x = ta_emb_x.unstack(self.processed_x)

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
            h_t = self.g_recurrent_unit(x_t, h_tm1)
            o_t = self.g_output_unit(h_t)
            g_predictions = g_predictions.write(i, tf.nn.softmax(o_t))  # batch x vocab_size
            x_tp1 = ta_emb_x.read(i)
            return i + 1, x_tp1, h_t, g_predictions

        _, _, _, self.g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            body=_pretrain_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings, self.start_token),
                       self.h0, g_predictions))

        self.g_predictions = tf.transpose(
            self.g_predictions.stack(), perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

        # pretraining loss
        self.pretrain_loss = -tf.reduce_sum(
            tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_vocabulary, 1.0, 0.0) * tf.log(
                tf.reshape(self.g_predictions, [-1, self.num_vocabulary]))) / (self.sequence_length * self.batch_size)

        self.out_loss = tf.reduce_sum(
            tf.reshape(
                -tf.reduce_sum(
                    tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_vocabulary, 1.0, 0.0) * tf.log(
                        tf.reshape(self.g_predictions, [-1, self.num_vocabulary])), 1
                ), [-1, self.sequence_length]
            ), 1
        )  # batch_size

    def generate(self, session):
        # h0 = np.random.normal(size=self.hidden_dim)
        outputs = session.run(self.gen_x)
        return outputs

    def init_matrix(self, shape):
        return tf.random_normal(shape, stddev=1.0, seed=10)

    def create_recurrent_unit(self, params):
        # Weights and Bias for input and hidden tensor
        self.Wi = tf.Variable(tf.random_normal([self.emb_dim, self.hidden_dim], 0.0, 1.0, seed=111))
        self.Ui = tf.Variable(tf.random_normal([self.hidden_dim, self.hidden_dim], 0.0, 1.0, seed=211))
        self.bi = tf.Variable(tf.random_normal([self.hidden_dim, ], 0.0, 1.0, seed=311))

        self.Wf = tf.Variable(tf.random_normal([self.emb_dim, self.hidden_dim], 0.0, 1.0, seed=114))
        self.Uf = tf.Variable(tf.random_normal([self.hidden_dim, self.hidden_dim], 0.0, 1.0, seed=115))
        self.bf = tf.Variable(tf.random_normal([self.hidden_dim, ], 0.0, 1.0, seed=116))

        self.Wog = tf.Variable(tf.random_normal([self.emb_dim, self.hidden_dim], 0.0, 1.0, seed=997))
        self.Uog = tf.Variable(tf.random_normal([self.hidden_dim, self.hidden_dim], 0.0, 1.0, seed=998))
        self.bog = tf.Variable(tf.random_normal([self.hidden_dim, ], 0.0, 1.0, seed=999))

        self.Wc = tf.Variable(tf.random_normal([self.emb_dim, self.hidden_dim], 0.0, 1.0, seed=110))
        self.Uc = tf.Variable(tf.random_normal([self.hidden_dim, self.hidden_dim], 0.0, 1.0, seed=111))
        self.bc = tf.Variable(tf.random_normal([self.hidden_dim, ], 0.0, 1.0, seed=112))
        params.extend([
            self.Wi, self.Ui, self.bi,
            self.Wf, self.Uf, self.bf,
            self.Wog, self.Uog, self.bog,
            self.Wc, self.Uc, self.bc])

        def unit(x, hidden_memory_tm1):
            previous_hidden_state, c_prev = tf.unstack(hidden_memory_tm1)

            # Input Gate
            i = tf.sigmoid(
                tf.matmul(x, self.Wi) +
                tf.matmul(previous_hidden_state, self.Ui) + self.bi
            )

            # Forget Gate
            f = tf.sigmoid(
                tf.matmul(x, self.Wf) +
                tf.matmul(previous_hidden_state, self.Uf) + self.bf
            )

            # Output Gate
            o = tf.sigmoid(
                tf.matmul(x, self.Wog) +
                tf.matmul(previous_hidden_state, self.Uog) + self.bog
            )

            # New Memory Cell
            c_ = tf.nn.tanh(
                tf.matmul(x, self.Wc) +
                tf.matmul(previous_hidden_state, self.Uc) + self.bc
            )

            # Final Memory cell
            c = f * c_prev + i * c_

            # Current Hidden state
            current_hidden_state = o * tf.nn.tanh(c)

            return tf.stack([current_hidden_state, c])

        return unit

    def create_output_unit(self, params):
        self.Wo = tf.Variable(tf.random_normal([self.hidden_dim, self.num_vocabulary], 0.0, 1.0, seed=12341))
        self.bo = tf.Variable(tf.random_normal([self.num_vocabulary], 0.0, 1.0, seed=56865246))
        params.extend([self.Wo, self.bo])

        def unit(hidden_memory_tuple):
            hidden_state, c_prev = tf.unstack(hidden_memory_tuple)
            logits = tf.matmul(hidden_state, self.Wo) + self.bo
            return logits

        return unit

    # Compute the similarity between minibatch examples and all embeddings.
    # We use the cosine distance:

    def set_similarity(self, valid_examples=None, pca=True):
        if valid_examples == None:
            if pca:
                valid_examples = np.array(range(20))
            else:
                valid_examples = np.array(range(self.num_vocabulary))
        self.valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
        self.norm = tf.sqrt(tf.reduce_sum(tf.square(self.g_embeddings), 1, keep_dims=True))
        self.normalized_embeddings = self.g_embeddings / self.norm
        # PCA
        if self.num_vocabulary >= 20 and pca == True:
            emb = tf.matmul(self.normalized_embeddings, tf.transpose(self.normalized_embeddings))
            s, u, v = tf.svd(emb)
            u_r = tf.strided_slice(u, begin=[0, 0], end=[20, self.num_vocabulary], strides=[1, 1])
            self.normalized_embeddings = tf.matmul(u_r, self.normalized_embeddings)
        self.valid_embeddings = tf.nn.embedding_lookup(
            self.normalized_embeddings, self.valid_dataset)
        self.similarity = tf.matmul(self.valid_embeddings, tf.transpose(self.normalized_embeddings))


## METRICS ## 

In [None]:
from abc import abstractmethod

class Metrics:
    def __init__(self):
        self.name = 'Metric'

    def get_name(self):
        return self.name

    def set_name(self, name):
        self.name = name

    @abstractmethod
    def get_score(self):
        pass


In [None]:
import numpy as np

class Nll(Metrics):
    def __init__(self, data_loader, pretrain_loss, x_real, sess, name='Nll'):
        super().__init__()
        self.name = name
        self.data_loader = data_loader
        self.sess = sess
        self.pretrain_loss = pretrain_loss
        self.x_real = x_real

    def set_name(self, name):
        self.name = name

    def get_name(self):
        return self.name

    def get_score(self):
        return self.nll_loss()

    def nll_loss(self):
        nll = []
        self.data_loader.reset_pointer()
        for it in range(self.data_loader.num_batch):
            batch = self.data_loader.next_batch()
            g_loss = self.sess.run(self.pretrain_loss, {self.x_real: batch})
            nll.append(g_loss)
        return np.mean(nll)


In [None]:
import random

import nltk
from nltk.translate.bleu_score import SmoothingFunction

class Bleu(Metrics):
    def __init__(self, test_text='', real_text='', gram=3, name='Bleu', portion=1):
        super().__init__()
        self.name = name
        self.test_data = test_text
        self.real_data = real_text
        self.gram = gram
        self.sample_size = 200  # BLEU scores remain nearly unchanged for self.sample_size >= 200
        self.reference = None
        self.is_first = True
        self.portion = portion  # how many portions to use in the evaluation, default to use the whole test dataset

    def get_name(self):
        return self.name

    def get_score(self, is_fast=False, ignore=False):
        if ignore:
            return 0
        if self.is_first:
            self.get_reference()
            self.is_first = False
        return self.get_bleu()

    def get_reference(self):
        if self.reference is None:
            reference = list()
            with open(self.real_data) as real_data:
                for text in real_data:
                    text = nltk.word_tokenize(text)
                    reference.append(text)

            # randomly choose a portion of test data
            # In-place shuffle
            random.shuffle(reference)
            len_ref = len(reference)
            reference = reference[:int(self.portion*len_ref)]

            self.reference = reference

            return reference
        else:
            return self.reference

    def get_bleu(self):
        ngram = self.gram
        bleu = list()
        reference = self.get_reference()
        weight = tuple((1. / ngram for _ in range(ngram)))
        with open(self.test_data) as test_data:
            i = 0
            for hypothesis in test_data:
                if i >= self.sample_size:
                    break
                hypothesis = nltk.word_tokenize(hypothesis)
                bleu.append(self.calc_bleu(reference, hypothesis, weight))
                i += 1
        return sum(bleu) / len(bleu)

    def calc_bleu(self, reference, hypothesis, weight):
        return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight,
                                                       smoothing_function=SmoothingFunction().method1)


In [None]:
import nltk
from nltk.translate.bleu_score import SmoothingFunction

class SelfBleu(Metrics):
    def __init__(self, test_text='', gram=3, name='SelfBleu', portion=1):
        super().__init__()
        self.name = name
        self.test_data = test_text
        self.gram = gram
        self.sample_size = 200  # SelfBLEU scores remain nearly unchanged for self.sample_size >= 200
        self.portion = portion  # how many posrtions to use in the evaluation, default to use the whole test dataset

    def get_name(self):
        return self.name

    def get_score(self, is_fast=False, ignore=False):
        if ignore:
            return 0

        return self.get_bleu()

    def get_reference(self):
        reference = list()
        with open(self.test_data) as real_data:
            for text in real_data:
                text = nltk.word_tokenize(text)
                reference.append(text)
        len_ref = len(reference)

        return reference[:int(self.portion*len_ref)]

    def get_bleu(self):
        ngram = self.gram
        bleu = list()
        reference = self.get_reference()
        weight = tuple((1. / ngram for _ in range(ngram)))
        with open(self.test_data) as test_data:
            i = 0
            for hypothesis in test_data:
                if i >= self.sample_size:
                    break
                hypothesis = nltk.word_tokenize(hypothesis)
                bleu.append(self.calc_bleu(reference, hypothesis, weight))
                i += 1

        return sum(bleu) / len(bleu)

    def calc_bleu(self, reference, hypothesis, weight):
        return nltk.translate.bleu_score.sentence_bleu(reference, hypothesis, weight,
                                                       smoothing_function=SmoothingFunction().method1)



In [None]:
import collections
import math
import random

import nltk
import numpy as np
import tensorflow as tf
from scipy.spatial.distance import cosine

class DocEmbSim(Metrics):
    def __init__(self, oracle_file=None, generator_file=None, num_vocabulary=None, name='DocEmbSim'):
        super().__init__()
        self.name = name
        self.oracle_sim = None
        self.gen_sim = None
        self.is_first = True
        self.oracle_file = oracle_file
        self.generator_file = generator_file
        self.num_vocabulary = num_vocabulary
        self.batch_size = 64
        self.embedding_size = 32
        self.data_index = 0
        self.valid_examples = None

    def get_score(self):
        if self.is_first:
            self.get_oracle_sim()
            self.is_first = False
        self.get_gen_sim()
        return self.get_dis_corr()

    def get_frequent_word(self):
        if self.valid_examples is not None:
            return self.valid_examples

        import collections
        words = []
        with open(self.oracle_file, 'r') as file:
            for line in file:
                text = nltk.word_tokenize(line)
                text = list(map(int, text))
                words += text
        counts = collections.Counter(words)
        new_list = sorted(words, key=lambda x: -counts[x])
        word_set = list(set(new_list))
        if len(word_set) < self.num_vocabulary // 10:
            self.valid_examples = word_set
            return word_set
        else:
            self.valid_examples = word_set[0: self.num_vocabulary//10]  # choose 1/10 words with the highest frequency
            return word_set[0: self.num_vocabulary//10]

    def read_data(self, file):
        words = []
        with open(file, 'r') as file:
            for line in file:
                text = nltk.word_tokenize(line)
                words.append(text)
        return words

    def generate_batch(self, batch_size, num_skips, skip_window, data=None):
        assert batch_size % num_skips == 0
        assert num_skips <= 2 * skip_window
        batch = np.ndarray(shape=(batch_size), dtype=np.int32)
        labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
        span = 2 * skip_window + 1  # [ skip_window target skip_window ]
        buffer = collections.deque(maxlen=span)  # deque to slide the window
        for _ in range(span):
            buffer.append(data[self.data_index])
            self.data_index = (self.data_index + 1) % len(data)
        for i in range(batch_size // num_skips):
            target = skip_window  # target label at the center of the buffer
            targets_to_avoid = [skip_window]
            for j in range(num_skips):
                while target in targets_to_avoid:
                    target = random.randint(0, span - 1)
                targets_to_avoid.append(target)
                batch[i * num_skips + j] = buffer[skip_window]
                labels[i * num_skips + j, 0] = buffer[target]
            buffer.append(data[self.data_index])
            self.data_index = (self.data_index + 1) % len(data)
        return batch, labels

    def get_wordvec(self, file):
        graph = tf.Graph()
        batch_size = self.batch_size
        embedding_size = self.embedding_size
        vocabulary_size = self.num_vocabulary
        num_sampled = 64
        if num_sampled > vocabulary_size:
            num_sampled = vocabulary_size
        num_steps = 2
        skip_window = 1  # How many words to consider left and right.
        num_skips = 2  # How many times to reuse an input to generate a label.
        if self.valid_examples is None:
            self.get_frequent_word()

        with graph.as_default():

            # Input data.
            train_dataset = tf.placeholder(tf.int32, shape=[batch_size])
            train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
            valid_dataset = tf.constant(self.valid_examples, dtype=tf.int32)

            # initial Variables.
            embeddings = tf.Variable(
                tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0, seed=11))
            softmax_weights = tf.Variable(
                tf.truncated_normal([vocabulary_size, embedding_size],
                                    stddev=1.0 / math.sqrt(embedding_size), seed=12))
            softmax_biases = tf.Variable(tf.zeros([vocabulary_size]))

            # Model.
            # Look up embeddings for inputs.
            embed = tf.nn.embedding_lookup(embeddings, train_dataset)
            # Compute the softmax loss, using a sample of the negative labels each time.
            loss = tf.reduce_mean(
                tf.nn.sampled_softmax_loss(weights=softmax_weights, biases=softmax_biases, inputs=embed,
                                           labels=train_labels, num_sampled=num_sampled, num_classes=vocabulary_size))

            optimizer = tf.train.AdagradOptimizer(1.0).minimize(loss)

            # Compute the similarity between minibatch examples and all embeddings.
            norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
            normalized_embeddings = embeddings / norm
            valid_embeddings = tf.nn.embedding_lookup(
                normalized_embeddings, valid_dataset)
            similarity = tf.matmul(valid_embeddings, tf.transpose(normalized_embeddings))

            data = self.read_data(file)

        with tf.Session(graph=graph) as session:
            tf.global_variables_initializer().run()
            average_loss = 0
            generate_num = len(data)
            for step in range(num_steps):
                for index in range(generate_num):
                    cur_batch_data, cur_batch_labels = self.generate_batch(
                        batch_size, num_skips, skip_window, data[index])
                    feed_dict = {train_dataset: cur_batch_data, train_labels: cur_batch_labels}
                    _, l = session.run([optimizer, loss], feed_dict=feed_dict)
                    average_loss += l
            similarity_value = similarity.eval()
            return similarity_value

    def get_oracle_sim(self):
        self.oracle_sim = self.get_wordvec(self.oracle_file)  # evaluate word embedding on the models file

    def get_gen_sim(self):
        self.gen_sim = self.get_wordvec(self.generator_file)  # evaluate word embedding on the generator file

    def get_dis_corr(self):
        if len(self.oracle_sim) != len(self.gen_sim):
            raise ArithmeticError
        corr = 0
        for index in range(len(self.oracle_sim)):
            corr += (1 - cosine(np.array(self.oracle_sim[index]), np.array(self.gen_sim[index])))
        return np.log10(corr / len(self.oracle_sim))


## HELPER FUNCTIONS ## 

In [None]:
import math
import tensorflow as tf


def hw_flatten(x):
    return tf.reshape(x, shape=[-1, x.shape[1] * x.shape[2], x.shape[-1]])


def l2_norm(v, eps=1e-12):
    return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)


def lrelu(x, alpha=0.2):
    return tf.nn.leaky_relu(x, alpha)


def create_linear_initializer(input_size, dtype=tf.float32):
    """Returns a default initializer for weights of a linear module."""
    stddev = 1 / math.sqrt(input_size * 1.0)
    return tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)


def create_bias_initializer(dtype=tf.float32):
    """Returns a default initializer for the biases of a linear/AddBias module."""
    return tf.zeros_initializer(dtype=dtype)


def linear(input_, output_size, use_bias=False, sn=False, scope=None):
    '''
    Linear map: output[k] = sum_i(Matrix[k, i] * input_[i] ) + Bias[k]
    Args:
    input_: a tensor or a list of 2D, batch x n, Tensors.
    output_size: int, second dimension of W[i].
    scope: Variable Scope for the created subgraph; defaults to "Linear".
  Returns:
    A 2D Tensor with shape [batch x output_size] equal to
    sum_i(input_[i] * W[i]), where W[i]s are newly created matrices.
  Raises:
    ValueError: if some of the arguments has unspecified or wrong shape.
  '''

    shape = input_.get_shape().as_list()
    if len(shape) != 2:
        raise ValueError("Linear is expecting 2D arguments: %s" % str(shape))
    if not shape[1]:
        raise ValueError("Linear expects shape[1] of arguments: %s" % str(shape))
    input_size = shape[1]

    # Now the computation.
    with tf.variable_scope(scope or "Linear"):
        W = tf.get_variable("Matrix", shape=[output_size, input_size],
                            initializer=create_linear_initializer(input_size, input_.dtype),
                            dtype=input_.dtype)
        if sn:
            W = spectral_norm(W)
        output_ = tf.matmul(input_, tf.transpose(W))
        if use_bias:
            bias_term = tf.get_variable("Bias", [output_size],
                                        initializer=create_bias_initializer(input_.dtype),
                                        dtype=input_.dtype)
            output_ += bias_term

    return output_


def highway(input_, size, num_layers=1, bias=-2.0, f=tf.nn.relu, scope='Highway'):
    """Highway Network (cf. http://arxiv.org/abs/1505.00387).
    t = sigmoid(Wy + b)
    z = t * g(Wy + b) + (1 - t) * y
    where g is nonlinearity, t is transform gate, and (1 - t) is carry gate.
    """

    with tf.variable_scope(scope):
        for idx in range(num_layers):
            g = f(linear(input_, size, scope='highway_lin_%d' % idx))

            t = tf.sigmoid(linear(input_, size, scope='highway_gate_%d' % idx) + bias)

            output = t * g + (1. - t) * input_
            input_ = output

    return output


def mlp(input_, output_sizes, act_func=tf.nn.relu, use_bias=True):
    '''
    Constructs a MLP module
    :param input_:
    :param output_sizes: An iterable of output dimensionalities
    :param act_func: activation function
    :param use_bias: whether use bias term for linear mapping
    :return: the output of the MLP module
    '''
    net = input_
    num_layers = len(output_sizes)
    for layer_id in range(num_layers):
        net = linear(net, output_sizes[layer_id], use_bias=use_bias, scope='linear_{}'.format(layer_id))
        if layer_id != num_layers - 1:
            net = act_func(net)
    return net


def conv2d(input_, out_nums, k_h=2, k_w=1, d_h=2, d_w=1, stddev=None, sn=False, padding='SAME', scope=None):
    in_nums = input_.get_shape().as_list()[-1]
    # Glorot initialization
    if stddev is None:
        stddev = math.sqrt(2. / (k_h * k_w * in_nums))
    with tf.variable_scope(scope or "Conv2d"):
        W = tf.get_variable("Matrix", shape=[k_h, k_w, in_nums, out_nums],
                            initializer=tf.truncated_normal_initializer(stddev=stddev))
        if sn:
            W = spectral_norm(W)
        b = tf.get_variable("Bias", shape=[out_nums], initializer=tf.zeros_initializer)
        conv = tf.nn.conv2d(input_, filter=W, strides=[1, d_h, d_w, 1], padding=padding)
        conv = tf.nn.bias_add(conv, b)

    return conv


def self_attention(x, ch, sn=False):
    """self-attention for GAN"""
    f = conv2d(x, ch // 8, k_h=1, d_h=1, sn=sn, scope='f_conv')  # [bs, h, w, c']
    g = conv2d(x, ch // 8, k_h=1, d_h=1, sn=sn, scope='g_conv')  # [bs, h, w, c']
    h = conv2d(x, ch, k_h=1, d_h=1, sn=sn, scope='h_conv')  # [bs, h, w, c]

    # N = h * w
    s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True)  # # [bs, N, N]

    beta = tf.nn.softmax(s, dim=-1)  # attention map

    o = tf.matmul(beta, hw_flatten(h))  # [bs, N, C]
    gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))

    o = tf.reshape(o, [-1] + x.get_shape().as_list()[1:])  # [bs, h, w, C]
    x = gamma * o + x

    return x


def spectral_norm(w, iteration=1):
    """spectral normalization for GANs"""
    w_shape = w.shape.as_list()
    w = tf.reshape(w, [-1, w_shape[-1]])

    u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)

    u_hat = u
    v_hat = None
    for i in range(iteration):
        """
        power iteration
        Usually iteration = 1 will be enough
        """
        v_ = tf.matmul(u_hat, tf.transpose(w))
        v_hat = l2_norm(v_)

        u_ = tf.matmul(v_hat, w)
        u_hat = l2_norm(u_)

    sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
    w_norm = w / sigma

    with tf.control_dependencies([u.assign(u_hat)]):
        w_norm = tf.reshape(w_norm, w_shape)

    return w_norm


def create_output_unit(output_size, vocab_size):
    # output_size = self.gen_mem.output_size.as_list()[0]
    Wo = tf.get_variable('Wo', shape=[output_size, vocab_size], initializer=create_linear_initializer(output_size))
    bo = tf.get_variable('bo', shape=[vocab_size], initializer=create_bias_initializer())

    def unit(hidden_mem_o):
        logits = tf.matmul(hidden_mem_o, Wo) + bo
        return logits

    return unit


def add_gumbel(o_t, eps=1e-10):
    """Sample from Gumbel(0, 1)"""
    u = tf.random_uniform(tf.shape(o_t), minval=0, maxval=1, dtype=tf.float32)
    g_t = -tf.log(-tf.log(u + eps) + eps)
    gumbel_t = tf.add(o_t, g_t)
    return gumbel_t


def add_gumbel_cond(o_t, next_token_onehot, eps=1e-10):
    """draw reparameterization z of categorical variable b from p(z|b)."""

    def truncated_gumbel(gumbel, truncation):
        return -tf.log(eps + tf.exp(-gumbel) + tf.exp(-truncation))

    v = tf.random_uniform(tf.shape(o_t), minval=0, maxval=1, dtype=tf.float32)

    print("shape of v: {}".format(v.get_shape().as_list()))
    print("shape of next_token_onehot: {}".format(next_token_onehot.get_shape().as_list()))

    gumbel = -tf.log(-tf.log(v + eps) + eps, name="gumbel")
    topgumbels = gumbel + tf.reduce_logsumexp(o_t, axis=-1, keep_dims=True)
    topgumbel = tf.reduce_sum(next_token_onehot * topgumbels, axis=-1, keep_dims=True)

    truncgumbel = truncated_gumbel(gumbel + o_t, topgumbel)
    return (1. - next_token_onehot) * truncgumbel + next_token_onehot * topgumbels


def gradient_penalty(discriminator, x_real_onehot, x_fake_onehot_appr, config):
    """compute the gradiet penalty for the WGAN-GP loss"""
    alpha = tf.random_uniform(shape=[config['batch_size'], 1, 1], minval=0., maxval=1.)
    interpolated = alpha * x_real_onehot + (1. - alpha) * x_fake_onehot_appr

    logit = discriminator(x_onehot=interpolated)

    grad = tf.gradients(logit, interpolated)[0]  # gradient of D(interpolated)
    grad_norm = tf.norm(tf.layers.flatten(grad), axis=1)  # l2 norm

    GP = config['reg_param'] * tf.reduce_mean(tf.square(grad_norm - 1.))

    return GP


## TRAINING FUNCTIONS ## 

### ORACLE ### 

In [None]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm
import time

EPS = 1e-10


def oracle_train(generator, discriminator, oracle_model, oracle_loader, gen_loader, config):
    batch_size = config['batch_size']
    vocab_size = config['vocab_size']
    seq_len = config['seq_len']
    num_sentences = config['num_sentences']
    data_dir = config['data_dir']
    log_dir = config['log_dir']
    sample_dir = config['sample_dir']
    npre_epochs = config['npre_epochs']
    nadv_steps = config['nadv_steps']
    seed = config['seed']
    temper = config['temperature']
    adapt = config['adapt']

    # set random seed
    tf.set_random_seed(seed)
    np.random.seed(seed)

    # filename
    oracle_file = os.path.join(sample_dir, 'oracle.txt')
    gen_file = os.path.join(sample_dir, 'generator.txt')
    csv_file = os.path.join(log_dir, 'experiment-log-relgan.csv')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_real")  # tokens of oracle sequences

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_onehot = tf.one_hot(x_real, vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_onehot.get_shape().as_list() == [batch_size, seq_len, vocab_size]

    x_fake_onehot_appr, x_fake, g_pretrain_loss, gen_o = generator(x_real=x_real, temperature=temperature)

    d_out_real = discriminator(x_onehot=x_real_onehot)
    d_out_fake = discriminator(x_onehot=x_fake_onehot_appr)

    # GAN / Divergence type
    log_pg, g_loss, d_loss = get_losses(d_out_real, d_out_fake, x_real_onehot, x_fake_onehot_appr,
                                                    gen_o, discriminator, config)

    # Global step
    global_step = tf.Variable(0, trainable=False)
    global_step_op = global_step.assign_add(1)

    # Train ops
    g_pretrain_op, g_train_op, d_train_op, temp_train_op = get_train_ops(config, g_pretrain_loss, g_loss, d_loss,
                                                                         log_pg, temperature, global_step)

    # Record wall clock time
    time_diff = tf.placeholder(tf.float32)
    Wall_clock_time = tf.Variable(0., trainable=False)
    update_Wall_op = Wall_clock_time.assign_add(time_diff)

    # Temperature placeholder
    temp_var = tf.placeholder(tf.float32)
    update_temperature_op = temperature.assign(temp_var)

    # Loss summaries
    loss_summaries = [
        tf.summary.scalar('loss/discriminator', d_loss),
        tf.summary.scalar('loss/g_loss', g_loss),
        tf.summary.scalar('loss/log_pg', log_pg),
        tf.summary.scalar('loss/Wall_clock_time', Wall_clock_time),
        tf.summary.scalar('loss/temperature', temperature),
    ]
    loss_summary_op = tf.summary.merge(loss_summaries)

    # Metric Summaries
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # ------------- initial the graph --------------
    with init_sess() as sess:
        log = open(csv_file, 'w')
        sum_writer = tf.summary.FileWriter(os.path.join(log_dir, 'summary'), sess.graph)

        # generate oracle data and create batches
        generate_samples(sess, oracle_model.gen_x, batch_size, num_sentences, oracle_file)
        oracle_loader.create_batches(oracle_file)

        metrics = get_metrics(config, oracle_loader, gen_loader, oracle_file, gen_file,
                              oracle_model, g_pretrain_loss, x_real, sess)

        print('Start pre-training...')
        for epoch in range(npre_epochs):
            # run pre-training
            g_pretrain_loss_np = pre_train_epoch(sess, g_pretrain_op, g_pretrain_loss, x_real, oracle_loader)

            # Test
            ntest_pre = 10
            if np.mod(epoch, ntest_pre) == 0:
                # generate fake data and create batches
                gen_save_file = os.path.join(sample_dir, 'pre_samples_{:05d}.txt'.format(epoch))
                generate_samples(sess, x_fake, batch_size, num_sentences, gen_file)
                generate_samples(sess, x_fake, batch_size, 120, gen_save_file)
                gen_loader.create_batches(gen_file)

                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict(zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str, epoch)

                msg = 'pre_gen_epoch:' + str(epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                print(msg)
                log.write(msg)
                log.write('\n')

        print('Start adversarial training...')
        progress = tqdm(range(nadv_steps))
        for _ in progress:
            niter = sess.run(global_step)

            t0 = time.time()

            # run adversarial training
            for _ in range(config['gsteps']):
                sess.run(g_train_op, feed_dict={x_real: oracle_loader.random_batch()})
            for _ in range(config['dsteps']):
                sess.run(d_train_op, feed_dict={x_real: oracle_loader.random_batch()})

            t1 = time.time()
            sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0})

            # temperature
            temp_var_np = get_fixed_temperature(temper, niter, nadv_steps, adapt)
            sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np})

            feed = {x_real: oracle_loader.random_batch()}
            g_loss_np, d_loss_np, loss_summary_str = sess.run([g_loss, d_loss, loss_summary_op], feed_dict=feed)
            sum_writer.add_summary(loss_summary_str, niter)

            sess.run(global_step_op)

            progress.set_description('g_loss: %4.4f, d_loss: %4.4f' % (g_loss_np, d_loss_np))

            # Test
            if np.mod(niter, config['ntest']) == 0:
                # generate fake data and create batches
                gen_save_file = os.path.join(sample_dir, 'adv_samples_{:05d}.txt'.format(niter))
                generate_samples(sess, x_fake, batch_size, num_sentences, gen_file)
                generate_samples(sess, x_fake, batch_size, 120, gen_save_file)
                gen_loader.create_batches(gen_file)

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict(zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str, niter + config['npre_epochs'])

                msg = 'adv_step: ' + str(niter)
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                print(msg)
                log.write(msg)
                log.write('\n')


# A function to get different GAN losses
def get_losses(d_out_real, d_out_fake, x_real_onehot, x_fake_onehot_appr, gen_o, discriminator, config):
    batch_size = config['batch_size']
    gan_type = config['gan_type']

    if gan_type == 'standard':  # the non-satuating GAN loss
        d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_real, labels=tf.ones_like(d_out_real)
        ))
        d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_fake, labels=tf.zeros_like(d_out_fake)
        ))
        d_loss = d_loss_real + d_loss_fake

        g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_fake, labels=tf.ones_like(d_out_fake)
        ))

    elif gan_type == 'JS':  # the vanilla GAN loss
        d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_real, labels=tf.ones_like(d_out_real)
        ))
        d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_fake, labels=tf.zeros_like(d_out_fake)
        ))
        d_loss = d_loss_real + d_loss_fake

        g_loss = -d_loss_fake

    elif gan_type == 'KL':  # the GAN loss implicitly minimizing KL-divergence
        d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_real, labels=tf.ones_like(d_out_real)
        ))
        d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_fake, labels=tf.zeros_like(d_out_fake)
        ))
        d_loss = d_loss_real + d_loss_fake

        g_loss = tf.reduce_mean(-d_out_fake)

    elif gan_type == 'hinge':  # the hinge loss
        d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - d_out_real))
        d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + d_out_fake))
        d_loss = d_loss_real + d_loss_fake

        g_loss = -tf.reduce_mean(d_out_fake)

    elif gan_type == 'tv':  # the total variation distance
        d_loss = tf.reduce_mean(tf.tanh(d_out_fake) - tf.tanh(d_out_real))
        g_loss = tf.reduce_mean(-tf.tanh(d_out_fake))

    elif gan_type == 'wgan-gp': # WGAN-GP
        d_loss = tf.reduce_mean(d_out_fake) - tf.reduce_mean(d_out_real)
        GP = gradient_penalty(discriminator, x_real_onehot, x_fake_onehot_appr, config)
        d_loss += GP

        g_loss = -tf.reduce_mean(d_out_fake)

    elif gan_type == 'LS': # LS-GAN
        d_loss_real = tf.reduce_mean(tf.squared_difference(d_out_real, 1.0))
        d_loss_fake = tf.reduce_mean(tf.square(d_out_fake))
        d_loss = d_loss_real + d_loss_fake

        g_loss = tf.reduce_mean(tf.squared_difference(d_out_fake, 1.0))

    elif gan_type == 'RSGAN':  # relativistic standard GAN
        d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_real - d_out_fake, labels=tf.ones_like(d_out_real)
        ))
        g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_fake - d_out_real, labels=tf.ones_like(d_out_fake)
        ))

    else:
        raise NotImplementedError("Divergence '%s' is not implemented" % gan_type)

    log_pg = tf.reduce_mean(tf.log(gen_o + EPS))  # [1], measures the log p_g(x)

    return log_pg, g_loss, d_loss


# A function to calculate the gradients and get training operations
def get_train_ops(config, g_pretrain_loss, g_loss, d_loss, log_pg, temperature, global_step):
    optimizer_name = config['optimizer']
    nadv_steps = config['nadv_steps']
    d_lr = config['d_lr']
    gpre_lr = config['gpre_lr']
    gadv_lr = config['gadv_lr']

    g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
    d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')

    grad_clip = 5.0  # keep the same with the previous setting

    # generator pre-training
    pretrain_opt = tf.train.AdamOptimizer(gpre_lr, beta1=0.9, beta2=0.999)
    pretrain_grad, _ = tf.clip_by_global_norm(tf.gradients(g_pretrain_loss, g_vars), grad_clip)  # gradient clipping
    g_pretrain_op = pretrain_opt.apply_gradients(zip(pretrain_grad, g_vars))

    if config['decay']:
        d_lr = tf.train.exponential_decay(d_lr, global_step=global_step, decay_steps=nadv_steps, decay_rate=0.1)
        gadv_lr = tf.train.exponential_decay(gadv_lr, global_step=global_step, decay_steps=nadv_steps, decay_rate=0.1)

    if optimizer_name == 'adam':
        d_optimizer = tf.train.AdamOptimizer(d_lr, beta1=0.9, beta2=0.999)
        g_optimizer = tf.train.AdamOptimizer(gadv_lr, beta1=0.9, beta2=0.999)
        temp_optimizer = tf.train.AdamOptimizer(1e-2, beta1=0.9, beta2=0.999)
    elif optimizer_name == 'rmsprop':
        d_optimizer = tf.train.RMSPropOptimizer(d_lr)
        g_optimizer = tf.train.RMSPropOptimizer(gadv_lr)
        temp_optimizer = tf.train.RMSPropOptimizer(1e-2)
    else:
        raise NotImplementedError

    g_grads, _ = tf.clip_by_global_norm(tf.gradients(g_loss, g_vars), grad_clip)  # gradient clipping
    g_train_op = g_optimizer.apply_gradients(zip(g_grads, g_vars))

    print('len of g_grads without None: {}'.format(len([i for i in g_grads if i is not None])))
    print('len of g_grads: {}'.format(len(g_grads)))

    # d_train_op = d_optimizer.minimize(d_loss, var_list=d_vars)
    d_grads, _ = tf.clip_by_global_norm(tf.gradients(d_loss, d_vars), grad_clip)  # gradient clipping
    d_train_op = d_optimizer.apply_gradients(zip(d_grads, d_vars))

    temp_grads = tf.gradients(-log_pg, [temperature])
    temp_train_op = temp_optimizer.apply_gradients(zip(temp_grads, [temperature]))

    return g_pretrain_op, g_train_op, d_train_op, temp_train_op


# A function to get various evaluation metrics
def get_metrics(config, oracle_loader, gen_loader, oracle_file, gen_file, oracle_model, g_pretrain_loss, x_real, sess):
    # set up evaluation metric
    metrics = []
    if config['nll_oracle']:
        nll_oracle = Nll(gen_loader, oracle_model.pretrain_loss, oracle_model.x, sess, name='nll_oracle')
        metrics.append(nll_oracle)
    if config['nll_gen']:
        nll_gen = Nll(oracle_loader, g_pretrain_loss, x_real, sess, name='nll_gen')
        metrics.append(nll_gen)
    if config['doc_embsim']:
        doc_embsim = DocEmbSim(oracle_file, gen_file, config['vocab_size'], name='doc_embsim')
        metrics.append(doc_embsim)

    return metrics


# A function to get the summary for each metric
def get_metric_summary_op(config):
    metrics_pl = []
    metrics_sum = []
    if config['nll_oracle']:
        nll_oracle = tf.placeholder(tf.float32)
        metrics_pl.append(nll_oracle)
        metrics_sum.append(tf.summary.scalar('metrics/nll_oracle', nll_oracle))

    if config['nll_gen']:
        nll_gen = tf.placeholder(tf.float32)
        metrics_pl.append(nll_gen)
        metrics_sum.append(tf.summary.scalar('metrics/nll_gen', nll_gen))

    if config['doc_embsim']:
        doc_embsim = tf.placeholder(tf.float32)
        metrics_pl.append(doc_embsim)
        metrics_sum.append(tf.summary.scalar('metrics/doc_embsim', doc_embsim))

    metric_summary_op = tf.summary.merge(metrics_sum)
    return metrics_pl, metric_summary_op


def get_fixed_temperature(temper, i, nadv_steps, adapt):
    # using a fixed number of maximum adversarial steps
    N = 5000
    assert nadv_steps <= N
    if adapt == 'no':
        temper_var_np = temper  # no increase
    elif adapt == 'lin':
        temper_var_np = 1 + i / (N - 1) * (temper - 1)  # linear increase
    elif adapt == 'exp':
        temper_var_np = temper ** (i / N)  # exponential increase
    elif adapt == 'log':
        temper_var_np = 1 + (temper - 1) / np.log(N) * np.log(i + 1)  # logarithm increase
    elif adapt == 'sigmoid':
        temper_var_np = (temper - 1) * 1 / (1 + np.exp((N / 2 - i) * 20 / N)) + 1  # sigmoid increase
    elif adapt == 'quad':
        temper_var_np = (temper - 1) / (N - 1)**2 * i ** 2 + 1
    elif adapt == 'sqrt':
        temper_var_np = (temper - 1) / np.sqrt(N - 1) * np.sqrt(i) + 1
    else:
        raise Exception("Unknown adapt type!")

    return temper_var_np


In [None]:
import numpy as np
import random


class OracleDataLoader():
    def __init__(self, batch_size, seq_length, end_token=0):
        self.batch_size = batch_size
        self.token_stream = []
        self.seq_length = seq_length
        self.end_token = end_token

    def create_batches(self, data_file):
        self.token_stream = []

        with open(data_file, 'r') as raw:
            for line in raw:
                line = line.strip().split()
                parse_line = [int(x) for x in line]
                if len(parse_line) > self.seq_length:
                    self.token_stream.append(parse_line[:self.seq_length])
                else:
                    while len(parse_line) < self.seq_length:
                        parse_line.append(self.end_token)
                    if len(parse_line) == self.seq_length:
                        self.token_stream.append(parse_line)

        self.num_batch = int(len(self.token_stream) / self.batch_size)
        self.token_stream = self.token_stream[:self.num_batch * self.batch_size]
        self.sequence_batches = np.split(np.array(self.token_stream), self.num_batch, 0)
        self.pointer = 0

    def next_batch(self):
        ret = self.sequence_batches[self.pointer]
        self.pointer = (self.pointer + 1) % self.num_batch
        return ret

    def random_batch(self):
        rn_pointer = random.randint(0, self.num_batch - 1)
        ret = self.sequence_batches[rn_pointer]
        return ret

    def reset_pointer(self):
        self.pointer = 0


### REAL ### 

In [None]:
import numpy as np
import random


class RealDataLoader():
    def __init__(self, batch_size, seq_length, end_token=0):
        self.batch_size = batch_size
        self.token_stream = []
        self.seq_length = seq_length
        self.end_token = end_token

    def create_batches(self, data_file):
        self.token_stream = []

        with open(data_file, 'r') as raw:
            for line in raw:
                line = line.strip().split()
                parse_line = [int(x) for x in line]
                if len(parse_line) > self.seq_length:
                    self.token_stream.append(parse_line[:self.seq_length])
                else:
                    while len(parse_line) < self.seq_length:
                        parse_line.append(self.end_token)
                    if len(parse_line) == self.seq_length:
                        self.token_stream.append(parse_line)

        self.num_batch = int(len(self.token_stream) / self.batch_size)
        self.token_stream = self.token_stream[:self.num_batch * self.batch_size]
        self.sequence_batches = np.split(np.array(self.token_stream), self.num_batch, 0)
        self.pointer = 0

    def next_batch(self):
        ret = self.sequence_batches[self.pointer]
        self.pointer = (self.pointer + 1) % self.num_batch
        return ret

    def random_batch(self):
        rn_pointer = random.randint(0, self.num_batch - 1)
        ret = self.sequence_batches[rn_pointer]
        return ret

    def reset_pointer(self):
        self.pointer = 0


In [None]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm
import time

EPS = 1e-10


# A function to initiate the graph and train the networks
def real_train(generator, discriminator, oracle_loader, config):
    batch_size = config['batch_size']
    num_sentences = config['num_sentences']
    vocab_size = config['vocab_size']
    seq_len = config['seq_len']
    data_dir = config['data_dir']
    dataset = config['dataset']
    log_dir = config['log_dir']
    sample_dir = config['sample_dir']
    npre_epochs = config['npre_epochs']
    nadv_steps = config['nadv_steps']
    temper = config['temperature']
    adapt = config['adapt']

    # filename
    oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset))
    gen_file = os.path.join(sample_dir, 'generator.txt')
    gen_text_file = os.path.join(sample_dir, 'generator_text.txt')
    csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv')
    data_file = os.path.join(data_dir, '{}.txt'.format(dataset))
    if dataset == 'image_coco':
        test_file = os.path.join(data_dir, 'testdata/test_coco.txt')
    elif dataset == 'emnlp_news':
        test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt')
    elif dataset == 'fce_train_new':
        test_file = os.path.join(data_dir, 'testdata/fce_test_new.txt')
    else:
        raise NotImplementedError('Unknown dataset!')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_real")  # tokens of oracle sequences

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_onehot = tf.one_hot(x_real, vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_onehot.get_shape().as_list() == [batch_size, seq_len, vocab_size]

    # generator and discriminator outputs
    x_fake_onehot_appr, x_fake, g_pretrain_loss, gen_o = generator(x_real=x_real, temperature=temperature)
    d_out_real = discriminator(x_onehot=x_real_onehot)
    d_out_fake = discriminator(x_onehot=x_fake_onehot_appr)

    # GAN / Divergence type
    log_pg, g_loss, d_loss = get_losses(d_out_real, d_out_fake, x_real_onehot, x_fake_onehot_appr,
                                                    gen_o, discriminator, config)

    # Global step
    global_step = tf.Variable(0, trainable=False)
    global_step_op = global_step.assign_add(1)

    # Train ops
    g_pretrain_op, g_train_op, d_train_op = get_train_ops(config, g_pretrain_loss, g_loss, d_loss,
                                                          log_pg, temperature, global_step)

    # Record wall clock time
    time_diff = tf.placeholder(tf.float32)
    Wall_clock_time = tf.Variable(0., trainable=False)
    update_Wall_op = Wall_clock_time.assign_add(time_diff)

    # Temperature placeholder
    temp_var = tf.placeholder(tf.float32)
    update_temperature_op = temperature.assign(temp_var)

    # Loss summaries
    loss_summaries = [
        tf.summary.scalar('loss/discriminator', d_loss),
        tf.summary.scalar('loss/g_loss', g_loss),
        tf.summary.scalar('loss/log_pg', log_pg),
        tf.summary.scalar('loss/Wall_clock_time', Wall_clock_time),
        tf.summary.scalar('loss/temperature', temperature),
    ]
    loss_summary_op = tf.summary.merge(loss_summaries)

    # Metric Summaries
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # ------------- initial the graph --------------
    with init_sess() as sess:
        if config['checkpoint_restore']:
          saver = tf.train.Saver()
          saver.restore(sess, config['checkpoint'])
          print('[INFO]: Restored checkpoint')
          
        log = open(csv_file, 'w')
        sum_writer = tf.summary.FileWriter(os.path.join(log_dir, 'summary'), sess.graph)

        # generate oracle data and create batches
        index_word_dict = get_oracle_file(data_file, oracle_file, seq_len)
        oracle_loader.create_batches(oracle_file)

        metrics = get_metrics(config, oracle_loader, test_file, gen_text_file, g_pretrain_loss, x_real, sess)

        if not config['checkpoint_restore']:
          print('Start pre-training...')
          for epoch in range(npre_epochs):
              # pre-training
              g_pretrain_loss_np = pre_train_epoch(sess, g_pretrain_op, g_pretrain_loss, x_real, oracle_loader)

              # Test
              ntest_pre = 10
              if np.mod(epoch, ntest_pre) == 0:
                  # generate fake data and create batches
                  gen_save_file = os.path.join(sample_dir, 'pre_samples_{:05d}.txt'.format(epoch))
                  generate_samples(sess, x_fake, batch_size, num_sentences, gen_file)
                  get_real_test_file(gen_file, gen_save_file, index_word_dict)
                  get_real_test_file(gen_file, gen_text_file, index_word_dict)

                  # write summaries
                  scores = [metric.get_score() for metric in metrics]
                  metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict(zip(metrics_pl, scores)))
                  sum_writer.add_summary(metrics_summary_str, epoch)

                  msg = 'pre_gen_epoch:' + str(epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                  metric_names = [metric.get_name() for metric in metrics]
                  for (name, score) in zip(metric_names, scores):
                      msg += ', ' + name + ': %.4f' % score
                  print(msg)
                  log.write(msg)
                  log.write('\n')

        print('Start adversarial training...')
        progress = tqdm(range(nadv_steps))
        for _ in progress:
            niter = sess.run(global_step)

            t0 = time.time()

            # adversarial training
            for _ in range(config['gsteps']):
                sess.run(g_train_op, feed_dict={x_real: oracle_loader.random_batch()})
            for _ in range(config['dsteps']):
                sess.run(d_train_op, feed_dict={x_real: oracle_loader.random_batch()})

            t1 = time.time()
            sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0})

            # temperature
            temp_var_np = get_fixed_temperature(temper, niter, nadv_steps, adapt)
            sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np})

            feed = {x_real: oracle_loader.random_batch()}
            g_loss_np, d_loss_np, loss_summary_str = sess.run([g_loss, d_loss, loss_summary_op], feed_dict=feed)
            sum_writer.add_summary(loss_summary_str, niter)

            sess.run(global_step_op)

            progress.set_description('g_loss: %4.4f, d_loss: %4.4f' % (g_loss_np, d_loss_np))

            # Test
            if np.mod(niter, config['ntest']) == 0:
              
                #Save the model
                #saver= tf.train.Saver()
                #saver.save(sess, '/content/drive/My Drive/model_saver/' + str(niter) + '.ckpt')
                
                # generate fake data and create batches
                gen_save_file = os.path.join(sample_dir, 'adv_samples_{:05d}.txt'.format(niter))
                generate_samples(sess, x_fake, batch_size, num_sentences, gen_file, iteration = niter)
                get_real_test_file(gen_file, gen_save_file, index_word_dict)
                get_real_test_file(gen_file, gen_text_file, index_word_dict)

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict(zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str, niter + config['npre_epochs'])

                msg = 'adv_step: ' + str(niter)
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                print(msg)
                log.write(msg)
                log.write('\n')


# A function to get different GAN losses
def get_losses(d_out_real, d_out_fake, x_real_onehot, x_fake_onehot_appr, gen_o, discriminator, config):
    batch_size = config['batch_size']
    gan_type = config['gan_type']

    if gan_type == 'standard':  # the non-satuating GAN loss
        d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_real, labels=tf.ones_like(d_out_real)
        ))
        d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_fake, labels=tf.zeros_like(d_out_fake)
        ))
        d_loss = d_loss_real + d_loss_fake

        g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_fake, labels=tf.ones_like(d_out_fake)
        ))

    elif gan_type == 'JS':  # the vanilla GAN loss
        d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_real, labels=tf.ones_like(d_out_real)
        ))
        d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_fake, labels=tf.zeros_like(d_out_fake)
        ))
        d_loss = d_loss_real + d_loss_fake

        g_loss = -d_loss_fake

    elif gan_type == 'KL':  # the GAN loss implicitly minimizing KL-divergence
        d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_real, labels=tf.ones_like(d_out_real)
        ))
        d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_fake, labels=tf.zeros_like(d_out_fake)
        ))
        d_loss = d_loss_real + d_loss_fake

        g_loss = tf.reduce_mean(-d_out_fake)

    elif gan_type == 'hinge':  # the hinge loss
        d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - d_out_real))
        d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + d_out_fake))
        d_loss = d_loss_real + d_loss_fake

        g_loss = -tf.reduce_mean(d_out_fake)

    elif gan_type == 'tv':  # the total variation distance
        d_loss = tf.reduce_mean(tf.tanh(d_out_fake) - tf.tanh(d_out_real))
        g_loss = tf.reduce_mean(-tf.tanh(d_out_fake))

    elif gan_type == 'wgan-gp':  # WGAN-GP
        d_loss = tf.reduce_mean(d_out_fake) - tf.reduce_mean(d_out_real)
        GP = gradient_penalty(discriminator, x_real_onehot, x_fake_onehot_appr, config)
        d_loss += GP

        g_loss = -tf.reduce_mean(d_out_fake)

    elif gan_type == 'LS':  # LS-GAN
        d_loss_real = tf.reduce_mean(tf.squared_difference(d_out_real, 1.0))
        d_loss_fake = tf.reduce_mean(tf.square(d_out_fake))
        d_loss = d_loss_real + d_loss_fake

        g_loss = tf.reduce_mean(tf.squared_difference(d_out_fake, 1.0))

    elif gan_type == 'RSGAN':  # relativistic standard GAN
        d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_real - d_out_fake, labels=tf.ones_like(d_out_real)
        ))
        g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_out_fake - d_out_real, labels=tf.ones_like(d_out_fake)
        ))

    else:
        raise NotImplementedError("Divergence '%s' is not implemented" % gan_type)

    log_pg = tf.reduce_mean(tf.log(gen_o + EPS))  # [1], measures the log p_g(x)

    return log_pg, g_loss, d_loss


# A function to calculate the gradients and get training operations
def get_train_ops(config, g_pretrain_loss, g_loss, d_loss, log_pg, temperature, global_step):
    optimizer_name = config['optimizer']
    nadv_steps = config['nadv_steps']
    d_lr = config['d_lr']
    gpre_lr = config['gpre_lr']
    gadv_lr = config['gadv_lr']

    g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
    d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')

    grad_clip = 5.0  # keep the same with the previous setting

    # generator pre-training
    pretrain_opt = tf.train.AdamOptimizer(gpre_lr, beta1=0.9, beta2=0.999)
    pretrain_grad, _ = tf.clip_by_global_norm(tf.gradients(g_pretrain_loss, g_vars), grad_clip)  # gradient clipping
    g_pretrain_op = pretrain_opt.apply_gradients(zip(pretrain_grad, g_vars))

    # decide if using the weight decaying
    if config['decay']:
        d_lr = tf.train.exponential_decay(d_lr, global_step=global_step, decay_steps=nadv_steps, decay_rate=0.1)
        gadv_lr = tf.train.exponential_decay(gadv_lr, global_step=global_step, decay_steps=nadv_steps, decay_rate=0.1)

    # Adam optimizer
    if optimizer_name == 'adam':
        d_optimizer = tf.train.AdamOptimizer(d_lr, beta1=0.9, beta2=0.999)
        g_optimizer = tf.train.AdamOptimizer(gadv_lr, beta1=0.9, beta2=0.999)
        temp_optimizer = tf.train.AdamOptimizer(1e-2, beta1=0.9, beta2=0.999)

    # RMSProp optimizer
    elif optimizer_name == 'rmsprop':
        d_optimizer = tf.train.RMSPropOptimizer(d_lr)
        g_optimizer = tf.train.RMSPropOptimizer(gadv_lr)
        temp_optimizer = tf.train.RMSPropOptimizer(1e-2)

    else:
        raise NotImplementedError

    # gradient clipping
    g_grads, _ = tf.clip_by_global_norm(tf.gradients(g_loss, g_vars), grad_clip)
    g_train_op = g_optimizer.apply_gradients(zip(g_grads, g_vars))

    print('len of g_grads without None: {}'.format(len([i for i in g_grads if i is not None])))
    print('len of g_grads: {}'.format(len(g_grads)))

    # gradient clipping
    d_grads, _ = tf.clip_by_global_norm(tf.gradients(d_loss, d_vars), grad_clip)
    d_train_op = d_optimizer.apply_gradients(zip(d_grads, d_vars))

    return g_pretrain_op, g_train_op, d_train_op


# A function to get various evaluation metrics
def get_metrics(config, oracle_loader, test_file, gen_file, g_pretrain_loss, x_real, sess):
    # set up evaluation metric
    metrics = []
    if config['nll_gen']:
        nll_gen = Nll(oracle_loader, g_pretrain_loss, x_real, sess, name='nll_gen')
        metrics.append(nll_gen)
    if config['doc_embsim']:
        doc_embsim = DocEmbSim(test_file, gen_file, config['vocab_size'], name='doc_embsim')
        metrics.append(doc_embsim)
    if config['bleu']:
        for i in range(2, 6):
            bleu = Bleu(test_text=gen_file, real_text=test_file, gram=i, name='bleu' + str(i))
            metrics.append(bleu)
    if config['selfbleu']:
        for i in range(2, 6):
            selfbleu = SelfBleu(test_text=gen_file, gram=i, name='selfbleu' + str(i))
            metrics.append(selfbleu)

    return metrics


# A function to get the summary for each metric
def get_metric_summary_op(config):
    metrics_pl = []
    metrics_sum = []

    if config['nll_gen']:
        nll_gen = tf.placeholder(tf.float32)
        metrics_pl.append(nll_gen)
        metrics_sum.append(tf.summary.scalar('metrics/nll_gen', nll_gen))

    if config['doc_embsim']:
        doc_embsim = tf.placeholder(tf.float32)
        metrics_pl.append(doc_embsim)
        metrics_sum.append(tf.summary.scalar('metrics/doc_embsim', doc_embsim))

    if config['bleu']:
        for i in range(2, 6):
            temp_pl = tf.placeholder(tf.float32, name='bleu{}'.format(i))
            metrics_pl.append(temp_pl)
            metrics_sum.append(tf.summary.scalar('metrics/bleu{}'.format(i), temp_pl))

    if config['selfbleu']:
        for i in range(2, 6):
            temp_pl = tf.placeholder(tf.float32, name='selfbleu{}'.format(i))
            metrics_pl.append(temp_pl)
            metrics_sum.append(tf.summary.scalar('metrics/selfbleu{}'.format(i), temp_pl))

    metric_summary_op = tf.summary.merge(metrics_sum)
    return metrics_pl, metric_summary_op


# A function to set up different temperature control policies
def get_fixed_temperature(temper, i, nadv_steps, adapt):
    # using a fixed number of maximum adversarial steps
    N = 5000
    assert nadv_steps <= N
    if adapt == 'no':
        temper_var_np = temper  # no increase
    elif adapt == 'lin':
        temper_var_np = 1 + i / (N - 1) * (temper - 1)  # linear increase
    elif adapt == 'exp':
        temper_var_np = temper ** (i / N)  # exponential increase
    elif adapt == 'log':
        temper_var_np = 1 + (temper - 1) / np.log(N) * np.log(i + 1)  # logarithm increase
    elif adapt == 'sigmoid':
        temper_var_np = (temper - 1) * 1 / (1 + np.exp((N / 2 - i) * 20 / N)) + 1  # sigmoid increase
    elif adapt == 'quad':
        temper_var_np = (temper - 1) / (N - 1)**2 * i ** 2 + 1
    elif adapt == 'sqrt':
        temper_var_np = (temper - 1) / np.sqrt(N - 1) * np.sqrt(i) + 1
    else:
        raise Exception("Unknown adapt type!")

    return temper_var_np


## MEMORY ARCHITECTURE ## 

In [None]:
"""Relational Memory architecture.

An implementation of the architecture described in "Relational Recurrent
Neural Networks", Santoro et al., 2018.
"""
import tensorflow as tf


class RelationalMemory(object):
    """Relational Memory Core."""

    def __init__(self, mem_slots, head_size, num_heads=1, num_blocks=1,
                 forget_bias=1.0, input_bias=0.0, gate_style='unit',
                 attention_mlp_layers=2, key_size=None, name='relational_memory'):
        """Constructs a `RelationalMemory` object.

        Args:
          mem_slots: The total number of memory slots to use.
          head_size: The size of an attention head.
          num_heads: The number of attention heads to use. Defaults to 1.
          num_blocks: Number of times to compute attention per time step. Defaults
            to 1.
          forget_bias: Bias to use for the forget gate, assuming we are using
            some form of gating. Defaults to 1.
          input_bias: Bias to use for the input gate, assuming we are using
            some form of gating. Defaults to 0.
          gate_style: Whether to use per-element gating ('unit'),
            per-memory slot gating ('memory'), or no gating at all (None).
            Defaults to `unit`.
          attention_mlp_layers: Number of layers to use in the post-attention
            MLP. Defaults to 2.
          key_size: Size of vector to use for key & query vectors in the attention
            computation. Defaults to None, in which case we use `head_size`.
          name: Name of the module.

        Raises:
          ValueError: gate_style not one of [None, 'memory', 'unit'].
          ValueError: num_blocks is < 1.
          ValueError: attention_mlp_layers is < 1.
        """

        self._mem_slots = mem_slots
        self._head_size = head_size
        self._num_heads = num_heads
        self._mem_size = self._head_size * self._num_heads
        self._name = name

        if num_blocks < 1:
            raise ValueError('num_blocks must be >= 1. Got: {}.'.format(num_blocks))
        self._num_blocks = num_blocks

        self._forget_bias = forget_bias
        self._input_bias = input_bias

        if gate_style not in ['unit', 'memory', None]:
            raise ValueError(
                'gate_style must be one of [\'unit\', \'memory\', None]. Got: '
                '{}.'.format(gate_style))
        self._gate_style = gate_style

        if attention_mlp_layers < 1:
            raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format(
                attention_mlp_layers))
        self._attention_mlp_layers = attention_mlp_layers

        self._key_size = key_size if key_size else self._head_size

        self._template = tf.make_template(self._name, self._build)  # wrapper for variable sharing

    def initial_state(self, batch_size):
        """Creates the initial memory.

        We should ensure each row of the memory is initialized to be unique,
        so initialize the matrix to be the identity. We then pad or truncate
        as necessary so that init_state is of size
        (batch_size, self._mem_slots, self._mem_size).

        Args:
          batch_size: The size of the batch.

        Returns:
          init_state: A truncated or padded matrix of size
            (batch_size, self._mem_slots, self._mem_size).
        """
        init_state = tf.eye(self._mem_slots, batch_shape=[batch_size])

        # Pad the matrix with zeros.
        if self._mem_size > self._mem_slots:
            difference = self._mem_size - self._mem_slots
            pad = tf.zeros((batch_size, self._mem_slots, difference))
            init_state = tf.concat([init_state, pad], -1)
        # Truncation. Take the first `self._mem_size` components.
        elif self._mem_size < self._mem_slots:
            init_state = init_state[:, :, :self._mem_size]
        return init_state

    def _multihead_attention(self, memory):
        """Perform multi-head attention from 'Attention is All You Need'.

        Implementation of the attention mechanism from
        https://arxiv.org/abs/1706.03762.

        Args:
          memory: Memory tensor to perform attention on, with size [B, N, H*V].

        Returns:
          new_memory: New memory tensor.
        """

        qkv_size = 2 * self._key_size + self._head_size
        total_size = qkv_size * self._num_heads  # Denote as F.
        batch_size = memory.get_shape().as_list()[0]  # Denote as B
        memory_flattened = tf.reshape(memory, [-1, self._mem_size])  # [B * N, H * V]
        qkv = linear(memory_flattened, total_size, use_bias=False, scope='lin_qkv')  # [B*N, F]
        qkv = tf.reshape(qkv, [batch_size, -1, total_size])  # [B, N, F]
        qkv = tf.contrib.layers.layer_norm(qkv, trainable=True)  # [B, N, F]

        # [B, N, F] -> [B, N, H, F/H]
        qkv_reshape = tf.reshape(qkv, [batch_size, -1, self._num_heads, qkv_size])

        # [B, N, H, F/H] -> [B, H, N, F/H]
        qkv_transpose = tf.transpose(qkv_reshape, [0, 2, 1, 3])
        q, k, v = tf.split(qkv_transpose, [self._key_size, self._key_size, self._head_size], -1)

        q *= qkv_size ** -0.5
        dot_product = tf.matmul(q, k, transpose_b=True)  # [B, H, N, N]
        weights = tf.nn.softmax(dot_product)

        output = tf.matmul(weights, v)  # [B, H, N, V]

        # [B, H, N, V] -> [B, N, H, V]
        output_transpose = tf.transpose(output, [0, 2, 1, 3])

        # [B, N, H, V] -> [B, N, H * V]
        new_memory = tf.reshape(output_transpose, [batch_size, -1, self._mem_size])
        return new_memory

    @property
    def state_size(self):
        return tf.TensorShape([self._mem_slots, self._mem_size])

    @property
    def output_size(self):
        return tf.TensorShape(self._mem_slots * self._mem_size)

    def _calculate_gate_size(self):
        """Calculate the gate size from the gate_style.

        Returns:
          The per sample, per head parameter size of each gate.
        """
        if self._gate_style == 'unit':
            return self._mem_size
        elif self._gate_style == 'memory':
            return 1
        else:  # self._gate_style == None
            return 0

    def _create_gates(self, inputs, memory):
        """Create input and forget gates for this step using `inputs` and `memory`.

        Args:
          inputs: Tensor input.
          memory: The current state of memory.

        Returns:
          input_gate: A LSTM-like insert gate.
          forget_gate: A LSTM-like forget gate.
        """
        # We'll create the input and forget gates at once. Hence, calculate double
        # the gate size.
        num_gates = 2 * self._calculate_gate_size()
        batch_size = memory.get_shape().as_list()[0]

        memory = tf.tanh(memory)  # B x N x H * V

        inputs = tf.reshape(inputs, [batch_size, -1])  # B x In_size
        gate_inputs = linear(inputs, num_gates, use_bias=False, scope='gate_in')  # B x num_gates
        gate_inputs = tf.expand_dims(gate_inputs, axis=1)  # B x 1 x num_gates

        memory_flattened = tf.reshape(memory, [-1, self._mem_size])  # [B * N, H * V]
        gate_memory = linear(memory_flattened, num_gates, use_bias=False, scope='gate_mem')  # [B * N, num_gates]
        gate_memory = tf.reshape(gate_memory, [batch_size, self._mem_slots, num_gates])  # [B, N, num_gates]

        gates = tf.split(gate_memory + gate_inputs, num_or_size_splits=2, axis=2)
        input_gate, forget_gate = gates  # B x N x num_gates/2, B x N x num_gates/2

        input_gate = tf.sigmoid(input_gate + self._input_bias)
        forget_gate = tf.sigmoid(forget_gate + self._forget_bias)

        return input_gate, forget_gate

    def _attend_over_memory(self, memory):
        """Perform multiheaded attention over `memory`.

        Args:
          memory: Current relational memory.

        Returns:
          The attended-over memory.
        """

        for _ in range(self._num_blocks):
            attended_memory = self._multihead_attention(memory)  # [B, N, H * V]

            # Add a skip connection to the multiheaded attention's input.
            memory = tf.contrib.layers.layer_norm(memory + attended_memory, trainable=True)  # [B, N, H * V]

            # Add a mlp map
            batch_size = memory.get_shape().as_list()[0]

            memory_mlp = tf.reshape(memory, [-1, self._mem_size])  # [B * N, H * V]
            memory_mlp = mlp(memory_mlp, [self._mem_size] * self._attention_mlp_layers)  # [B * N, H * V]
            memory_mlp = tf.reshape(memory_mlp, [batch_size, -1, self._mem_size])

            # Add a skip connection to the memory_mlp's input.
            memory = tf.contrib.layers.layer_norm(memory + memory_mlp, trainable=True)  # [B, N, H * V]

        return memory

    def _build(self, inputs, memory):
        """Adds relational memory to the TensorFlow graph.

        Args:
          inputs: Tensor input.
          memory: Memory output from the previous time step.

        Returns:
          output: This time step's output.
          next_memory: The next version of memory to use.
        """

        batch_size = memory.get_shape().as_list()[0]
        inputs = tf.reshape(inputs, [batch_size, -1])  # [B, In_size]
        inputs = linear(inputs, self._mem_size, use_bias=True, scope='input_for_cancat')  # [B, V * H]
        inputs_reshape = tf.expand_dims(inputs, 1)  # [B, 1, V * H]

        memory_plus_input = tf.concat([memory, inputs_reshape], axis=1)  # [B, N + 1, V * H]
        next_memory = self._attend_over_memory(memory_plus_input)  # [B, N + 1, V * H]

        n = inputs_reshape.get_shape().as_list()[1]
        next_memory = next_memory[:, :-n, :]  # [B, N, V * H]

        if self._gate_style == 'unit' or self._gate_style == 'memory':
            self._input_gate, self._forget_gate = self._create_gates(inputs_reshape, memory)
            next_memory = self._input_gate * tf.tanh(next_memory)
            next_memory += self._forget_gate * memory

        output = tf.reshape(next_memory, [batch_size, -1])
        return output, next_memory

    def __call__(self, *args, **kwargs):
        """Operator overload for calling.

        This is the entry point when users connect a Module into the Graph. The
        underlying _build method will have been wrapped in a Template by the
        constructor, and we call this template with the provided inputs here.

        Args:
          *args: Arguments for underlying _build method.
          **kwargs: Keyword arguments for underlying _build method.

        Returns:
          The result of the underlying _build method.
        """
        outputs = self._template(*args, **kwargs)

        return outputs

    @property
    def input_gate(self):
        """Returns the input gate Tensor."""
        return self._input_gate

    @property
    def forget_gate(self):
        """Returns the forget gate Tensor."""
        return self._forget_gate

    @property
    def rmc_params(self):
        """Returns the parameters in the RMC module"""
        return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self._name)

    def set_rmc_params(self, ref_rmc_params):
        """Set parameters of the RMC module to be the same with those of the reference module"""
        rmc_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self._name)
        if len(rmc_params) != len(ref_rmc_params):
            raise ValueError("the number of parameters in the two RMC modules does not match")
        for i in range(len(ref_rmc_params)):
            rmc_params[i] = tf.identity(ref_rmc_params[i])

    def update_rmc_params(self, ref_rmc_params, update_ratio):
        """Update parameters of the RMC module based on a reference module"""
        rmc_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self._name)
        if len(rmc_params) != len(ref_rmc_params):
            raise ValueError("the number of parameters in the two RMC modules does not match")
        for i in range(len(ref_rmc_params)):
            rmc_params[i] = update_ratio * rmc_params[i] + (1 - update_ratio) * tf.identity(ref_rmc_params[i])


## GENERATOR AND DISCRIMINATOR ## 

In [None]:
from tensorflow.python.ops import tensor_array_ops, control_flow_ops

# The generator network based on the Relational Memory
def generator(x_real, temperature, vocab_size, batch_size, seq_len, gen_emb_dim, mem_slots, head_size, num_heads,
              hidden_dim, start_token):
    start_tokens = tf.constant([start_token] * batch_size, dtype=tf.int32)
    output_size = mem_slots * head_size * num_heads

    # build relation memory module
    g_embeddings = tf.get_variable('g_emb', shape=[vocab_size, gen_emb_dim],
                                   initializer=create_linear_initializer(vocab_size))
    gen_mem = RelationalMemory(mem_slots=mem_slots, head_size=head_size, num_heads=num_heads)
    g_output_unit = create_output_unit(output_size, vocab_size)

    # initial states
    init_states = gen_mem.initial_state(batch_size)

    # ---------- generate tokens and approximated one-hot results (Adversarial) ---------
    gen_o = tensor_array_ops.TensorArray(dtype=tf.float32, size=seq_len, dynamic_size=False, infer_shape=True)
    gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=seq_len, dynamic_size=False, infer_shape=True)
    gen_x_onehot_adv = tensor_array_ops.TensorArray(dtype=tf.float32, size=seq_len, dynamic_size=False,
                                                    infer_shape=True)  # generator output (relaxed of gen_x)

    # the generator recurrent module used for adversarial training
    def _gen_recurrence(i, x_t, h_tm1, gen_o, gen_x, gen_x_onehot_adv):
        mem_o_t, h_t = gen_mem(x_t, h_tm1)  # hidden_memory_tuple
        o_t = g_output_unit(mem_o_t)  # batch x vocab, logits not probs
        gumbel_t = add_gumbel(o_t)
        next_token = tf.stop_gradient(tf.argmax(gumbel_t, axis=1, output_type=tf.int32))
        next_token_onehot = tf.one_hot(next_token, vocab_size, 1.0, 0.0)

        x_onehot_appr = tf.nn.softmax(tf.multiply(gumbel_t, temperature))  # one-hot-like, [batch_size x vocab_size]

        # x_tp1 = tf.matmul(x_onehot_appr, g_embeddings)  # approximated embeddings, [batch_size x emb_dim]
        x_tp1 = tf.nn.embedding_lookup(g_embeddings, next_token)  # embeddings, [batch_size x emb_dim]

        gen_o = gen_o.write(i, tf.reduce_sum(tf.multiply(next_token_onehot, x_onehot_appr), 1))  # [batch_size], prob
        gen_x = gen_x.write(i, next_token)  # indices, [batch_size]

        gen_x_onehot_adv = gen_x_onehot_adv.write(i, x_onehot_appr)

        return i + 1, x_tp1, h_t, gen_o, gen_x, gen_x_onehot_adv

    # build a graph for outputting sequential tokens
    _, _, _, gen_o, gen_x, gen_x_onehot_adv = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3, _4, _5: i < seq_len,
        body=_gen_recurrence,
        loop_vars=(tf.constant(0, dtype=tf.int32), tf.nn.embedding_lookup(g_embeddings, start_tokens),
                   init_states, gen_o, gen_x, gen_x_onehot_adv))

    gen_o = tf.transpose(gen_o.stack(), perm=[1, 0])  # batch_size x seq_len
    gen_x = tf.transpose(gen_x.stack(), perm=[1, 0])  # batch_size x seq_len

    gen_x_onehot_adv = tf.transpose(gen_x_onehot_adv.stack(), perm=[1, 0, 2])  # batch_size x seq_len x vocab_size

    # ----------- pre-training for generator -----------------
    x_emb = tf.transpose(tf.nn.embedding_lookup(g_embeddings, x_real), perm=[1, 0, 2])  # seq_len x batch_size x emb_dim
    g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32, size=seq_len, dynamic_size=False, infer_shape=True)

    ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32, size=seq_len)
    ta_emb_x = ta_emb_x.unstack(x_emb)

    # the generator recurrent moddule used for pre-training
    def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
        mem_o_t, h_t = gen_mem(x_t, h_tm1)
        o_t = g_output_unit(mem_o_t)
        g_predictions = g_predictions.write(i, tf.nn.softmax(o_t))  # batch_size x vocab_size
        x_tp1 = ta_emb_x.read(i)
        return i + 1, x_tp1, h_t, g_predictions

    # build a graph for outputting sequential tokens
    _, _, _, g_predictions = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3: i < seq_len,
        body=_pretrain_recurrence,
        loop_vars=(tf.constant(0, dtype=tf.int32), tf.nn.embedding_lookup(g_embeddings, start_tokens),
                   init_states, g_predictions))

    g_predictions = tf.transpose(g_predictions.stack(),
                                 perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

    # pre-training loss
    pretrain_loss = -tf.reduce_sum(
        tf.one_hot(tf.to_int32(tf.reshape(x_real, [-1])), vocab_size, 1.0, 0.0) * tf.log(
            tf.clip_by_value(tf.reshape(g_predictions, [-1, vocab_size]), 1e-20, 1.0)
        )
    ) / (seq_len * batch_size)

    return gen_x_onehot_adv, gen_x, pretrain_loss, gen_o


# The discriminator network based on the CNN classifier
def discriminator(x_onehot, batch_size, seq_len, vocab_size, dis_emb_dim, num_rep, sn):
    # get the embedding dimension for each presentation
    emb_dim_single = int(dis_emb_dim / num_rep)
    assert isinstance(emb_dim_single, int) and emb_dim_single > 0

    filter_sizes = [2, 3, 4, 5]
    num_filters = [300, 300, 300, 300]
    dropout_keep_prob = 0.75

    d_embeddings = tf.get_variable('d_emb', shape=[vocab_size, dis_emb_dim],
                                   initializer=create_linear_initializer(vocab_size))
    input_x_re = tf.reshape(x_onehot, [-1, vocab_size])
    emb_x_re = tf.matmul(input_x_re, d_embeddings)
    emb_x = tf.reshape(emb_x_re, [batch_size, seq_len, dis_emb_dim])  # batch_size x seq_len x dis_emb_dim

    emb_x_expanded = tf.expand_dims(emb_x, -1)  # batch_size x seq_len x dis_emb_dim x 1
    print('shape of emb_x_expanded: {}'.format(emb_x_expanded.get_shape().as_list()))

    # Create a convolution + maxpool layer for each filter size
    pooled_outputs = []
    for filter_size, num_filter in zip(filter_sizes, num_filters):
        conv = conv2d(emb_x_expanded, num_filter, k_h=filter_size, k_w=emb_dim_single,
                      d_h=1, d_w=emb_dim_single, sn=sn, stddev=None, padding='VALID',
                      scope="conv-%s" % filter_size)  # batch_size x (seq_len-k_h+1) x num_rep x num_filter
        out = tf.nn.relu(conv, name="relu")
        pooled = tf.nn.max_pool(out, ksize=[1, seq_len - filter_size + 1, 1, 1],
                                strides=[1, 1, 1, 1], padding='VALID',
                                name="pool")  # batch_size x 1 x num_rep x num_filter
        pooled_outputs.append(pooled)

    # Combine all the pooled features
    num_filters_total = sum(num_filters)
    h_pool = tf.concat(pooled_outputs, 3)  # batch_size x 1 x num_rep x num_filters_total
    print('shape of h_pool: {}'.format(h_pool.get_shape().as_list()))
    h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total])

    # Add highway
    h_highway = highway(h_pool_flat, h_pool_flat.get_shape()[1], 1, 0)  # (batch_size*num_rep) x num_filters_total

    # Add dropout
    h_drop = tf.nn.dropout(h_highway, dropout_keep_prob, name='dropout')

    # fc
    fc_out = linear(h_drop, output_size=100, use_bias=True, sn=sn, scope='fc')
    logits = linear(fc_out, output_size=1, use_bias=True, sn=sn, scope='logits')
    logits = tf.squeeze(logits, -1)  # batch_size*num_rep

    return logits



In [None]:
import tensorflow as tf


generator_dict = {
    'rmc_vanilla': generator,
}

discriminator_dict = {
    'rmc_vanilla': discriminator,
}


def get_generator(model_name, scope='generator', **kwargs):
    model_func = generator_dict[model_name]
    return tf.make_template(scope, model_func, **kwargs)


def get_discriminator(model_name, scope='discriminator', **kwargs):
    model_func = discriminator_dict[model_name]
    return tf.make_template(scope, model_func, **kwargs)

## RUN ## 

In [None]:
import argparse
import os


parser = argparse.ArgumentParser(description='Train and run a RmcGAN')
# Architecture
parser.add_argument('--gf-dim', default=64, type=int, help='Number of filters to use for generator')
parser.add_argument('--df-dim', default=64, type=int, help='Number of filters to use for discriminator')
parser.add_argument('--g-architecture', default='rmc_att', type=str, help='Architecture for generator')
parser.add_argument('--d-architecture', default='rmc_att', type=str, help='Architecture for discriminator')
parser.add_argument('--gan-type', default='standard', type=str, help='Which type of GAN to use')
parser.add_argument('--hidden-dim', default=32, type=int, help='only used for OrcaleLstm and lstm_vanilla (generator)')
parser.add_argument('--sn', default=False, action='store_true', help='if using spectral norm')

# Training
parser.add_argument('--gsteps', default='1', type=int, help='How many training steps to use for generator')
parser.add_argument('--dsteps', default='5', type=int, help='How many training steps to use for discriminator')
parser.add_argument('--npre-epochs', default=150, type=int, help='Number of steps to run pre-training')
parser.add_argument('--nadv-steps', default=5000, type=int, help='Number of steps to run adversarial training')
parser.add_argument('--ntest', default=50, type=int, help='How often to run tests')
parser.add_argument('--d-lr', default=1e-4, type=float, help='Learning rate for the discriminator')
parser.add_argument('--gpre-lr', default=1e-2, type=float, help='Learning rate for the generator in pre-training')
parser.add_argument('--gadv-lr', default=1e-4, type=float, help='Learning rate for the generator in adv-training')
parser.add_argument('--batch-size', default=64, type=int, help='Batch size for training')
parser.add_argument('--log-dir', default='./oracle/logs', type=str, help='Where to store log and checkpoint files')
parser.add_argument('--sample-dir', default='./oracle/samples', type=str, help='Where to put samples during training')
parser.add_argument('--optimizer', default='adam', type=str, help='training method')
parser.add_argument('--decay', default=False, action='store_true', help='if decaying learning rate')
parser.add_argument('--adapt', default='exp', type=str, help='temperature control policy: [no, lin, exp, log, sigmoid, quad, sqrt]')
parser.add_argument('--seed', default=123, type=int, help='for reproducing the results')
parser.add_argument('--temperature', default=1000, type=float, help='the largest temperature')

# evaluation
parser.add_argument('--nll-oracle', default=False, action='store_true', help='if using nll-oracle metric')
parser.add_argument('--nll-gen', default=False, action='store_true', help='if using nll-gen metric')
parser.add_argument('--bleu', default=False, action='store_true', help='if using bleu metric, [2,3,4,5]')
parser.add_argument('--selfbleu', default=False, action='store_true', help='if using selfbleu metric, [2,3,4,5]')
parser.add_argument('--doc-embsim', default=False, action='store_true', help='if using DocEmbSim metric')

# relational memory
parser.add_argument('--mem-slots', default=1, type=int, help="memory size")
parser.add_argument('--head-size', default=512, type=int, help="head size or memory size")
parser.add_argument('--num-heads', default=2, type=int, help="number of heads")

# Data
parser.add_argument('--dataset', default='oracle', type=str, help='[oracle, image_coco, emnlp_news]')
parser.add_argument('--vocab-size', default=5000, type=int, help="vocabulary size")
parser.add_argument('--start-token', default=0, type=int, help="start token for a sentence")
parser.add_argument('--seq-len', default=20, type=int, help="sequence length: [20, 40]")
parser.add_argument('--num-sentences', default=10000, type=int, help="number of total sentences")
parser.add_argument('--gen-emb-dim', default=32, type=int, help="generator embedding dimension")
parser.add_argument('--dis-emb-dim', default=64, type=int, help="TOTAL discriminator embedding dimension")
parser.add_argument('--num-rep', default=64, type=int, help="number of discriminator embedded representations")
parser.add_argument('--data-dir', default='./data', type=str, help='Where data data is stored')


def main():
    config = dict()
    #args = parser.parse_args()
    #pp.pprint(vars(args))
    
    config['gf_dim'] = 64 
    config['df_dim'] = 64
    config['g_architecture'] = 'rmc_vanilla'
    config['d_architecture'] = 'rmc_vanilla'
    config['gan_type'] = 'RSGAN'
    config['hidden_dim'] = 32
    config['sn'] = False
    
    config['gsteps'] = 1
    config['dsteps'] = 1
    config['npre_epochs'] = 30
    config['nadv_steps'] = 5000
    config['ntest'] = 25
    config['d_lr'] = 1e-4
    config['gpre_lr'] = 1e-2
    config['gadv_lr'] = 1e-4
    config['batch_size'] = 32
    config['log_dir'] = '/content/drive/My Drive/oracle/logs'
    config['sample_dir'] = '/content/drive/My Drive/oracle/samples'
    config['optimizer'] = 'adam'
    config['decay'] = False
    config['adapt'] = 'exp'
    config['seed'] = 171
    config['temperature'] = 100
    config['nll_oracle']= False
    config['nll_gen'] = True
    config['bleu'] = True
    config['selfbleu'] = False
    config['doc_embsim'] = False
    
    config['mem_slots'] = 1
    config['head_size'] = 256
    config['num_heads'] = 2 
    
    config['dataset'] = 'fce_train_new'
    config['vocab_size'] = 5000
    config['start_token'] = 0
    config['seq_len'] = 20
    config['num_sentences'] = 300
    config['num_rep'] = 64
    config['data_dir'] = '/content/drive/My Drive/data'
    
    config['gen_emb_dim'] = 32
    config['dis_emb_dim'] = 64
   
    config['checkpoint_restore'] = False 
    config['checkpoint'] = '/content/drive/My Drive/model_saver/1525.ckpt'
    config['check_meta'] = '/content/drive/My Drive/model_saver/1525.ckpt'

    
    # train with different datasets
    if config['dataset'] == 'oracle':
        oracle_model = OracleLstm(num_vocabulary= config['vocab_size'], batch_size=config['batch_size'], emb_dim=config['gen_emb_dim'],
                                  hidden_dim=config['hidden_dim'], sequence_length=config['seq_len'],
                                  start_token=config['start_token'])
        oracle_loader = OracleDataLoader(config['batch_size'],config['seq_len'])
        gen_loader = OracleDataLoader(config['batch_size'],config['seq_len'])

        generator = get_generator(config['g_architecture'], vocab_size=config['vocab_size'], batch_size=config['batch_size'],
                                         seq_len=config['seq_len'], gen_emb_dim=config['gen_emb_dim'], mem_slots=config['mem_slots'],
                                         head_size=config['head_size'], num_heads=config['num_heads'], hidden_dim=config['hidden_dim'],
                                         start_token=config['start_token'])
        discriminator = get_discriminator(config['d_architecture'], batch_size= config['batch_size'], seq_len= config['seq_len'],
                                                 vocab_size= config['vocab_size'], dis_emb_dim=config['dis_emb_dim'],
                                                 num_rep= config['num_rep'], sn= config['sn'])
        oracle_train(generator, discriminator, oracle_model, oracle_loader, gen_loader, config)

    elif config['dataset'] in ['image_coco', 'emnlp_news', 'fce_train_new']:
        data_file = os.path.join(config['data_dir'], '{}.txt'.format(config['dataset']))
        seq_len, vocab_size = text_precess(data_file)
        config['seq_len'] = seq_len
        config['vocab_size'] = vocab_size
        print('seq_len: %d, vocab_size: %d' % (seq_len, vocab_size))

        oracle_loader = RealDataLoader(config['batch_size'], config['seq_len'])

        generator = get_generator(config['g_architecture'], vocab_size=vocab_size, batch_size=config['batch_size'],
                                         seq_len=seq_len, gen_emb_dim=config['gen_emb_dim'], mem_slots=config['mem_slots'],
                                         head_size=config['head_size'], num_heads=config['num_heads'], hidden_dim=config['hidden_dim'],
                                         start_token=config['start_token'])
        discriminator = get_discriminator(config['d_architecture'], batch_size=config['batch_size'], seq_len=seq_len,
                                                 vocab_size=vocab_size, dis_emb_dim=config['dis_emb_dim'],
                                                 num_rep=config['num_rep'], sn=config['sn'])
        real_train(generator, discriminator, oracle_loader, config)

    else:
        raise NotImplementedError('{}: unknown dataset!'.format(config['dataset']))

        
if __name__ == '__main__':
    main()