In [1]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import subprocess
import pdb
import time
import datetime
import math
import random
import _pickle as cPickle
from collections import defaultdict

from six.moves import zip_longest
import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import distributions as tfd
from tensorflow.keras.preprocessing.sequence import pad_sequences

from data_structure import get_batches, get_test_batches
from components import tf_log, sample_latents, compute_kl_losses, dynamic_rnn, dynamic_bi_rnn

from topic_beam_search_decoder import BeamSearchDecoder

In [2]:
PAD = '<pad>' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
UNK = '<unk>' # This has a vocab id, which is used to represent out-of-vocabulary words
BOS = '<p>' # This has a vocab id, which is used at the beginning of every decoder input sequence
EOS = '</p>' # This has a vocab id, which is used at the end of untruncated target sequences

# load data & set config

In [3]:
def del_all_flags(FLAGS):
    flags_dict = FLAGS._flags()    
    keys_list = [keys for keys in flags_dict]    
    for keys in keys_list:
        FLAGS.__delattr__(keys)

del_all_flags(tf.flags.FLAGS)

flags = tf.app.flags

flags.DEFINE_string('gpu', '2', 'visible gpu')

flags.DEFINE_string('mode', 'train', 'set train or eval')

flags.DEFINE_string('data_path', 'data/bags/instances.pkl', 'path of data')
flags.DEFINE_string('modeldir', 'model/topic_vae', 'directory of model')
flags.DEFINE_string('modelname', 'bags', 'name of model')

flags.DEFINE_integer('epochs', 50, 'epochs')
flags.DEFINE_integer('batch_size', 64, 'number of sentences in each batch')
flags.DEFINE_integer('log_period', 1000, 'valid period')

flags.DEFINE_string('opt', 'Adagrad', 'optimizer')
flags.DEFINE_float('lr', 0.05, 'lr')
flags.DEFINE_float('reg', 1., 'regularization term')
flags.DEFINE_float('grad_clip', 5., 'grad_clip')

flags.DEFINE_float('keep_prob', 0.8, 'dropout rate')
flags.DEFINE_float('word_keep_prob', 0.75, 'word dropout rate')

flags.DEFINE_bool('warmup', True, 'flg of warming up')
flags.DEFINE_integer('epochs_cycle', 5, 'number of epochs within a cycle')
flags.DEFINE_float('r_cycle', 0.5, 'proportion used to increase beta within a cycle')
flags.DEFINE_integer('warmup_topic', 0, 'warmup period for KL of topic')

flags.DEFINE_integer('beam_width', 2, 'beam_width')
flags.DEFINE_float('length_penalty_weight', 0.0, 'length_penalty_weight')

flags.DEFINE_integer('dim_hidden_bow', 256, 'dim of hidden bow')
flags.DEFINE_integer('dim_latent_bow', 32, 'dim of latent topic')
flags.DEFINE_integer('dim_emb', 256, 'dim_emb')
flags.DEFINE_integer('dim_hidden', 512, 'dim_hidden')
flags.DEFINE_integer('dim_hidden_topic', 512, 'dim_hidden_topic')
flags.DEFINE_integer('dim_latent_topic', 32, 'dim_latent_topic')
flags.DEFINE_bool('bidirectional', True, 'flg of bidirectional encoding')

# for evaluation
flags.DEFINE_string('refdir', 'ref', 'refdir')
flags.DEFINE_string('outdir', 'out', 'outdir')

flags.DEFINE_string('f', '', 'kernel')
flags.DEFINE_bool('logtostderr', True, 'kernel')
flags.DEFINE_bool('showprefixforinfo', False, '')
flags.DEFINE_bool('verbosity', False, '')
# flags.DEFINE_integer('stderrthreshold', 20, 'kernel')

config = flags.FLAGS

flags.DEFINE_string('modelpath', os.path.join(config.modeldir, config.modelname), 'path of model')

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu

In [5]:
instances_train, instances_dev, instances_test, word_to_idx, idx_to_word, bow_idxs = cPickle.load(open(config.data_path,'rb'))

In [6]:
train_batches = get_batches(instances_train, config.batch_size)
dev_batches = get_batches(instances_dev, config.batch_size)
test_batches = get_test_batches(instances_test, config.batch_size)

In [7]:
flags.DEFINE_integer('PAD_IDX', word_to_idx[PAD], 'PAD_IDX')
flags.DEFINE_integer('UNK_IDX', word_to_idx[UNK], 'UNK_IDX')
flags.DEFINE_integer('BOS_IDX', word_to_idx[BOS], 'BOS_IDX')
flags.DEFINE_integer('EOS_IDX', word_to_idx[EOS], 'EOS_IDX')

flags.DEFINE_integer('n_vocab', len(word_to_idx), 'n_vocab')
flags.DEFINE_integer('dim_bow', len(bow_idxs), 'dim_bow')

maximum_iterations = max([max([instance.max_sent_l for instance in batch]) for ct, batch in dev_batches])
flags.DEFINE_integer('maximum_iterations', maximum_iterations, 'maximum_iterations')

flags.DEFINE_integer('cycle_steps', len(train_batches)*config.epochs_cycle, 'number of steps for each cycle')

In [8]:
def debug_shape(variables):
    sample_batch = dev_batches[0][1]
    feed_dict = get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable.shape)
        else:
            print(_variable.shape)

def debug_value(variables, return_value=False):
    sample_batch = test_batches[0][1]
    feed_dict = get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)

    if return_value: 
        return _variables
    else:
        for _variable, variable in zip(_variables, variables):
            if hasattr(variable, 'name'):
                print(variable.name, ':', _variable)
            else:
                print(_variable)
                
def check_shape(variables):
    if 'sess' in globals(): raise
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    sample_batch = test_batches[0][1]
    feed_dict = get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable.shape)
        else:
            print(_variable.shape)
            
    sess.close()
    
def check_value(variables):
    if 'sess' in globals(): raise
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    sample_batch = test_batches[0][1]
    feed_dict = get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable)
        else:
            print(_variable.shape)
            
    sess.close()    
    

# run model 

## init

In [9]:
def get_feed_dict(batch, mode='train'):
    bow = np.array([instance.bow for instance in batch]).astype(np.float32)
    keep_prob = config.keep_prob if mode == 'train' else 1.0
    feed_dict = {
                t_variables['bow']: bow, 
                t_variables['keep_prob']: keep_prob
    }
    return  feed_dict

In [10]:
tf.reset_default_graph()

t_variables = {}
t_variables['bow'] = tf.placeholder(tf.float32, [None, config.dim_bow])
t_variables['keep_prob'] = tf.placeholder(tf.float32)

# tree_idxs = {0:[1, 2], 
#              1:[10, 11], 2:[20, 21, 22], 
#              10: [100, 101], 11: [110, 111, 112], 20: [200, 201], 21: [210, 211], 22:[220, 221, 222]
#              }

# tree_idxs = {0:[1, 2, 3], 
#                       1:[10, 11], 2:[20, 21], 3:[30, 31]}

tree_idxs = {0:[1, 2, 3], 
                      1:[10, 11], 2:[20, 21], 3:[30, 31],
                      10: [100, 101], 11: [110, 111], 20: [200, 201], 21: [210, 211], 30:[300, 301], 31:[310, 311]}


topic_idxs = [0] + [idx for child_idxs in tree_idxs.values() for idx in child_idxs]
child_to_parent_idxs = {child_idx: parent_idx for parent_idx, child_idxs in tree_idxs.items() for child_idx in child_idxs}

## doubly rnn

In [11]:
class DoublyRNNCell:
    def __init__(self, dim_hidden, output_layer=None):
        self.dim_hidden = dim_hidden
        
        self.ancestral_layer=tf.layers.Dense(units=dim_hidden, activation=tf.nn.tanh, name='ancestral')
        self.fraternal_layer=tf.layers.Dense(units=dim_hidden, activation=tf.nn.tanh, name='fraternal')
        self.hidden_layer = tf.layers.Dense(units=dim_hidden, activation=tf.nn.tanh, name='hidden')
        
        self.output_layer=output_layer
        
    def __call__(self, state_ancestral, state_fraternal, reuse=True):
        with tf.variable_scope('input', reuse=reuse):
            state_ancestral = self.ancestral_layer(state_ancestral)
            state_fraternal = self.fraternal_layer(state_fraternal)

        with tf.variable_scope('output', reuse=reuse):
            state_hidden = self.hidden_layer(state_ancestral + state_fraternal)
            if self.output_layer is not None: 
                output = self.output_layer(state_hidden)
            else:
                output = state_hidden
            
        return output, state_hidden
    
    def get_initial_state(self, name):
        initial_state = tf.get_variable(name, [1, self.dim_hidden], dtype=tf.float32)
        return initial_state
    
    def get_zero_state(self, name):
        zero_state = tf.zeros([1, self.dim_hidden], dtype=tf.float32, name=name)
        return zero_state

In [12]:
def doubly_rnn(dim_hidden, tree_idxs, initial_state_parent=None, initial_state_sibling=None, output_layer=None, name=''):
    outputs, states_parent = {}, {}
    
    with tf.variable_scope(name, reuse=False):
        doubly_rnn_cell = DoublyRNNCell(dim_hidden, output_layer)

        if initial_state_parent is None: initial_state_parent = doubly_rnn_cell.get_initial_state('init_state_parent')
        if initial_state_sibling is None: 
#             initial_state_sibling = doubly_rnn_cell.get_initial_state('init_state_sibling')
            initial_state_sibling = doubly_rnn_cell.get_zero_state('init_state_sibling')
        output, state_sibling = doubly_rnn_cell(initial_state_parent, initial_state_sibling, reuse=False)
        outputs[0], states_parent[0] = output, state_sibling

        for parent_idx, child_idxs in tree_idxs.items():
            state_parent = states_parent[parent_idx]
            state_sibling = initial_state_sibling
            for child_idx in child_idxs:
                output, state_sibling = doubly_rnn_cell(state_parent, state_sibling)
                outputs[child_idx], states_parent[child_idx] = output, state_sibling

    return outputs

## stick break

In [13]:
def hierarchical_sbp(tree_sticks_topic, tree_sticks_branch):
    tree_prob_topic = {}
    rest_topics = {}

    # calculate topic probability and save
    stick_topic = tree_sticks_topic[0]
    tree_prob_topic[0] = stick_topic
    rest_topics[0] = 1.-stick_topic
    for parent_idx, child_idxs in tree_idxs.items():
        rest_topic = rest_topics[parent_idx]
        rest_branch = 1.
        for child_idx in child_idxs:
            # calculate topic probability
            if child_idx == child_idxs[-1]: # last child
                prob_branch = rest_branch # phi
            else:
                stick_branch = tree_sticks_branch[child_idx] # psi
                prob_branch = stick_branch * rest_branch # phi

            if not child_idx in tree_idxs: # leaf childs
                prob_topic = prob_branch * rest_topic # pi
            else:
                stick_topic = tree_sticks_topic[child_idx] # upsilon
                prob_topic = stick_topic * prob_branch * rest_topic # pi

            # save topic probability and update rest stick length
            tree_prob_topic[child_idx] = prob_topic
            rest_branch = (1.- stick_branch) * rest_branch
            rest_topics[child_idx] = (1.-stick_topic)*prob_branch * rest_topic
            
    return tree_prob_topic

## build model

In [14]:
def get_tree_topic_bow(tree_topic_embeddings):
    def get_depth(parent_idx=0, tree_depth = None, depth=1):
        if tree_depth is None: tree_depth={0: depth}

        child_idxs = tree_idxs[parent_idx]
        depth +=1
        for child_idx in child_idxs:
            tree_depth[child_idx] = depth
            if child_idx in tree_idxs: get_depth(child_idx, tree_depth, depth)
        return tree_depth
       
    def softmax_with_temperature(logits, axis=None, name=None, temperature=1.):
        if axis is None:
            axis = -1
        return tf.exp(logits / temperature) / tf.reduce_sum(tf.exp(logits / temperature), axis=axis)
    
    tree_depth = get_depth()

    tree_topic_bow = {}
    for topic_idx, depth in tree_depth.items():
        topic_embedding = tree_topic_embeddings[topic_idx]
        temperature = tf.constant(10. ** (1./depth), dtype=tf.float32)
        logits = tf.matmul(topic_embedding, bow_embeddings, transpose_b=True)
        tree_topic_bow[topic_idx] = softmax_with_temperature(logits, axis=-1, temperature=temperature)
    
    return tree_topic_bow

In [15]:
# encode bow
with tf.variable_scope('topic/enc', reuse=False):
    hidden_bow_ = tf.layers.Dense(units=config.dim_hidden_bow, activation=tf.nn.relu, name='hidden_bow')(t_variables['bow'])
    hidden_bow = tf.layers.Dropout(t_variables['keep_prob'])(hidden_bow_)
    means_bow = tf.layers.Dense(units=config.dim_latent_bow, name='mean_bow')(hidden_bow)
    logvars_bow = tf.layers.Dense(units=config.dim_latent_bow, kernel_initializer=tf.constant_initializer(0), bias_initializer=tf.constant_initializer(0), name='logvar_bow')(hidden_bow)
    latents_bow = sample_latents(means_bow, logvars_bow) # sample latent vectors
    prob_layer = lambda h: tf.nn.sigmoid(tf.matmul(latents_bow, h, transpose_b=True))
    
    tree_sticks_topic = doubly_rnn(config.dim_latent_bow, tree_idxs, output_layer=prob_layer, name='sticks_topic')
    tree_sticks_branch = doubly_rnn(config.dim_latent_bow, tree_idxs, output_layer=prob_layer, name='sticks_branch')

    tree_prob_topic = hierarchical_sbp(tree_sticks_topic, tree_sticks_branch)
    prob_topic = tf.concat(list(tree_prob_topic.values()), 1)

# decode bow
with tf.variable_scope('shared', reuse=False):
    embeddings = tf.get_variable('emb', [config.n_vocab, config.dim_emb], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) # embeddings of vocab

bow_embeddings = tf.nn.embedding_lookup(embeddings, bow_idxs) # embeddings of each bow features

with tf.variable_scope('topic/dec', reuse=False):
    tree_hidden_topic_emb = doubly_rnn(config.dim_emb, tree_idxs, name='hidden_topic_emb')
    hidden_topic_emb = tf.concat(list(tree_hidden_topic_emb.values()), 0)
    
    means_topic_emb = tf.layers.Dense(units=config.dim_latent_topic, name='mean_topic_emb')(hidden_topic_emb)
    logvars_topic_emb = tf.layers.Dense(units=config.dim_latent_topic, kernel_initializer=tf.constant_initializer(0), bias_initializer=tf.constant_initializer(0), name='logvar_topic_emb')(hidden_topic_emb)
    latents_topic_emb = sample_latents(means_topic_emb, logvars_topic_emb) # sample latent vectors
    topic_embeddings = tf.layers.Dense(units=config.dim_emb, name='topic_emb')(latents_topic_emb)
    
    tree_topic_embeddings = {topic_idx: tf.expand_dims(topic_embeddings[topic_idxs.index(topic_idx)], 0) for topic_idx in topic_idxs}
    tree_topic_bow = get_tree_topic_bow(tree_topic_embeddings) # bow vectors for each topic
    
    topic_bow = tf.concat([tree_topic_bow[topic_idx] for topic_idx in topic_idxs], 0)
    logits_bow = tf_log(tf.matmul(prob_topic, topic_bow)) # predicted bow distribution

## define loss

In [16]:
def get_tree_mask_reg(tree_idxs):
    tree_mask_reg = np.ones([len(topic_idxs), len(topic_idxs)], dtype=np.float32)
    parent_to_descendant_idxs = {parent_idx: get_descendant_idxs(parent_idx) for parent_idx in tree_idxs}
    
    for parent_idx, descendant_idxs in parent_to_descendant_idxs.items():
        for descendant_idx in descendant_idxs:
            tree_mask_reg[topic_idxs.index(parent_idx), topic_idxs.index(descendant_idx)] = tree_mask_reg[topic_idxs.index(descendant_idx), topic_idxs.index(parent_idx)] = 0.
            
    return tree_mask_reg

def get_descendant_idxs(parent_idx, descendant_idxs = None):
    if descendant_idxs is None: descendant_idxs = []
    
    child_idxs = tree_idxs[parent_idx]
    descendant_idxs += child_idxs
    for child_idx in child_idxs:
        if child_idx in tree_idxs: get_descendant_idxs(child_idx, descendant_idxs)
    return descendant_idxs

In [17]:
# define losses
topic_losses_recon = -tf.reduce_sum(tf.multiply(t_variables['bow'], logits_bow), 1)
topic_loss_recon = tf.reduce_mean(topic_losses_recon) # negative log likelihood of each words

topic_losses_kl_bow = compute_kl_losses(means_bow, logvars_bow) # KL divergence b/w latent dist & gaussian std
topic_loss_kl_bow = tf.reduce_mean(topic_losses_kl_bow, [0]) #mean of kl_losses over batches

means_topic_emb_prior = tf.concat([tf.expand_dims(latents_topic_emb[topic_idxs.index(child_to_parent_idxs[topic_idx])], 0) if topic_idx in child_to_parent_idxs else tf.zeros([1, config.dim_latent_topic], dtype=tf.float32) for topic_idx in topic_idxs], 0)
topic_losses_kl_emb = compute_kl_losses(means_topic_emb, logvars_topic_emb, means_topic_emb_prior)
topic_loss_kl_emb = tf.reduce_sum(topic_losses_kl_emb, [0]) #sum of kl_losses over topics

topic_bow_norm = topic_bow / tf.norm(topic_bow, axis=1, keepdims=True)
topic_dots = tf.clip_by_value(tf.matmul(topic_bow_norm, tf.transpose(topic_bow_norm)), -1., 1.)
tree_mask_reg = get_tree_mask_reg(tree_idxs)
topic_loss_reg = tf.reduce_mean(tf.square(topic_dots - tf.eye(len(topic_idxs))) * tree_mask_reg)

global_step = tf.Variable(0, name='global_step',trainable=False)
tau = tf.cast(tf.divide(tf.mod(global_step, tf.constant(config.cycle_steps)), tf.constant(config.cycle_steps)), dtype=tf.float32)
beta = tf.minimum(1., tau/config.r_cycle)

loss = topic_loss_recon + beta*topic_loss_kl_bow + topic_loss_kl_emb + config.reg * topic_loss_reg

In [18]:
# define optimizer
if config.opt == 'Adam':
    optimizer = tf.train.AdamOptimizer(config.lr)
elif config.opt == 'Adagrad':
    optimizer = tf.train.AdagradOptimizer(config.lr)

grad_vars = optimizer.compute_gradients(loss)
clipped_grad_vars = [(tf.clip_by_value(grad, -config.grad_clip, config.grad_clip), var) for grad, var in grad_vars]
opt = optimizer.apply_gradients(clipped_grad_vars, global_step=global_step)

# monitor
n_bow = tf.reduce_sum(t_variables['bow'], 1)
topic_ppls = tf.divide(topic_losses_recon, tf.maximum(1e-5, n_bow))
topics_freq_bow_indices = tf.nn.top_k(topic_bow, 10, name='topic_freq_bow').indices

In [19]:
def get_loss(sess, batches):
    losses = []
    ppl_list = []
    for ct, batch in batches:
        feed_dict = get_feed_dict(batch, mode='test')
        loss_batch, topic_loss_recon_batch, topic_loss_kl_bow_batch, topic_loss_kl_emb_batch, topic_loss_reg_batch, ppls_batch \
            = sess.run([loss, topic_loss_recon, topic_loss_kl_bow, topic_loss_kl_emb, topic_loss_reg, topic_ppls], feed_dict = feed_dict)
        losses += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_bow_batch, topic_loss_kl_emb_batch, topic_loss_reg_batch]]
        ppl_list += list(ppls_batch)
    loss_mean, topic_loss_recon_mean, topic_loss_kl_bow_mean, topic_loss_kl_emb_mean, topic_loss_reg_mean = np.mean(losses, 0)
    ppl_mean = np.exp(np.mean(ppl_list))
    return loss_mean, topic_loss_recon_mean, topic_loss_kl_bow_mean, topic_loss_kl_emb_mean, topic_loss_reg_mean, ppl_mean

In [20]:
def print_topic_sample(parent_idx=0, topics_freq_bow_idxs=None, depth = 0):
    if topics_freq_bow_idxs is None:
        topics_freq_bow_idxs = bow_idxs[sess.run(topics_freq_bow_indices)]
        topic_freq_bow_idxs = topics_freq_bow_idxs[topic_idxs.index(parent_idx)]
        print(parent_idx, ' '.join([idx_to_word[idx] for idx in topic_freq_bow_idxs]))
    
    child_idxs = tree_idxs[parent_idx]
    depth += 1
    for child_idx in child_idxs:
        topic_freq_bow_idxs = topics_freq_bow_idxs[topic_idxs.index(child_idx)]
        print('  '*depth, child_idx, ' '.join([idx_to_word[idx] for idx in topic_freq_bow_idxs]))
        
        if child_idx in tree_idxs: print_topic_sample(child_idx, topics_freq_bow_idxs, depth)

In [21]:
if 'sess' in globals(): sess.close()
# sess = tf.Session()
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
sess.run(tf.global_variables_initializer())

losses_train = []
ppls_train = []
loss_min = np.inf
beta_eval = 1.
epoch = 0
train_batches = get_batches(instances_train, config.batch_size, iterator=True)
saver = tf.train.Saver(max_to_keep=10)

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','','TRAIN:','TM','','','', '','VALID:','TM','','','', ''],
                            ['Time','Ep','Ct','Beta','LOSS','PPL','NLL','KL(BOW)','KL(EMB)','REG','LOSS','PPL','NLL','KL(BOW)','KL(EMB)','REG']]))))

In [22]:
if len(log_df) == 0:
    cmd_rm = 'rm -r %s' % config.modeldir
    res = subprocess.call(cmd_rm.split())

    cmd_mk = 'mkdir %s' % config.modeldir
    res = subprocess.call(cmd_mk.split())

time_start = time.time()
while epoch < config.epochs:
    for ct, batch in train_batches:
        feed_dict = get_feed_dict(batch)

        _, loss_batch, topic_loss_recon_batch, topic_loss_kl_bow_batch, topic_loss_kl_emb_batch, topic_loss_reg_batch, ppls_batch = \
        sess.run([opt, loss, topic_loss_recon, topic_loss_kl_bow, topic_loss_kl_emb, topic_loss_reg, topic_ppls], feed_dict = feed_dict)
            
        losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_bow_batch, topic_loss_kl_emb_batch, topic_loss_reg_batch]]
        ppls_train += list(ppls_batch)

        if ct%config.log_period==0:
            loss_train, topic_loss_recon_train, topic_loss_kl_bow_train, topic_loss_kl_emb_train, topic_loss_reg_train = np.mean(losses_train, 0)
            ppl_train = np.exp(np.mean(ppls_train))
            loss_dev, topic_loss_recon_dev, topic_loss_kl_bow_dev, topic_loss_kl_emb_dev, topic_loss_reg_dev, ppl_dev = get_loss(sess, dev_batches)
            global_step_log, beta_eval = sess.run([tf.train.get_global_step(), beta])
            
#             if loss_dev < loss_min:
#                 loss_min = loss_dev
#                 saver.save(sess, config.modelpath, global_step=global_step_log)

            clear_output()
    
            time_log = int(time.time() - time_start)
            log_series = pd.Series([time_log, epoch, ct,  '%.3f'%beta_eval, \
                    '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_bow_train, '%.2f'%topic_loss_kl_emb_train, '%.2f'%topic_loss_reg_train, \
                    '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_bow_dev, '%.2f'%topic_loss_kl_emb_dev, '%.2f'%topic_loss_reg_dev],
                    index=log_df.columns)
            log_df.loc[global_step_log] = log_series
            display(log_df)
            
            # visualize topic
            print_topic_sample()

            time_start = time.time()
            
    epoch += 1
    train_batches = get_batches(instances_train, config.batch_size, iterator=True)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,TRAIN:,TM,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,VALID:,TM,Unnamed: 13_level_0,Unnamed: 14_level_0,Unnamed: 15_level_0,Unnamed: 16_level_0
Unnamed: 0_level_1,Time,Ep,Ct,Beta,LOSS,PPL,NLL,KL(BOW),KL(EMB),REG,LOSS,PPL.1,NLL,KL(BOW),KL(EMB),REG
1,3,0,0,0.000,481.92,1035,138.84,0.45,342.34,0.74,1056.84,1034,116.08,0.48,940.03,0.73
1001,24,0,1000,0.176,416.97,593,114.25,3.37,301.83,0.64,345.30,529,105.22,2.10,239.14,0.57
2001,27,0,2000,0.352,383.91,569,113.29,2.30,269.77,0.57,337.72,519,104.80,0.82,232.21,0.43
2276,9,1,0,0.400,379.81,566,113.17,2.12,265.80,0.55,343.26,516,104.79,0.69,237.81,0.39
3276,33,1,1000,0.576,370.40,558,112.86,1.68,256.75,0.50,342.33,525,104.97,0.51,236.74,0.32
4276,29,1,2000,0.752,365.07,555,112.83,1.40,251.47,0.46,343.64,523,104.96,0.38,238.10,0.29
4551,7,2,0,0.800,364.02,554,112.84,1.34,250.43,0.45,334.45,522,104.90,0.33,228.99,0.31
5551,34,2,1000,0.976,361.04,553,112.79,1.15,247.54,0.42,336.43,528,105.16,0.22,230.77,0.29
6551,27,2,2000,1.000,358.93,552,112.81,1.02,245.43,0.40,334.98,523,104.98,0.21,229.51,0.28
6826,6,3,0,1.000,358.36,552,112.77,0.99,244.90,0.39,338.06,526,105.07,0.21,232.52,0.25


0 bought nice quality price carry pro sleeve color room perfect
   1 cover nice ! ; price bought carry quality room pockets
     10 mac ; air bought recommend highly keyboard pro perfect sleeve
       100 perfectly & ... love cover carry perfect sleeve ! room
       101 perfectly & nice love quality bought pro carry ! price
     11 air cover pro sleeve love price ... carry bought room
       110 ! quality sleeve ; price perfect carry room 'm pocket
       111 sleeve bought color price ! perfectly pro room 'm carry
   2 sleeve cover love air mac bought nice keyboard & picture
     20 perfectly & love price carry room color nice mac pockets
       200 ! cover nice price quality love room pockets color 'm
       201 nice cover pro air & ... keyboard cute inside carry
     21 love ; carry color ! bought perfectly quality put room
       210 ! nice recommend ... pro carry perfect price perfectly room
       211 ! color ; nice carry sleeve - perfect 'm pro
   3 ! bought love carry room perfe

# confirm variables

In [23]:
debug_value([latents_topic_emb[:, :10]])

strided_slice_21:0 : [[-2.5982551e-02 -5.6432390e-01  4.0929565e-01  8.5030925e-01
  -1.7676382e-01 -2.6671773e-01 -1.3552337e+00  1.6929364e-01
  -2.8348271e-02 -5.3082258e-01]
 [-4.4079715e-01  5.8031768e-01  1.3397443e+00  7.8160560e-01
   7.4835449e-01  3.9578742e-01  2.4156820e-02 -9.8926175e-01
   5.8203346e-01 -5.7427430e-01]
 [ 1.4157270e-01 -8.6036511e-02 -1.3325736e-02 -8.7471586e-01
   3.1318909e-01 -5.0508392e-01 -5.8611864e-01  4.0088370e-01
   3.8623053e-01 -3.7124756e-01]
 [ 4.7123161e-01 -1.8215059e-01 -2.0138383e-02  6.6253620e-01
  -8.3562136e-01  5.4736203e-01  6.2694035e-02 -2.3076111e-01
   2.3890631e-01  8.7376183e-01]
 [-2.5074175e-01 -2.3621376e-01 -1.0168192e-01  2.9831025e-01
  -5.8458930e-01 -1.3564000e-01  1.7179342e-01 -1.5082026e-01
   2.5052178e-01  4.1207787e-01]
 [ 3.5270238e-01 -9.8326705e-02 -6.8281919e-02  3.2828712e-01
   3.4977266e-01  3.0041650e-01 -1.8540964e-01  5.0612062e-02
   9.9283493e-01 -1.3885978e-01]
 [ 2.5382605e-01  6.9582961e-02  5.51

In [24]:
debug_value([topic_embeddings[:, :10]])

strided_slice_22:0 : [[ 4.71705675e-01  1.09541512e+00 -7.74620295e-01 -3.93208861e-01
   8.80491138e-02 -6.71725273e-01  2.66618967e-01 -2.75311857e-01
   1.08393244e-01  9.25271094e-01]
 [ 9.54161644e-01  6.09670222e-01 -1.16206026e+00 -3.57270807e-01
  -2.10517257e-01  8.99863243e-03  3.53915513e-01  3.86625081e-01
  -2.11970270e-01  6.96777463e-01]
 [ 3.69215131e-01  1.34054244e+00 -4.91978109e-01 -3.48054767e-01
  -2.68570751e-01 -5.28739929e-01  1.92318186e-01 -2.11869419e-01
   2.62838244e-01  9.38213050e-01]
 [ 6.09872460e-01  1.50683057e+00 -3.68538469e-01  5.77289462e-01
  -1.43320754e-01 -1.50334388e-02 -6.82511330e-02 -1.59655303e-01
  -4.28955257e-02  8.77734184e-01]
 [ 2.92465776e-01  5.53109884e-01 -5.43974817e-01 -2.01078653e-01
  -1.26955599e-01 -2.20333427e-01  3.80943477e-01  4.80760932e-01
   3.20056170e-01  5.08757830e-01]
 [ 3.73121709e-01  6.45740628e-01 -9.42302465e-01 -3.39754760e-01
   4.03995246e-01  3.21732879e-01  3.95074606e-01  1.08387262e-01
  -1.1707646

In [25]:
debug_value([prob_topic[0]])

strided_slice_23:0 : [0.6871471  0.14580384 0.04415834 0.03305959 0.01411436 0.01188824
 0.00636758 0.00626512 0.00537253 0.0048368  0.00586794 0.00592379
 0.00602836 0.00587011 0.00236442 0.00223661 0.0024917  0.00231555
 0.00206869 0.00198038 0.00198575 0.00185321]


In [26]:
debug_value([prob_topic[3]])

strided_slice_24:0 : [0.36514091 0.16972587 0.11559934 0.13012975 0.03454283 0.05279605
 0.01175086 0.01729055 0.0098858  0.01412573 0.00954426 0.01053242
 0.00957782 0.01061257 0.00459532 0.00503028 0.00516685 0.00581277
 0.00407689 0.00456276 0.00445924 0.00504113]


In [27]:
debug_value([tf.exp(-tf.divide(topic_losses_recon, n_bow))])

Exp_2:0 : [0.00139044 0.00210681 0.00238436 0.00163164 0.00118835 0.00111976
 0.00139357 0.00143787 0.00207097 0.00240417 0.00173485 0.00243893
 0.00226123 0.00227776 0.00243891 0.00169865 0.00146405 0.00216443
 0.00420724 0.0018686  0.00165764 0.00141199 0.00232757 0.00190915
 0.00114597 0.00348853 0.00224065 0.00175064 0.00178051 0.00132692
 0.00177498 0.00210237 0.00444817 0.0014329  0.00388049 0.00172206
 0.00227368 0.00110812 0.00186194 0.00205981 0.00305567 0.00192876
 0.0017374  0.00137587 0.00164526 0.00230987 0.00143312 0.00219786
 0.00153281 0.00192624 0.0018971  0.00127488 0.00229584]


### test

In [28]:
debug_shape([bow, hidden_bow, latents_bow, prob_topic, bow_embeddings, topic_embeddings, topic_bow, prob_bow])

NameError: name 'bow' is not defined

In [None]:
debug_shape([topic_losses_recon, topic_loss_recon, n_bow, ppls, topic_embeddings_norm, tf.expand_dims(topic_angles_mean, -1), topic_angles_vars])

In [None]:
debug_value([tf.reduce_sum(tf.square(topic_embeddings_norm), 1)], return_value=True)[0]

In [None]:
debug_value([tf.reduce_sum(prob_topic, -1), tf.reduce_sum(topic_bow, -1), tf.reduce_sum(tf.exp(prob_bow), 1)])

In [None]:
sigma_bow = tf.exp(0.5 * logvars_bow)
dist_bow = tfd.Normal(means_bow, sigma_bow)
dist_std = tfd.Normal(0., 1.)
topic_loss_kl_tmp = tf.reduce_mean(tf.reduce_sum(tfd.kl_divergence(dist_bow, dist_std), 1))

In [None]:
debug_value([topic_loss_recon, topic_loss_kl, topic_loss_kl_tmp])

In [None]:
_logvars, _means, _kl_losses, _latents, _output_logits = sess.run([logvars, means, kl_losses, latents, output_logits], feed_dict=feed_dict)


In [None]:
_logvars.shape, _means.shape, _kl_losses.shape, _latents.shape

In [None]:
_output_logits

In [None]:
_output_logits, _dec_target_idxs_do, _dec_mask_tokens_do, _recon_loss, _kl_losses, _ = sess.run([output_logits, dec_target_idxs_do, dec_mask_tokens_do, recon_loss, kl_losses, opt], feed_dict=feed_dict)


In [None]:
tf.reduce_max(output_logits, 2).eval(session=sess, feed_dict=feed_dict).shape

In [None]:
_output_logits.shape, _dec_target_idxs_do.shape, _dec_mask_tokens_do.shape

In [None]:
_logits = np.exp(_output_logits) / np.sum(np.exp(_output_logits), 2)[:, :, None]

In [None]:
_idxs = _dec_target_idxs_do

In [None]:
_losses = np.array([[-np.log(_logits[i, j, _idxs[i, j]]) for j in range(_idxs.shape[1])] for i in range(_idxs.shape[0])]) * _dec_mask_tokens_do

In [None]:
np.sum(_losses)/np.sum(_dec_mask_tokens_do)

In [None]:
_recon_loss

In [None]:
_kl_losses.shape