In [1]:
import numpy as np
import tensorflow as tf
import random



In [2]:
from dataloader import Gen_Data_loader, Dis_Data_loader
import pickle
from generator import Generator
from discriminator import Discriminator

In [3]:
print(tf.__version__)

2.1.0


In [4]:

#########################################################################################
#  Generator  Hyper-parameters
######################################################################################
EMB_DIM = 200 # embedding dimension
HIDDEN_DIM = 200 # hidden state dimension of lstm cell
MAX_SEQ_LENGTH = 17  # max sequence length
BATCH_SIZE = 64


#########################################################################################
#  Discriminator  Hyper-parameters
#########################################################################################
dis_embedding_dim = 64
dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15]
dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160]
dis_dropout_keep_prob = 0.75
dis_l2_reg_lambda = 0.2
dis_batch_size = 64


In [5]:
#########################################################################################
#  Basic Training Parameters
#########################################################################################
TOTAL_BATCH = 2000
dataset_path = "../../data/movie/"
emb_dict_file = dataset_path + "imdb_word.vocab"

# imdb corpus
imdb_file_txt = dataset_path + "imdb/imdb_sentences.txt"
imdb_file_id = dataset_path + "imdb/imdb_sentences.id"

# sstb corpus
sst_pos_file_txt = dataset_path + 'sstb/sst_pos_sentences.txt'
sst_pos_file_id = dataset_path + 'sstb/sst_pos_sentences.id'
sst_neg_file_txt = dataset_path + 'sstb/sst_neg_sentences.txt'
sst_neg_file_id = dataset_path + 'sstb/sst_neg_sentences.id'


eval_file = 'save/eval_file.txt'
eval_text_file = 'save/eval_text_file.txt'
negative_file = 'save/generator_sample.txt'
infer_file = 'save/infer/'


In [6]:
def generate_samples(sess, trainable_model, generated_num, output_file, vocab_list, if_log=False, epoch=0):
    # Generate Samples
    generated_samples = []
    for _ in range(int(generated_num)):
        generated_samples.extend(trainable_model.generate(sess))

    if if_log:
        mode = 'a'
        if epoch == 0:
            mode = 'w'
        with open(eval_text_file, mode) as fout:
            # id_str = 'epoch:%d ' % epoch
            for poem in generated_samples:
                poem = list(poem)
                if 2 in poem:
                    poem = poem[:poem.index(2)]
                buffer = ' '.join([vocab_list[x] for x in poem]) + '\n'
                fout.write(buffer)

    with open(output_file, 'w') as fout:
        for poem in generated_samples:
            poem = list(poem)
            if 2 in poem:
                poem = poem[:poem.index(2)]
            buffer = ' '.join([str(x) for x in poem]) + '\n'
            fout.write(buffer)


In [7]:
def generate_infer(sess, trainable_model, epoch, vocab_list):
    generated_samples = []
    for _ in range(int(100)):
        # generated_samples.extend(trainable_model.infer(sess))
        generated_samples.extend(trainable_model.generate(sess))
    file = infer_file+str(epoch)+'.txt'
    with open(file, 'w') as fout:
        for poem in generated_samples:
            poem = list(poem)
            if 2 in poem:
                poem = poem[:poem.index(2)]
            buffer = ' '.join([vocab_list[x] for x in poem]) + '\n'
            fout.write(buffer)
    print("%s saves" % file)
    return

In [8]:
def produce_samples(generated_samples):
    produces_sample = []
    for poem in generated_samples:
        poem_list = []
        for ii in poem:
            if ii == 0:  # _PAD
                continue
            if ii == 2:  # _EOS
                break
            poem_list.append(ii)
        produces_sample.append(poem_list)
    return produces_sample

In [9]:
def load_emb_data(emb_dict_file):
    word_dict = {}
    word_list = []
    item = 0
    with open(emb_dict_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            word = line.strip()
            word_dict[word] = item
            item += 1
            word_list.append(word)
    length = len(word_dict)
    print("Load embedding success! Num: %d" % length)
    return word_dict, length, word_list




In [None]:
def pre_train_epoch(sess, trainable_model, data_loader):
    # Pre-train the generator using MLE for one epoch
    supervised_g_losses = []
    data_loader.reset_pointer()

    for it in range(200):  # data_loader.num_batch):
        batch = data_loader.next_batch()
        _, g_loss = trainable_model.pretrain_step(sess, batch)
        supervised_g_losses.append(g_loss)

    return np.mean(supervised_g_losses)