In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import tensorflow as tf
from generator import Generator
from discriminator import Discriminator
from rollout_max_ent import ROLLOUT
from target_lstm import TARGET_LSTM
import pickle
from matplotlib import pyplot as plt
import numpy as np
from sequence_gan_max_ent import *
from IPython.display import clear_output

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:

#########################################################################################
#  Generator  Hyper-parameters
######################################################################################
EMB_DIM = 32 # embedding dimension
HIDDEN_DIM = 32 # hidden state dimension of lstm cell
SEQ_LENGTH = 20 # sequence length
START_TOKEN = 0
PRE_EPOCH_NUM = 120 # supervise (maximum likelihood estimation) epochs
SEED = 88
BATCH_SIZE = 64

#########################################################################################
#  Discriminator  Hyper-parameters
#########################################################################################
dis_embedding_dim = 64
dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]
dis_dropout_keep_prob = 0.75
dis_l2_reg_lambda = 0.2
dis_batch_size = 64

#########################################################################################
#  Basic Training Parameters
#########################################################################################
TOTAL_BATCH = 200
positive_file = 'save/real_data.txt'
negative_file = 'save/generator_sample.txt'
eval_file = 'save/eval_file.txt'
generated_num = 10000

vocab_size = 5000
generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
target_params = pickle.load(open('save/target_params_py3.pkl','rb'))
target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

discriminator = Discriminator(sequence_length=20, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, 
                            filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)
rollout = ROLLOUT(generator, 0.8)


Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.random.categorical instead.
Instructions for updating:
Use tf.cast instead.


KeyboardInterrupt: 

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

In [None]:
# run from scratch
sess.run(tf.global_variables_initializer())

In [None]:
# run from saved checkpoint 
saver = tf.train.Saver()
tf.reset_default_graph()
saver.restore(sess, 'model/pretrain_max_ent.ckpt')
#saver.restore(sess, 'model/advtrain.ckpt')

In [None]:
x_t = tf.nn.embedding_lookup(generator.g_embeddings, generator.start_token)
h_tm1 = generator.h0
h_t = generator.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
o_t = generator.g_output_unit(h_t)  # batch x vocab , logits not prob
#log_prob = tf.log(tf.nn.softmax(o_t))
dist0 = sess.run(tf.nn.softmax(o_t))[0,:]

In [None]:
h_t = target_lstm.g_recurrent_unit(x_t, target_lstm.h0)  # hidden_memory_tuple
o_t = target_lstm.g_output_unit(h_t)  # batch x vocab , logits not prob
#log_prob = tf.log(tf.nn.softmax(o_t))
dist0_target = sess.run(tf.nn.softmax(o_t))[0,:]

In [None]:
dist0.shape

In [None]:
sum(dist0_target)

In [None]:
#plot first state of generator

In [None]:
#plt.plot(np.array([dist0,dist0_target]).T)

In [None]:
plt.plot(np.array([dist0,dist0_target]).T)

In [None]:
max(dist0_target)

In [None]:
#check the discriminator

In [None]:
generator.generate(sess)

In [None]:
#How well does discrim recognize real samples

In [None]:
class_ = 1 
predictions = sess.run(discriminator.ypred_for_auc, {discriminator.input_x: target_lstm.generate(sess), discriminator.dropout_keep_prob: 1.0})[:,class_]

In [None]:
min(predictions)

In [None]:
np.mean(predictions)

In [None]:
#How about generated ones

In [None]:
class_ = 0
predictions = sess.run(discriminator.ypred_for_auc, {discriminator.input_x: generator.generate(sess), discriminator.dropout_keep_prob: 1.0})[:,class_]

#IE before pretraining it's is possible to fool the 

In [None]:
#It is seen even though it will always pick right then we can still get to the situation where the discriminator is more unsure...

In [None]:
min(predictions)

In [None]:
np.mean(predictions)

In [None]:
samples = generator.generate(sess)
rewards = rollout.get_reward(sess, samples, 16, discriminator,999)
rewards

In [None]:
np.std(rewards)

In [None]:
np.std(rewards)

In [None]:
sess.run(generator.lls,{generator.x: generator.generate(sess) })

In [None]:
sess.run(tf.one_hot(tf.to_int32(generator.x), generator.num_emb, 1.0, 0.0) * tf.log(
                    tf.clip_by_value(generator.g_predictions, 1e-20, 1.0)
                ),{generator.x: samples }).shape

In [None]:
for total_batch in range(100):
    # Train the generator for one step
    for it in range(1):
        samples = generator.generate(sess)
        rewards = rollout.get_reward(sess, samples, 16, discriminator,99999)
        feed = {generator.x: samples, generator.rewards: rewards}
        _ = sess.run(generator.g_updates, feed_dict=feed)

    # Test
    #if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
    #    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    #    likelihood_data_loader.create_batches(eval_file)
    #    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    #    buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
    #    print('total_batch: ', total_batch, 'test_loss: ', test_loss)
    #    writer.add_scalar('Loss/oracle_nll', test_loss, total_batch)
    #    log.write(buffer)

    # Update roll-out parameters
    rollout.update_params()
    class_ = 0
    predictions = np.array([])
    for i in range(10):
        predictions = np.concatenate((predictions,sess.run(discriminator.ypred_for_auc, {discriminator.input_x: generator.generate(sess), discriminator.dropout_keep_prob: 1.0})[:,class_]))
    clear_output(wait=True)
    x_t = tf.nn.embedding_lookup(generator.g_embeddings, generator.start_token)
    h_tm1 = generator.h0
    h_t = generator.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
    o_t = generator.g_output_unit(h_t)  # batch x vocab , logits not prob
    #log_prob = tf.log(tf.nn.softmax(o_t))
    dist0 = sess.run(tf.nn.softmax(o_t))[0,:]
    ent_dist0 = -np.sum(dist0*np.log(dist0))
    print("min: {}, mean: {}, dist0_ent: {}".format(min(predictions),np.mean(predictions),ent_dist0))
    plt.plot(np.array([dist0]).T)
    plt.show()
    #time.sleep(1.0),
    

In [None]:
sess.run(discriminator.ypred_for_auc, {discriminator.input_x: generator.generate(sess), discriminator.dropout_keep_prob: 1.0})[:,class_]

In [None]:
#This example shows that we might need to run the reinforce algorithm much longer than we would expect

In [None]:
#This example shows that fixing the discriminator only gives a limited progression in cheating the discriminator
#(as seen by the )

In [None]:
#Could show a high variance when we are undoing stuff. However also shows that we can actually cheat the
#discriminator, however it is only for very few cases. However note that actually the mean also falls

In [None]:
#There is 