In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from tabular_simple import TabularSimple
from generator import Generator
from discriminator import Discriminator
from rollout_max_ent import ROLLOUT
import tensorflow as tf
from gan_trainer import GanTrainer
from dataloader import Gen_Data_loader, Dis_dataloader

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
n_modes = 2
n_vocabulary = 4
vocab = range(n_vocabulary)
tabular_model = TabularSimple(4,n_vocabulary,n_modes)

In [4]:
dist = tabular_model.table["10"]
plt.bar(vocab, dist)

KeyError: '10'

In [None]:
size = 10000
samples = tabular_model.sample(size)

In [None]:
np.savetxt('save/real_data_tab.txt',samples,fmt='%d',delimiter=' ')

In [None]:
len(list(tabular_model.table.keys()))

In [None]:
tabular_model.ll(samples)

In [None]:
TabularSimple(4,n_vocabulary,n_modes).ll(samples)

In [None]:
#########################################################################################
#  Generator  Hyper-parameters
######################################################################################
EMB_DIM = 4 # embedding dimension
HIDDEN_DIM = 4 # hidden state dimension of lstm cell
SEQ_LENGTH = 4 # sequence length
START_TOKEN = 0
PRE_EPOCH_NUM = 120 # supervise (maximum likelihood estimation) epochs
SEED = 88
BATCH_SIZE = 128
vocab_size = 4

#########################################################################################
#  Discriminator  Hyper-parameters
#########################################################################################
dis_embedding_dim = 4
dis_filter_sizes = [1, 2, 3, 4]
#dis_num_filters = [200, 200, 200, 200]
dis_num_filters = [10, 10, 10, 10]
dis_dropout_keep_prob = 0.75
dis_l2_reg_lambda = 0.2
dis_batch_size = 128

#########################################################################################
#  Basic Training Parameters
#########################################################################################
TOTAL_BATCH = 200
positive_file = 'save/real_data_tab.txt'
negative_file = 'save/generator_sample_tab.txt'
negative_file_ent = 'save/generator_sample_tab_ent.txt'
#eval_file = 'save/eval_file_tab.txt'
generated_num = 10000
sequence_length = 4
g_lr = 0.01
generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN,learning_rate=g_lr)
generator_ent = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN,learning_rate=g_lr)

discriminator = Discriminator(sequence_length=sequence_length, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, 
                            filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)
rollout = ROLLOUT(generator, 0.8)
rollout_ent = ROLLOUT(generator_ent, 0.8)


In [None]:
gen_data_loader = Gen_Data_loader(BATCH_SIZE,SEQ_LENGTH)
gen_data_loader_ent = Gen_Data_loader(BATCH_SIZE,SEQ_LENGTH)
dis_data_loader = Dis_dataloader(BATCH_SIZE,SEQ_LENGTH)
gan_trainer = GanTrainer(generator,discriminator,rollout,gen_data_loader,dis_data_loader,
           tabular_model,'pretrain_notebook','advtrain_notebook',positive_file,negative_file,BATCH_SIZE)
gan_trainer_ent = GanTrainer(generator_ent,discriminator,rollout_ent,gen_data_loader_ent,dis_data_loader,
           tabular_model,'pretrain_notebook','advtrain_notebook',positive_file,negative_file_ent,BATCH_SIZE)

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

In [None]:
# run from saved checkpoint 
#saver = tf.train.Saver()
#tf.reset_default_graph()
#saver.restore(sess, 'model/pretrain_max_ent_tab.ckpt')
#saver.restore(sess, 'model/advtrain.ckpt')

In [None]:
sess.run(tf.global_variables_initializer())

In [None]:
saver = tf.train.Saver()

In [None]:
generated_num

In [None]:
gan_trainer.pretrain(sess, 20, 20,3,
    saver,dis_dropout_keep_prob,generated_num)
gan_trainer_ent.pretrain(sess, 20, 20,3,
    saver,dis_dropout_keep_prob,generated_num)

In [None]:
#EVEN WITH A VERY HIGH ENTROPY CONSTANT WE DON't see any major difference

In [None]:
for it in range(10000):
    for temp, gen, gan in zip([9999,.25],[generator,generator_ent],[gan_trainer, gan_trainer_ent]):
        test_loss, g_loss = gan.advtrain_gen(sess,1,64,temp)
        policy_ent = sess.run(gen.pretrain_loss,
                {gen.x: gen.generate(sess)})
        class_ = 1
        predictions = np.array([])
        for i in range(10):
            predictions = np.concatenate((predictions,sess.run(discriminator.ypred_for_auc, {discriminator.input_x: gen.generate(sess), discriminator.dropout_keep_prob: dis_dropout_keep_prob})[:,class_]))
        #self.writer.add_scalar('Loss/discrim_loss', disc_loss, total_batch)
        #print("discrim  --  min: {}, max: {}, ll: {}, loss: {}".format(min(predictions),max(predictions),np.mean(np.log(predictions)),disc_loss))
        if it % 10 == 0:
            print("GenT: {:.4f} -  test_loss: {:.4f}, g_loss: {:.4f}, pol_ent: {:.4f}, ll_disc: {:.4f}, maxp_disc: {:.4f}, minp_disc: {:.4f}"
                .format(temp, test_loss, g_loss,policy_ent,np.mean(np.log(predictions)),max(predictions),min(predictions)))
    

In [None]:
#ll is not affected too much by training

In [None]:
x_t = tf.nn.embedding_lookup(generator.g_embeddings, [10]*128)
h_tm1 = generator.h0
h_t = generator.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
o_t = generator.g_output_unit(h_t)  # batch x vocab , logits not prob
#log_prob = tf.log(tf.nn.softmax(o_t))
dist0 = sess.run(tf.nn.softmax(o_t))[0,:]

In [None]:
plt.bar(vocab, dist0)

In [None]:
x_t = tf.nn.embedding_lookup(generator_ent.g_embeddings, [10]*128)
h_tm1 = generator_ent.h0
h_t = generator_ent.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
o_t = generator_ent.g_output_unit(h_t)  # batch x vocab , logits not prob
#log_prob = tf.log(tf.nn.softmax(o_t))
dist0_ent = sess.run(tf.nn.softmax(o_t))[0,:]

In [None]:
plt.bar(vocab, dist0_ent)

In [None]:
plt.bar(vocab,tabular_model.table["10"])