In [1]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import subprocess
import pdb
import time
import datetime
import math
import copy
import random
import _pickle as cPickle
from collections import defaultdict
import matplotlib.pyplot as plt

%matplotlib inline
# %matplotlib nbagg

from six.moves import zip_longest
import numpy as np
import pandas as pd
from scipy.stats import hmean

import tensorflow as tf
from tensorflow import distributions as tfd
from tensorflow.keras.preprocessing.sequence import pad_sequences

from data_structure import get_batches, get_test_batches
from components import tf_log, sample_latents, compute_kl_loss, dynamic_rnn, dynamic_bi_rnn

from topic_beam_search_decoder import BeamSearchDecoder

# load data & set config

In [30]:
def del_all_flags(FLAGS):
    flags_dict = FLAGS._flags()    
    keys_list = [keys for keys in flags_dict]    
    for keys in keys_list:
        FLAGS.__delattr__(keys)

del_all_flags(tf.flags.FLAGS)

flags = tf.app.flags

flags.DEFINE_string('gpu', '3', 'visible gpu')

flags.DEFINE_string('mode', 'train', 'set train or eval')

flags.DEFINE_string('data_path', 'data/old/20news/instances.pkl', 'path of data')
flags.DEFINE_string('modeldir', 'model/topic_vae', 'directory of model')
flags.DEFINE_string('modelname', '20news', 'name of model')

flags.DEFINE_integer('epochs', 1000, 'epochs')
flags.DEFINE_integer('batch_size', 64, 'number of sentences in each batch')
flags.DEFINE_integer('log_period', 3000, 'valid period')

flags.DEFINE_string('opt', 'Adagrad', 'optimizer')
# flags.DEFINE_string('opt', 'Adam', 'optimizer')
flags.DEFINE_float('lr', 0.01, 'lr')
flags.DEFINE_float('reg', 1., 'regularization term')
flags.DEFINE_float('grad_clip', 5., 'grad_clip')

flags.DEFINE_float('keep_prob', 0.8, 'dropout rate')
flags.DEFINE_float('word_keep_prob', 0.75, 'word dropout rate')

flags.DEFINE_bool('warmup', True, 'flg of warming up')
flags.DEFINE_integer('epochs_cycle', 5, 'number of epochs within a cycle')
flags.DEFINE_float('r_cycle', 0.5, 'proportion used to increase beta within a cycle')
flags.DEFINE_integer('warmup_topic', 0, 'warmup period for KL of topic')

flags.DEFINE_integer('beam_width', 2, 'beam_width')
flags.DEFINE_float('length_penalty_weight', 0.0, 'length_penalty_weight')

flags.DEFINE_integer('n_topic', 20, 'number of topic')
flags.DEFINE_integer('dim_hidden_bow', 256, 'dim of hidden bow')
flags.DEFINE_integer('dim_latent_bow', 32, 'dim of latent topic')
flags.DEFINE_integer('dim_emb', 256, 'dim_emb')
flags.DEFINE_integer('dim_hidden', 512, 'dim_hidden')
flags.DEFINE_integer('dim_hidden_topic', 512, 'dim_hidden_topic')
flags.DEFINE_integer('dim_latent', 32, 'dim_latent')
flags.DEFINE_bool('bidirectional', True, 'flg of bidirectional encoding')

# for evaluation
flags.DEFINE_string('refdir', 'ref', 'refdir')
flags.DEFINE_string('outdir', 'out', 'outdir')

flags.DEFINE_string('f', '', 'kernel')
flags.DEFINE_bool('logtostderr', True, 'kernel')
flags.DEFINE_bool('showprefixforinfo', False, '')
flags.DEFINE_bool('verbosity', False, '')
# flags.DEFINE_integer('stderrthreshold', 20, 'kernel')

config = flags.FLAGS

flags.DEFINE_string('modelpath', os.path.join(config.modeldir, config.modelname), 'path of model')

In [31]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu

In [32]:
instances_train, instances_dev, instances_test, word_to_idx, idx_to_word, bow_idxs = cPickle.load(open(config.data_path,'rb'))

In [33]:
def get_batches(instances, batch_size, iterator=False):
    iter_instances = iter(instances)
    n_batch = len(instances)//batch_size
    
    batches = [(i_batch, [next(iter_instances) for i_doc in range(batch_size)]) for i_batch in range(n_batch)]
    
    if iterator: batches = iter(batches)
    return batches

train_batches = get_batches(instances_train, config.batch_size)
dev_batches = get_batches(instances_dev, config.batch_size)
test_batches = get_batches(instances_test, config.batch_size)

In [34]:
flags.DEFINE_integer('n_vocab', len(word_to_idx), 'n_vocab')
flags.DEFINE_integer('dim_bow', len(bow_idxs), 'dim_bow')

flags.DEFINE_integer('cycle_steps', len(train_batches)*config.epochs_cycle, 'number of steps for each cycle')

In [35]:
def debug_shape(variables, model):
    sample_batch = dev_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable.shape)
        else:
            print(_variable.shape)

def debug_value(variables, model, return_value=False):
    sample_batch = test_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)

    if return_value: 
        return _variables
    else:
        for _variable, variable in zip(_variables, variables):
            if hasattr(variable, 'name'):
                print(variable.name, ':', _variable)
            else:
                print(_variable)
                
def check_shape(variables):
    if 'sess' in globals(): raise
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    sample_batch = test_batches[0][1]
    feed_dict = get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable.shape)
        else:
            print(_variable.shape)
            
    sess.close()
    
def check_value(variables):
    if 'sess' in globals(): raise
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    sample_batch = test_batches[0][1]
    feed_dict = get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable)
        else:
            print(_variable.shape)
            
    sess.close()    
    

#  model 

## doubly rnn

In [8]:
class DoublyRNNCell:
    def __init__(self, dim_hidden, output_layer=None):
        self.dim_hidden = dim_hidden
        
        self.ancestral_layer=tf.layers.Dense(units=dim_hidden, activation=tf.nn.tanh, name='ancestral')
        self.fraternal_layer=tf.layers.Dense(units=dim_hidden, activation=tf.nn.tanh, name='fraternal')
        self.hidden_layer = tf.layers.Dense(units=dim_hidden, name='hidden')
        
        self.output_layer=output_layer
        
    def __call__(self, state_ancestral, state_fraternal, reuse=True):
        with tf.variable_scope('input', reuse=reuse):
            state_ancestral = self.ancestral_layer(state_ancestral)
            state_fraternal = self.fraternal_layer(state_fraternal)

        with tf.variable_scope('output', reuse=reuse):
            state_hidden = self.hidden_layer(state_ancestral + state_fraternal)
            if self.output_layer is not None: 
                output = self.output_layer(state_hidden)
            else:
                output = state_hidden
            
        return output, state_hidden
    
    def get_initial_state(self, name):
        initial_state = tf.get_variable(name, [1, self.dim_hidden], dtype=tf.float32)
        return initial_state
    
    def get_zero_state(self, name):
        zero_state = tf.zeros([1, self.dim_hidden], dtype=tf.float32, name=name)
        return zero_state

In [9]:
def doubly_rnn(dim_hidden, tree_idxs, initial_state_parent=None, initial_state_sibling=None, output_layer=None, name=''):
    outputs, states_parent = {}, {}
    
    with tf.variable_scope(name, reuse=False):
        doubly_rnn_cell = DoublyRNNCell(dim_hidden, output_layer)

        if initial_state_parent is None: 
            initial_state_parent = doubly_rnn_cell.get_initial_state('init_state_parent')
#             initial_state_parent = doubly_rnn_cell.get_zero_state('init_state_parent')
        if initial_state_sibling is None: 
#             initial_state_sibling = doubly_rnn_cell.get_initial_state('init_state_sibling')
            initial_state_sibling = doubly_rnn_cell.get_zero_state('init_state_sibling')
        output, state_sibling = doubly_rnn_cell(initial_state_parent, initial_state_sibling, reuse=False)
        outputs[0], states_parent[0] = output, state_sibling

        for parent_idx, child_idxs in tree_idxs.items():
            state_parent = states_parent[parent_idx]
            state_sibling = initial_state_sibling
            for child_idx in child_idxs:
                output, state_sibling = doubly_rnn_cell(state_parent, state_sibling)
                outputs[child_idx], states_parent[child_idx] = output, state_sibling

    return outputs, states_parent

## nCRP model

In [17]:
class Model():
    def __init__(self, config, tree_idxs):
        def get_depth(parent_idx=0, tree_depth=None, depth=1):
            if tree_depth is None: tree_depth={0: depth}

            child_idxs = tree_idxs[parent_idx]
            depth +=1
            for child_idx in child_idxs:
                tree_depth[child_idx] = depth
                if child_idx in tree_idxs: get_depth(child_idx, tree_depth, depth)
            return tree_depth
        
        self.config = config
        
        self.t_variables = {}
        
        self.tree_idxs = tree_idxs
        self.topic_idxs = [0] + [idx for child_idxs in tree_idxs.values() for idx in child_idxs]
        self.child_to_parent_idxs = {child_idx: parent_idx for parent_idx, child_idxs in self.tree_idxs.items() for child_idx in child_idxs}
        self.tree_depth = get_depth()
        self.n_depth = max(self.tree_depth.values())
        
        self.build()
        
    def build(self):
        def nCRP(tree_sticks_topic):
            tree_prob_topic = {}
            tree_prob_leaf = {}
            # calculate topic probability and save
            tree_prob_topic[0] = 1.

            for parent_idx, child_idxs in self.tree_idxs.items():
                rest_prob_topic = tree_prob_topic[parent_idx]
                for child_idx in child_idxs:
                    stick_topic = tree_sticks_topic[child_idx]
                    if child_idx == child_idxs[-1]:
                        prob_topic = rest_prob_topic * 1.
                    else:
                        prob_topic = rest_prob_topic * stick_topic

                    if not child_idx in self.tree_idxs: # leaf childs
                        tree_prob_leaf[child_idx] = prob_topic
                    else:
                        tree_prob_topic[child_idx] = prob_topic

                    rest_prob_topic -= prob_topic
            return tree_prob_leaf

        def get_prob_topic(tree_prob_leaf, prob_depth):
            def get_ancestor_idxs(leaf_idx, ancestor_idxs = None):
                if ancestor_idxs is None: ancestor_idxs = [leaf_idx]
                parent_idx = self.child_to_parent_idxs[leaf_idx]
                ancestor_idxs += [parent_idx]
                if parent_idx in self.child_to_parent_idxs: get_ancestor_idxs(parent_idx, ancestor_idxs)
                return ancestor_idxs[::-1]
            
            tree_prob_topic = defaultdict(float)
            leaf_ancestor_idxs = {leaf_idx: get_ancestor_idxs(leaf_idx) for leaf_idx in tree_prob_leaf}
            for leaf_idx, ancestor_idxs in leaf_ancestor_idxs.items():
                prob_leaf = tree_prob_leaf[leaf_idx]
                for i, ancestor_idx in enumerate(ancestor_idxs):
                    prob_ancestor = prob_leaf * tf.expand_dims(prob_depth[:, i], -1)
                    tree_prob_topic[ancestor_idx] += prob_ancestor
            prob_topic = tf.concat([tree_prob_topic[topic_idx] for topic_idx in self.topic_idxs], -1)
            return prob_topic
        
        def get_tree_topic_bow(tree_topic_embeddings):
            def softmax_with_temperature(logits, axis=None, name=None, temperature=1.):
                if axis is None:
                    axis = -1
                return tf.exp(logits / temperature) / tf.reduce_sum(tf.exp(logits / temperature), axis=axis)

            tree_topic_bow = {}
            for topic_idx, depth in self.tree_depth.items():
                topic_embedding = tree_topic_embeddings[topic_idx]
                temperature = tf.constant(10 ** (1./depth), dtype=tf.float32)
                logits = tf.matmul(topic_embedding, bow_embeddings, transpose_b=True)
                tree_topic_bow[topic_idx] = softmax_with_temperature(logits, axis=-1, temperature=temperature)
            return tree_topic_bow

        def get_tree_mask_reg():
            def get_descendant_idxs(parent_idx, descendant_idxs = None):
                if descendant_idxs is None: descendant_idxs = []

                child_idxs = self.tree_idxs[parent_idx]
                descendant_idxs += child_idxs
                for child_idx in child_idxs:
                    if child_idx in self.tree_idxs: get_descendant_idxs(child_idx, descendant_idxs)
                return descendant_idxs

            tree_mask_reg = np.ones([len(self.topic_idxs), len(self.topic_idxs)], dtype=np.float32)
            parent_to_descendant_idxs = {parent_idx: get_descendant_idxs(parent_idx) for parent_idx in self.tree_idxs}

            for parent_idx, descendant_idxs in parent_to_descendant_idxs.items():
                for descendant_idx in descendant_idxs:
                    parent_index = self.topic_idxs.index(parent_idx)
                    descendant_index = self.topic_idxs.index(descendant_idx)
                    tree_mask_reg[parent_index, descendant_index] = tree_mask_reg[descendant_index, parent_index] = 0.
            return tree_mask_reg
       
        # -------------- Build Model --------------
        tf.reset_default_graph()
        
        self.t_variables['bow'] = tf.placeholder(tf.float32, [None, self.config.dim_bow])
        self.t_variables['keep_prob'] = tf.placeholder(tf.float32)
        
        # encode bow
        with tf.variable_scope('topic/enc', reuse=False):
            hidden_bow_ = tf.layers.Dense(units=self.config.dim_hidden_bow, activation=tf.nn.tanh, name='hidden_bow')(self.t_variables['bow'])
            hidden_bow = tf.layers.Dropout(self.t_variables['keep_prob'])(hidden_bow_)
            means_bow = tf.layers.Dense(units=self.config.dim_latent_bow, name='mean_bow')(hidden_bow)
            logvars_bow = tf.layers.Dense(units=self.config.dim_latent_bow, kernel_initializer=tf.constant_initializer(0), bias_initializer=tf.constant_initializer(0), name='logvar_bow')(hidden_bow)
            latents_bow = sample_latents(means_bow, logvars_bow) # sample latent vectors
            prob_layer = lambda h: tf.nn.sigmoid(tf.matmul(latents_bow, h, transpose_b=True))

            tree_sticks_topic, tree_states_sticks_topic = doubly_rnn(self.config.dim_latent_bow, self.tree_idxs, output_layer=prob_layer, name='sticks_topic')
            tree_prob_leaf = nCRP(tree_sticks_topic)
            self.tree_prob_leaf = tree_prob_leaf
            prob_depth = tf.layers.Dense(units=self.n_depth, activation=tf.nn.softmax, name='prob_depth')(latents_bow) # inference of topic probabilities
            self.prob_depth = prob_depth

            prob_topic = get_prob_topic(tree_prob_leaf, prob_depth)
            self.prob_topic = prob_topic # N_BATCH x K

        # decode bow
        with tf.variable_scope('shared', reuse=False):
            bow_embeddings = tf.get_variable('emb', [self.config.dim_bow, self.config.dim_emb], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) # embeddings of vocab

        with tf.variable_scope('topic/dec', reuse=False):
        #     tree_topic_embeddings, tree_states_topic_embeddings = doubly_rnn(self.config.dim_emb, self.tree_idxs, name='emb_topic')
            emb_layer = lambda h: tf.layers.Dense(units=self.config.dim_emb, name='output')(tf.nn.tanh(h))
            tree_topic_embeddings, tree_states_topic_embeddings = doubly_rnn(self.config.dim_emb, self.tree_idxs, output_layer=emb_layer, name='emb_topic')
#             topic_embeddings = tf.get_variable('topic_emb', [len(self.topic_idxs), self.config.dim_emb], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) # embeddings of topics
#             tree_topic_embeddings = {topic_idx: tf.expand_dims(topic_embeddings[self.topic_idxs.index(topic_idx)], 0) for topic_idx in self.topic_idxs}

            tree_topic_bow = get_tree_topic_bow(tree_topic_embeddings) # bow vectors for each topic

            topic_bow = tf.concat([tree_topic_bow[topic_idx] for topic_idx in self.topic_idxs], 0) # KxV
            self.topic_bow = topic_bow
            logits_bow = tf_log(tf.matmul(prob_topic, topic_bow)) # predicted bow distribution N_Batch x  V
            self.logits_bow = logits_bow
            
        # define losses
        self.topic_losses_recon = -tf.reduce_sum(tf.multiply(self.t_variables['bow'], logits_bow), 1)
        self.topic_loss_recon = tf.reduce_mean(self.topic_losses_recon) # negative log likelihood of each words

        self.topic_loss_kl = compute_kl_loss(means_bow, logvars_bow) # KL divergence b/w latent dist & gaussian std

#         topic_bow_norm = topic_bow / tf.norm(topic_bow, axis=1, keepdims=True)
#         self.topic_dots = tf.clip_by_value(tf.matmul(topic_bow_norm, tf.transpose(topic_bow_norm)), -1., 1.)
        
        topic_embeddings = tf.concat([tree_topic_embeddings[topic_idx] for topic_idx in self.topic_idxs], 0)
        topic_embeddings_norm = topic_embeddings / tf.norm(topic_embeddings, axis=1, keepdims=True)
        self.topic_dots = tf.clip_by_value(tf.matmul(topic_embeddings_norm, tf.transpose(topic_embeddings_norm)), -1., 1.)        
        
        self.tree_mask_reg = get_tree_mask_reg()
        self.topic_losses_reg = tf.square(self.topic_dots - tf.eye(len(self.topic_idxs))) * self.tree_mask_reg
        self.topic_loss_reg = tf.reduce_sum(self.topic_losses_reg) / tf.reduce_sum(self.tree_mask_reg)

        self.global_step = tf.Variable(0, name='global_step',trainable=False)

        self.loss = self.topic_loss_recon + self.topic_loss_kl + self.config.reg * self.topic_loss_reg

        # define optimizer
        if self.config.opt == 'Adam':
            optimizer = tf.train.AdamOptimizer(self.config.lr)
        elif self.config.opt == 'Adagrad':
            optimizer = tf.train.AdagradOptimizer(self.config.lr)

        self.grad_vars = optimizer.compute_gradients(self.loss)
        self.clipped_grad_vars = [(tf.clip_by_value(grad, -self.config.grad_clip, self.config.grad_clip), var) for grad, var in self.grad_vars]
        self.opt = optimizer.apply_gradients(self.clipped_grad_vars, global_step=self.global_step)

        # monitor
        self.n_bow = tf.reduce_sum(self.t_variables['bow'], 1)
        self.topic_ppls = tf.divide(self.topic_losses_recon, tf.maximum(1e-5, self.n_bow))
        self.topics_freq_bow_indices = tf.nn.top_k(topic_bow, 10, name='topic_freq_bow').indices
    
        # growth criteria
    #         self.dist_bow = tf.reduce_sum(tf.square(tf.expand_dims(tf.exp(logits_bow), 1) - tf.expand_dims(topic_bow, 0)), -1)
    #         self.rads_bow = tf.sqrt(tf.reduce_sum(tf.multiply(prob_topic, self.dist_bow), 0) / tf.reduce_sum(prob_topic, 0))        
        self.dist_bow = -tf.matmul(self.t_variables['bow'], tf.log(topic_bow), transpose_b=True)
        self.rads_bow = tf.divide(tf.multiply(self.dist_bow, prob_topic), tf.expand_dims(self.n_bow, -1))
        self.n_topics = tf.multiply(tf.expand_dims(self.n_bow, -1), prob_topic)
    
    def get_feed_dict(self, batch, mode='train'):
        bow = np.array([instance.bow for instance in batch]).astype(np.float32)
        keep_prob = self.config.keep_prob if mode == 'train' else 1.0
        feed_dict = {
                    self.t_variables['bow']: bow, 
                    self.t_variables['keep_prob']: keep_prob
        }
        return  feed_dict

# run

In [18]:
def get_loss(sess, batches, model):
    losses = []
    ppl_list = []
    rads_bow_list = []
    prob_topic_list = []
    n_bow_list = []
    n_topics_list = []
    for ct, batch in batches:
        feed_dict = model.get_feed_dict(batch, mode='test')
        loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, ppls_batch, rads_bow_batch, prob_topic_batch, n_bow_batch, n_topics_batch \
            = sess.run([model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_reg, model.topic_ppls, model.rads_bow, model.prob_topic, model.n_bow, model.n_topics], feed_dict = feed_dict)
        losses += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch]]
        ppl_list += list(ppls_batch)
        rads_bow_list.append(rads_bow_batch)
        prob_topic_list.append(prob_topic_batch)
        n_bow_list.append(n_bow_batch)
        n_topics_list.append(n_topics_batch)
    loss_mean, topic_loss_recon_mean, topic_loss_kl_mean, topic_loss_reg_mean = np.mean(losses, 0)
    ppl_mean = np.exp(np.mean(ppl_list))
    probs_topic = np.concatenate(prob_topic_list, 0)
    rads_bow = np.concatenate(rads_bow_list, 0)
    rads_bow_mean = np.sum(rads_bow, 0) / np.sum(rads_bow)
    n_bow = np.concatenate(n_bow_list, 0)
    n_topics = np.concatenate(n_topics_list, 0)
    probs_topic_mean = np.sum(n_topics, 0) / np.sum(n_bow)
    return loss_mean, topic_loss_recon_mean, topic_loss_kl_mean, topic_loss_reg_mean, ppl_mean, rads_bow_mean, probs_topic_mean

In [19]:
def print_topic_sample(tree_idxs, sess=None, model=None, rads_bow=None, probs_topic=None, parent_idx=0, topics_freq_bow_idxs=None, depth = 0):
    if topics_freq_bow_idxs is None:
        topics_freq_bow_idxs = bow_idxs[sess.run(model.topics_freq_bow_indices)]
        topic_freq_bow_idxs = topics_freq_bow_idxs[model.topic_idxs.index(parent_idx)]
        rad_bow = rads_bow[model.topic_idxs.index(parent_idx)]
        prob_topic = probs_topic[model.topic_idxs.index(parent_idx)]
        print(parent_idx, 'R: %.2f' % rad_bow, 'P: %.3f' % prob_topic, ' '.join([idx_to_word[idx] for idx in topic_freq_bow_idxs]))
    
    child_idxs = tree_idxs[parent_idx]
    depth += 1
    for child_idx in child_idxs:
        topic_freq_bow_idxs = topics_freq_bow_idxs[model.topic_idxs.index(child_idx)]
        rad_bow = rads_bow[model.topic_idxs.index(child_idx)]
        prob_topic = probs_topic[model.topic_idxs.index(child_idx)]
        print('  '*depth, child_idx, 'R: %.2f' % rad_bow, 'P: %.3f' % prob_topic, ' '.join([idx_to_word[idx] for idx in topic_freq_bow_idxs]))
        
        if child_idx in tree_idxs: print_topic_sample(tree_idxs, model=model, rads_bow=rads_bow, probs_topic=probs_topic, parent_idx=child_idx, topics_freq_bow_idxs=topics_freq_bow_idxs, depth=depth)


In [20]:
def update_tree(rads_bow, probs_topic, model, add_threshold=0.3, remove_threshold=0.1):
    def add_topic(topic_idx, tree_idxs):
        if topic_idx in tree_idxs:
            child_idx = max(tree_idxs[topic_idx])+1
            tree_idxs[topic_idx].append(child_idx)        
        else:
            child_idx = 10*topic_idx+1
            tree_idxs[topic_idx] = [10*topic_idx+1]
        return tree_idxs, child_idx
    
    assert len(model.topic_idxs) == len(rads_bow) == len(probs_topic)
    update_tree_flg = False
    
    topic_rad_bow = {topic_idx: rad_bow for topic_idx, rad_bow in zip(model.topic_idxs, rads_bow)}
    added_tree_idxs = copy.deepcopy(model.tree_idxs)
    for parent_idx, child_idxs in model.tree_idxs.items():
        rad_bow = np.max([topic_rad_bow[child_idx] for child_idx in child_idxs])
        if rad_bow > add_threshold:
            update_tree_flg = True
            for depth in range(model.tree_depth[parent_idx], model.n_depth):
                added_tree_idxs, parent_idx = add_topic(parent_idx, added_tree_idxs)
    
    def remove_topic(parent_idx, child_idx, tree_idxs):
        if parent_idx in tree_idxs:
            tree_idxs[parent_idx].remove(child_idx)
            if child_idx in tree_idxs:
                tree_idxs.pop(child_idx)    
        return tree_idxs
    
    topic_prob_topic = {topic_idx: prob_topic for topic_idx, prob_topic in zip(model.topic_idxs, probs_topic)}
    removed_tree_idxs = copy.deepcopy(added_tree_idxs)
    for parent_idx, child_idxs in model.tree_idxs.items():
        probs_child = np.array([topic_prob_topic[child_idx] for child_idx in child_idxs])
        prob_child = np.min(probs_child)
        child_idx = child_idxs[np.argmin(probs_child)]
        if prob_child < remove_threshold:
            update_tree_flg = True
            removed_tree_idxs = remove_topic(parent_idx, child_idx, removed_tree_idxs)
            if parent_idx in removed_tree_idxs:
                if len(removed_tree_idxs[parent_idx]) == 0:
                    ancestor_idx = model.child_to_parent_idxs[parent_idx]
                    removed_tree_idxs = remove_topic(ancestor_idx, parent_idx, removed_tree_idxs)
    return removed_tree_idxs, update_tree_flg

In [27]:
losses_train = []
ppls_train = []
loss_min = np.inf
beta_eval = 1.
epoch = 0
train_batches = get_batches(instances_train, config.batch_size, iterator=True)

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','TM','','','','VALID:','TM','','',''],
                            ['Time','Ep','Ct','LOSS','PPL','NLL','KL','REG','LOSS','PPL','NLL','KL','REG']]))))

In [28]:
# tree_idxs = {0:[1, 2], 
#           1:[10, 11], 2:[20, 21]}

tree_idxs = {0:[1, 2, 3, 4], 
              1:[10, 11], 2:[20, 21], 3:[30, 31], 4:[40, 41]}

# tree_idxs = {0:[1, 2, 3], 
#               1:[10, 11, 12], 2:[20, 21, 22], 3:[30, 31, 32]}


if 'sess' in globals(): sess.close()
model = Model(config, tree_idxs)
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
sess.run(tf.global_variables_initializer())
name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))}
saver = tf.train.Saver(max_to_keep=10)
update_tree_flg = False

In [29]:
if len(log_df) == 0:
    cmd_rm = 'rm -r %s' % config.modeldir
    res = subprocess.call(cmd_rm.split())

    cmd_mk = 'mkdir %s' % config.modeldir
    res = subprocess.call(cmd_mk.split())

time_start = time.time()
while epoch < config.epochs:    
    # train
    for ct, batch in train_batches:
        feed_dict = model.get_feed_dict(batch)
        _, loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, ppls_batch, global_step_log = \
        sess.run([model.opt, model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_reg, model.topic_ppls, tf.train.get_global_step()], feed_dict = feed_dict)

        losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch]]
        ppls_train += list(ppls_batch)

        # validate
#         if global_step_log % config.log_period == 0:
        if global_step_log % 3000 == 0:            
            loss_train, topic_loss_recon_train, topic_loss_kl_train, topic_loss_reg_train = np.mean(losses_train, 0)
            ppl_train = np.exp(np.mean(ppls_train))
            loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, rads_bow_dev, probs_topic_dev = get_loss(sess, dev_batches, model)

            # log
            clear_output()
            time_log = int(time.time() - time_start)
            log_series = pd.Series([time_log, epoch, ct, \
                    '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_train, '%.2f'%topic_loss_reg_train, \
                    '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_dev, '%.2f'%topic_loss_reg_dev],
                    index=log_df.columns)
            log_df.loc[global_step_log] = log_series
            display(log_df)

            # visualize topic
            print_topic_sample(tree_idxs, sess, model, rads_bow_dev, probs_topic_dev)
            time_start = time.time()

            # update tree
            if global_step_log % 3000 == 0:
                tree_idxs, update_tree_flg = update_tree(rads_bow_dev, probs_topic_dev, model, add_threshold=0.2, remove_threshold=0.02)
                if update_tree_flg:
                    print(tree_idxs)
                    name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))} # store paremeters
                    if 'sess' in globals(): sess.close()
                    model = Model(config, tree_idxs)
                    sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
                    name_tensors = {tensor.name: tensor for tensor in tf.global_variables()}
                    sess.run([name_tensors[name].assign(variable) for name, variable in name_variables.items()]) # restore parameters

    train_batches = get_batches(instances_train, config.batch_size, iterator=True)
    epoch += 1

display(log_df)
print_topic_sample()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,TM,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,TM,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL.1,NLL,KL,REG
3000,37,21,59,434.0,1075,430.99,2.9,0.11,428.91,1021,425.58,3.26,0.06
6000,39,42,119,431.17,1026,427.9,3.2,0.07,426.91,982,423.29,3.6,0.03
9000,33,64,39,429.57,998,426.14,3.36,0.06,426.54,959,422.86,3.66,0.02
12000,38,85,99,428.55,981,425.02,3.47,0.05,425.88,952,422.09,3.77,0.02
15000,32,107,19,427.84,968,424.25,3.55,0.04,425.59,950,421.78,3.8,0.02
18000,32,128,79,427.27,958,423.62,3.61,0.04,425.21,938,421.3,3.9,0.02
21000,32,149,139,426.87,950,423.17,3.66,0.03,424.99,938,421.06,3.92,0.01
24000,32,171,59,426.49,944,422.76,3.69,0.03,424.88,932,420.95,3.92,0.01
27000,32,192,119,426.19,939,422.43,3.73,0.03,424.72,926,420.73,3.97,0.01
30000,32,214,39,425.92,934,422.13,3.75,0.03,424.6,925,420.63,3.96,0.01


0 R: 0.22 P: 0.171 car little power down dod around enough bike lot buy
   1 R: 0.05 P: 0.064 key encryption chip government keys clipper security public law privacy
     10 R: 0.10 P: 0.086 team game play games season hockey win players league best
     11 R: 0.06 P: 0.082 she her mr president didn down went come told made
   2 R: 0.03 P: 0.052 space nasa research launch gov center earth national moon data
     20 R: 0.06 P: 0.064 health food drugs news day medical disease uiuc steve men
     21 R: 0.06 P: 0.114 government israel turkish state gun rights israeli jews against war
   3 R: 0.10 P: 0.108 program window available ftp image files graphics server version output
     31 R: 0.08 P: 0.041 uk email ac university list send cs interested internet sale
   4 R: 0.10 P: 0.076 drive windows card dos scsi mb disk mac bit pc
     40 R: 0.07 P: 0.075 evidence science true truth argument wrong human example claim mean
     41 R: 0.06 P: 0.068 jesus christian bible church christians christ

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,TM,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,TM,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL.1,NLL,KL,REG
3000,37,21,59,434.0,1075,430.99,2.9,0.11,428.91,1021,425.58,3.26,0.06
6000,39,42,119,431.17,1026,427.9,3.2,0.07,426.91,982,423.29,3.6,0.03
9000,33,64,39,429.57,998,426.14,3.36,0.06,426.54,959,422.86,3.66,0.02
12000,38,85,99,428.55,981,425.02,3.47,0.05,425.88,952,422.09,3.77,0.02
15000,32,107,19,427.84,968,424.25,3.55,0.04,425.59,950,421.78,3.8,0.02
18000,32,128,79,427.27,958,423.62,3.61,0.04,425.21,938,421.3,3.9,0.02
21000,32,149,139,426.87,950,423.17,3.66,0.03,424.99,938,421.06,3.92,0.01
24000,32,171,59,426.49,944,422.76,3.69,0.03,424.88,932,420.95,3.92,0.01
27000,32,192,119,426.19,939,422.43,3.73,0.03,424.72,926,420.73,3.97,0.01
30000,32,214,39,425.92,934,422.13,3.75,0.03,424.6,925,420.63,3.96,0.01


TypeError: print_topic_sample() missing 1 required positional argument: 'tree_idxs'

# confirm variables

In [None]:
states_topic_embeddings = tf.concat([tree_states_topic_embeddings[topic_idx] for topic_idx in topic_idxs], 0)

In [None]:
debug_value([states_topic_embeddings[:, :6]])

In [None]:
topic_embeddings = tf.concat([tree_topic_embeddings[topic_idx] for topic_idx in topic_idxs], 0)

In [None]:
debug_value([topic_embeddings[:, :6]])

In [None]:
_topics_bow, = debug_value([topic_bow], return_value=True)
np.max(_topics_bow)

In [None]:
_topics_bow, = debug_value([topic_bow], return_value=True)

plt.figure(figsize=(12, 20))
    
_topic_bow = _topics_bow[0]
plt.subplot(5,3,2)
plt.ylim([0, 0.1])
plt.bar(bow_idxs, _topic_bow)

for i in range(1, len(topic_idxs)):
    _topic_bow = _topics_bow[i]
    plt.subplot(5,3,i+3)
    plt.ylim([0, 0.1])
#     plt.axis('off')
    plt.bar(bow_idxs, _topic_bow)

plt.show()

In [None]:
_prob_topics = []
for ct, batch in dev_batches:
    feed_dict = get_feed_dict(batch)
    _prob_topic, = sess.run([prob_topic], feed_dict = feed_dict)
    _prob_topics.append(_prob_topic)
    
_prob_topics = np.concatenate(_prob_topics, 0)
_prob_topic_mean = np.mean(_prob_topics, 0)

print(_prob_topic_mean)

In [None]:
debug_value([topic_dots])

In [None]:
debug_value([topic_losses_reg])

In [None]:
tree_mask_reg

In [None]:
_topic_bow, = debug_value([topic_bow], return_value=True)

In [None]:
plt.bar(bow_idxs, _topic_bow[0])

In [None]:
plt.bar(bow_idxs, _topic_bow[1])

In [None]:
plt.bar(bow_idxs, _topic_bow[2])

In [None]:
plt.bar(bow_idxs, _topic_bow[3])

In [None]:
np.max(_topic_bow, 1)

In [None]:
plt.bar(bow_idxs, _topic_bow[-5])

In [None]:
plt.bar(bow_idxs, _topic_bow[-1])

In [None]:
len(bow_idxs)

In [None]:
debug_value([prob_topic[3]])

In [None]:
debug_value([tf.exp(-tf.divide(topic_losses_recon, n_bow))])

### test

In [None]:
debug_shape([bow, hidden_bow, latents_bow, prob_topic, bow_embeddings, topic_embeddings, topic_bow, prob_bow])

In [None]:
debug_shape([topic_losses_recon, topic_loss_recon, n_bow, ppls, topic_embeddings_norm, tf.expand_dims(topic_angles_mean, -1), topic_angles_vars])

In [None]:
debug_value([tf.reduce_sum(tf.square(topic_embeddings_norm), 1)], return_value=True)[0]

In [None]:
debug_value([tf.reduce_sum(prob_topic, -1), tf.reduce_sum(topic_bow, -1), tf.reduce_sum(tf.exp(prob_bow), 1)])

In [None]:
sigma_bow = tf.exp(0.5 * logvars_bow)
dist_bow = tfd.Normal(means_bow, sigma_bow)
dist_std = tfd.Normal(0., 1.)
topic_loss_kl_tmp = tf.reduce_mean(tf.reduce_sum(tfd.kl_divergence(dist_bow, dist_std), 1))

In [None]:
debug_value([topic_loss_recon, topic_loss_kl, topic_loss_kl_tmp])

In [None]:
_logvars, _means, _kl_losses, _latents, _output_logits = sess.run([logvars, means, kl_losses, latents, output_logits], feed_dict=feed_dict)


In [None]:
_logvars.shape, _means.shape, _kl_losses.shape, _latents.shape

In [None]:
_output_logits

In [None]:
_output_logits, _dec_target_idxs_do, _dec_mask_tokens_do, _recon_loss, _kl_losses, _ = sess.run([output_logits, dec_target_idxs_do, dec_mask_tokens_do, recon_loss, kl_losses, opt], feed_dict=feed_dict)


In [None]:
tf.reduce_max(output_logits, 2).eval(session=sess, feed_dict=feed_dict).shape

In [None]:
_output_logits.shape, _dec_target_idxs_do.shape, _dec_mask_tokens_do.shape

In [None]:
_logits = np.exp(_output_logits) / np.sum(np.exp(_output_logits), 2)[:, :, None]

In [None]:
_idxs = _dec_target_idxs_do

In [None]:
_losses = np.array([[-np.log(_logits[i, j, _idxs[i, j]]) for j in range(_idxs.shape[1])] for i in range(_idxs.shape[0])]) * _dec_mask_tokens_do

In [None]:
np.sum(_losses)/np.sum(_dec_mask_tokens_do)

In [None]:
_recon_loss

In [None]:
_kl_losses.shape