In [1]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import subprocess
import pdb
import time
import datetime
import math
import copy
import random
import _pickle as cPickle
from collections import defaultdict
import matplotlib.pyplot as plt

%matplotlib inline
# %matplotlib nbagg

import numpy as np
import pandas as pd
import tensorflow as tf

from data_structure import get_batches
from hntm import HierarchicalNeuralTopicModel
from tree import get_descendant_idxs
from evaluation import validate, print_hierarchical_affinity, print_topic_sample, print_topic_specialization

# load data & set config

In [2]:
def del_all_flags(FLAGS):
    flags_dict = FLAGS._flags()    
    keys_list = [keys for keys in flags_dict]    
    for keys in keys_list:
        FLAGS.__delattr__(keys)

del_all_flags(tf.flags.FLAGS)

flags = tf.app.flags

flags.DEFINE_string('gpu', '3', 'visible gpu')

flags.DEFINE_string('data_path', 'data/20news/instances.pkl', 'path of data')
flags.DEFINE_string('modeldir', 'model/topic_vae', 'directory of model')
flags.DEFINE_string('modelname', '20news', 'name of model')

flags.DEFINE_integer('epochs', 1000, 'epochs')
flags.DEFINE_integer('batch_size', 64, 'number of sentences in each batch')
flags.DEFINE_integer('log_period', 3000, 'valid period')

flags.DEFINE_string('opt', 'Adagrad', 'optimizer')
flags.DEFINE_float('lr', 0.01, 'lr')
flags.DEFINE_float('reg', 1., 'regularization term')
flags.DEFINE_float('grad_clip', 5., 'grad_clip')

flags.DEFINE_float('keep_prob', 0.8, 'dropout rate')

flags.DEFINE_integer('dim_hidden_bow', 256, 'dim of hidden bow')
flags.DEFINE_integer('dim_latent_bow', 32, 'dim of latent topic')
flags.DEFINE_integer('dim_emb', 256, 'dim_emb')

flags.DEFINE_float('depth_temperature', 1., 'dropout rate')

# for evaluation
flags.DEFINE_string('refdir', 'ref', 'refdir')
flags.DEFINE_string('outdir', 'out', 'outdir')

flags.DEFINE_string('f', '', 'kernel')

config = flags.FLAGS

flags.DEFINE_string('modelpath', os.path.join(config.modeldir, config.modelname), 'path of model')

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu

In [4]:
instances_train, instances_dev, instances_test, word_to_idx, idx_to_word, bow_idxs = cPickle.load(open(config.data_path,'rb'))
train_batches = get_batches(instances_train, config.batch_size)
dev_batches = get_batches(instances_dev, config.batch_size)
test_batches = get_batches(instances_test, config.batch_size)

In [5]:
flags.DEFINE_integer('dim_bow', len(bow_idxs), 'dim_bow')

In [6]:
def debug_shape(variables, model):
    sample_batch = test_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable.shape)
        else:
            print(_variable.shape)

def debug_value(variables, model, return_value=False):
    sample_batch = test_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)

    return _variables

# run

In [7]:
def update_tree(recur_prob_topic, topic_prob_topic, model, add_threshold=0.3, remove_threshold=0.1):    
    assert len(model.topic_idxs) == len(recur_prob_topic) == len(topic_prob_topic)
    update_tree_flg = False
    
    def add_topic(topic_idx, tree_idxs):
        if topic_idx in tree_idxs:
            child_idx = min([10*topic_idx+i for i in range(1, 10) if 10*topic_idx+i not in tree_idxs[topic_idx]])
            tree_idxs[topic_idx].append(child_idx)        
        else:
            child_idx = 10*topic_idx+1
            tree_idxs[topic_idx] = [10*topic_idx+1]
        return tree_idxs, child_idx
    
    added_tree_idxs = copy.deepcopy(model.tree_idxs)
    for parent_idx, child_idxs in model.tree_idxs.items():
        prob_topic = topic_prob_topic[parent_idx]
        if prob_topic > add_threshold:
            update_tree_flg = True
            for depth in range(model.tree_depth[parent_idx], model.n_depth):
                added_tree_idxs, parent_idx = add_topic(parent_idx, added_tree_idxs)
    
    def remove_topic(parent_idx, child_idx, tree_idxs):
        if parent_idx in tree_idxs:
            tree_idxs[parent_idx].remove(child_idx)
            if child_idx in tree_idxs:
                tree_idxs.pop(child_idx)    
        return tree_idxs
    
    removed_tree_idxs = copy.deepcopy(added_tree_idxs)
    for parent_idx, child_idxs in model.tree_idxs.items():
        probs_child = np.array([recur_prob_topic[child_idx] for child_idx in child_idxs])
#         prob_child = np.min(probs_child)
#         child_idx = child_idxs[np.argmin(probs_child)]
        for prob_child, child_idx in zip(probs_child, child_idxs):
            if prob_child < remove_threshold:
                update_tree_flg = True
                removed_tree_idxs = remove_topic(parent_idx, child_idx, removed_tree_idxs)
                if parent_idx in removed_tree_idxs:
                    if len(removed_tree_idxs[parent_idx]) == 0:
                        ancestor_idx = model.child_to_parent_idxs[parent_idx]
                        removed_tree_idxs = remove_topic(ancestor_idx, parent_idx, removed_tree_idxs)
    return removed_tree_idxs, update_tree_flg

In [8]:
losses_train = []
ppls_train = []
loss_min = np.inf
beta_eval = 1.
epoch = 0
train_batches = get_batches(instances_train, config.batch_size, iterator=True)

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','TM','','','','VALID:','TM','','',''],
                            ['Time','Ep','Ct','LOSS','PPL','NLL','KL','REG','LOSS','PPL','NLL','KL','REG']]))))

In [9]:
tree_idxs = {0:[1, 2, 3, 4, 5], 
              1:[11, 12], 2:[21, 22], 3:[31, 32], 4:[41, 42], 5:[51, 52]}

if 'sess' in globals(): sess.close()
model = HierarchicalNeuralTopicModel(config, tree_idxs)
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
sess.run(tf.global_variables_initializer())
name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))}
saver = tf.train.Saver(max_to_keep=10)
update_tree_flg = False

In [None]:
if len(log_df) == 0:
    cmd_rm = 'rm -r %s' % config.modeldir
    res = subprocess.call(cmd_rm.split())

    cmd_mk = 'mkdir %s' % config.modeldir
    res = subprocess.call(cmd_mk.split())

time_start = time.time()
while epoch < config.epochs:    
    # train
    for ct, batch in train_batches:
        feed_dict = model.get_feed_dict(batch)
        _, loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, ppls_batch, global_step_log = \
        sess.run([model.opt, model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_reg, model.topic_ppls, tf.train.get_global_step()], feed_dict = feed_dict)

        losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch]]
        ppls_train += list(ppls_batch)

        # validate
#         if global_step_log % config.log_period == 0:
        if global_step_log % 5000 == 0:            
            loss_train, topic_loss_recon_train, topic_loss_kl_train, topic_loss_reg_train = np.mean(losses_train, 0)
            ppl_train = np.exp(np.mean(ppls_train))
            loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, rads_bow_dev, probs_topic_dev = validate(sess, dev_batches, model)

            # log
            clear_output()
            time_log = int(time.time() - time_start)
            log_series = pd.Series([time_log, epoch, ct, \
                    '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_train, '%.2f'%topic_loss_reg_train, \
                    '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_dev, '%.2f'%topic_loss_reg_dev],
                    index=log_df.columns)
            log_df.loc[global_step_log] = log_series
            display(log_df)

            # visualize topic
            topics_freq_idxs = bow_idxs[sess.run(model.topics_freq_bow_indices)]
            topic_freq_token = {topic_idx: ' '.join([idx_to_word[idx] for idx in topic_freq_idxs]) for topic_idx, topic_freq_idxs in zip(model.topic_idxs, topics_freq_idxs)}
            topic_prob_topic = {topic_idx: prob_topic for topic_idx, prob_topic in zip(model.topic_idxs, probs_topic_dev)}
            descendant_idxs = {parent_idx: get_descendant_idxs(model, parent_idx) for parent_idx in model.topic_idxs}
            recur_prob_topic = {parent_idx: np.sum([topic_prob_topic[child_idx] for child_idx in recur_child_idxs]) for parent_idx, recur_child_idxs in descendant_idxs.items()}
            
            print_topic_sample(sess, model, topic_prob_topic=topic_prob_topic, recur_prob_topic=recur_prob_topic, topic_freq_token=topic_freq_token)
            print_topic_specialization(sess, model, instances_test)
            print_hierarchical_affinity(sess, model)
            time_start = time.time()

            # update tree
            tree_idxs, update_tree_flg = update_tree(recur_prob_topic, topic_prob_topic, model, add_threshold=0.05, remove_threshold=0.05)
            if update_tree_flg:
                print(tree_idxs)
                name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))} # store paremeters
                if 'sess' in globals(): sess.close()
                model = HierarchicalNeuralTopicModel(config, tree_idxs)
                sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
                name_tensors = {tensor.name: tensor for tensor in tf.global_variables()}
                sess.run([name_tensors[name].assign(variable) for name, variable in name_variables.items()]) # restore parameters

    train_batches = get_batches(instances_train, config.batch_size, iterator=True)
    epoch += 1

loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, rads_bow_dev, probs_topic_dev = get_loss(sess, dev_batches, model)
topic_prob_topic = {topic_idx: prob_topic for topic_idx, prob_topic in zip(model.topic_idxs, probs_topic_dev)}

descendant_idxs = {parent_idx: get_descendant_idxs(model, parent_idx) for parent_idx in model.topic_idxs}
recur_prob_topic = {parent_idx: np.sum([topic_prob_topic[child_idx] for child_idx in recur_child_idxs]) for parent_idx, recur_child_idxs in descendant_idxs.items()}
display(log_df)
print_topic_sample(tree_idxs, sess, model, topic_prob_topic=topic_prob_topic, recur_prob_topic=recur_prob_topic)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,TM,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,TM,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL.1,NLL,KL,REG
5000,65,28,99,595.57,851,592.37,3.04,0.16,597.72,1143,593.52,4.07,0.13
10000,50,57,24,595.49,832,591.66,3.7,0.13,584.1,1021,579.83,4.19,0.08
15000,62,85,124,594.27,825,590.1,4.03,0.13,573.77,885,568.89,4.76,0.12
20000,65,114,49,593.86,815,589.37,4.35,0.14,569.07,834,563.7,5.22,0.15
25000,58,142,149,593.21,809,588.5,4.57,0.14,568.4,829,563.17,5.15,0.08
30000,65,171,74,592.49,802,587.61,4.74,0.13,566.71,805,561.2,5.41,0.11


0 R: 1.000 P: 0.303 write get article one go like think say know make
   2 R: 0.127 P: 0.055 space system launch program use include nasa datum information computer
     21 R: 0.028 P: 0.028 use key chip government system law encryption phone need number
     22 R: 0.045 P: 0.045 government people state president american country law clinton say rights
   3 R: 0.260 P: 0.059 please thanks anyone email post know write send use look
     33 R: 0.081 P: 0.081 one write argument say make evidence think people moral god
     31 R: 0.079 P: 0.079 god jesus christian one say bible believe people write man
     32 R: 0.040 P: 0.040 people turkish armenian armenians one say kill woman war turkey
   5 R: 0.103 P: 0.091 window use file program server application run display windows widget
     52 R: 0.012 P: 0.012 file entry program output line title section return rule build
   6 R: 0.085 P: 0.039 car bike ride dod engine drive use buy road speed
     61 R: 0.023 P: 0.023 gun stephanopoulos tax 