In [1]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ['PYTHONHASHSEED'] = '0'
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import subprocess
import pdb
import time
import datetime
import math
import copy
import random
import _pickle as cPickle
from collections import defaultdict
import matplotlib.pyplot as plt

%matplotlib inline

import numpy as np
import pandas as pd
import tensorflow as tf

from data_structure import get_batches
from hntm import HierarchicalNeuralTopicModel
from tree import get_descendant_idxs
from evaluation import validate, print_hierarchical_affinity, print_topic_sample, print_topic_specialization

seed = 0
np.random.seed(seed)
random.seed(seed)

# load data & set config

In [2]:
def del_all_flags(FLAGS):
    flags_dict = FLAGS._flags()    
    keys_list = [keys for keys in flags_dict]    
    for keys in keys_list:
        FLAGS.__delattr__(keys)

del_all_flags(tf.flags.FLAGS)

flags = tf.app.flags

flags.DEFINE_string('gpu', '3', 'visible gpu')
flags.DEFINE_integer('seed', seed, 'random seed')

flags.DEFINE_string('data_path', 'data/20news/instances.pkl', 'path of data')
flags.DEFINE_string('modeldir', 'model/topic_vae', 'directory of model')
flags.DEFINE_string('modelname', '20news', 'name of model')

flags.DEFINE_integer('epochs', 1000, 'epochs')
flags.DEFINE_integer('batch_size', 64, 'number of sentences in each batch')
flags.DEFINE_integer('log_period', 3000, 'valid period')

flags.DEFINE_string('opt', 'Adagrad', 'optimizer')
flags.DEFINE_float('lr', 0.01, 'lr')
flags.DEFINE_float('reg', 1., 'regularization term')
flags.DEFINE_float('grad_clip', 5., 'grad_clip')

flags.DEFINE_float('keep_prob', 0.8, 'dropout rate')

flags.DEFINE_integer('dim_hidden_bow', 256, 'dim of hidden bow')
flags.DEFINE_integer('dim_latent_bow', 32, 'dim of latent topic')
flags.DEFINE_integer('dim_emb', 256, 'dim_emb')

flags.DEFINE_float('depth_temperature', 1., 'dropout rate')

# for evaluation
flags.DEFINE_string('refdir', 'ref', 'refdir')
flags.DEFINE_string('outdir', 'out', 'outdir')

flags.DEFINE_string('f', '', 'kernel')

config = flags.FLAGS

flags.DEFINE_string('modelpath', os.path.join(config.modeldir, config.modelname), 'path of model')

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu

In [4]:
instances_train, instances_dev, instances_test, word_to_idx, idx_to_word, bow_idxs = cPickle.load(open(config.data_path,'rb'))
train_batches = get_batches(instances_train, config.batch_size)
dev_batches = get_batches(instances_dev, config.batch_size)
test_batches = get_batches(instances_test, config.batch_size)

In [5]:
flags.DEFINE_integer('dim_bow', len(bow_idxs), 'dim_bow')

In [6]:
def debug_shape(variables, model):
    sample_batch = test_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable.shape)
        else:
            print(_variable.shape)

def debug_value(variables, model, return_value=False):
    sample_batch = test_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)

    return _variables

# run

In [7]:
losses_train = []
ppls_train = []
loss_min = np.inf
epoch = 0
train_batches = get_batches(instances_train, config.batch_size, iterator=True)

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','TM','','','','VALID:','TM','','',''],
                            ['Time','Ep','Ct','LOSS','PPL','NLL','KL','REG','LOSS','PPL','NLL','KL','REG']]))))

In [8]:
tree_idxs = {0:[1, 2, 3, 4, 5], 
              1:[11, 12], 2:[21, 22], 3:[31, 32], 4:[41, 42], 5:[51, 52]}

if 'sess' in globals(): sess.close()
model = HierarchicalNeuralTopicModel(config, tree_idxs)
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
sess.run(tf.global_variables_initializer())
name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))}
saver = tf.train.Saver(max_to_keep=10)
update_tree_flg = False

In [9]:
if len(log_df) == 0:
    cmd_rm = 'rm -r %s' % config.modeldir
    res = subprocess.call(cmd_rm.split())

    cmd_mk = 'mkdir %s' % config.modeldir
    res = subprocess.call(cmd_mk.split())

time_start = time.time()
while epoch < config.epochs:    
    # train
    for ct, batch in train_batches:
        feed_dict = model.get_feed_dict(batch)
        _, loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, ppls_batch, global_step_log = \
        sess.run([model.opt, model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_reg, model.topic_ppls, tf.train.get_global_step()], feed_dict = feed_dict)

        losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch]]
        ppls_train += list(ppls_batch)

        # validate
        if global_step_log % 5000 == 0:            
            loss_train, topic_loss_recon_train, topic_loss_kl_train, topic_loss_reg_train = np.mean(losses_train, 0)
            ppl_train = np.exp(np.mean(ppls_train))
            loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, rads_bow_dev, probs_topic_dev = validate(sess, dev_batches, model)

            # log
            clear_output()
            time_log = int(time.time() - time_start)
            log_series = pd.Series([time_log, epoch, ct, \
                    '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_train, '%.2f'%topic_loss_reg_train, \
                    '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_dev, '%.2f'%topic_loss_reg_dev],
                    index=log_df.columns)
            log_df.loc[global_step_log] = log_series
            display(log_df)

            # visualize topic
            topics_freq_idxs = bow_idxs[sess.run(model.topics_freq_bow_indices)]
            topic_freq_token = {topic_idx: ' '.join([idx_to_word[idx] for idx in topic_freq_idxs]) for topic_idx, topic_freq_idxs in zip(model.topic_idxs, topics_freq_idxs)}
            topic_prob_topic = {topic_idx: prob_topic for topic_idx, prob_topic in zip(model.topic_idxs, probs_topic_dev)}
            descendant_idxs = {parent_idx: get_descendant_idxs(model, parent_idx) for parent_idx in model.topic_idxs}
            recur_prob_topic = {parent_idx: np.sum([topic_prob_topic[child_idx] for child_idx in recur_child_idxs]) for parent_idx, recur_child_idxs in descendant_idxs.items()}
            
            print_topic_sample(sess, model, topic_prob_topic=topic_prob_topic, recur_prob_topic=recur_prob_topic, topic_freq_token=topic_freq_token)
            print_topic_specialization(sess, model, instances_test)
            print_hierarchical_affinity(sess, model)
            time_start = time.time()

            # update tree
            tree_idxs, update_tree_flg = model.update_tree(topic_prob_topic, recur_prob_topic, add_threshold=0.05, remove_threshold=0.05)
            if update_tree_flg:
                print(tree_idxs)
                name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))} # store paremeters
                if 'sess' in globals(): sess.close()
                model = HierarchicalNeuralTopicModel(config, tree_idxs)
                sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
                name_tensors = {tensor.name: tensor for tensor in tf.global_variables()}
                sess.run([name_tensors[name].assign(variable) for name, variable in name_variables.items()]) # restore parameters

    train_batches = get_batches(instances_train, config.batch_size, iterator=True)
    epoch += 1

loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, rads_bow_dev, probs_topic_dev = validate(sess, dev_batches, model)
topic_prob_topic = {topic_idx: prob_topic for topic_idx, prob_topic in zip(model.topic_idxs, probs_topic_dev)}

display(log_df)
topics_freq_idxs = bow_idxs[sess.run(model.topics_freq_bow_indices)]
topic_freq_token = {topic_idx: ' '.join([idx_to_word[idx] for idx in topic_freq_idxs]) for topic_idx, topic_freq_idxs in zip(model.topic_idxs, topics_freq_idxs)}
topic_prob_topic = {topic_idx: prob_topic for topic_idx, prob_topic in zip(model.topic_idxs, probs_topic_dev)}
descendant_idxs = {parent_idx: get_descendant_idxs(model, parent_idx) for parent_idx in model.topic_idxs}
recur_prob_topic = {parent_idx: np.sum([topic_prob_topic[child_idx] for child_idx in recur_child_idxs]) for parent_idx, recur_child_idxs in descendant_idxs.items()}

print_topic_sample(sess, model, topic_prob_topic=topic_prob_topic, recur_prob_topic=recur_prob_topic, topic_freq_token=topic_freq_token)
print_topic_specialization(sess, model, instances_test)
print_hierarchical_affinity(sess, model)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,TM,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,TM,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL.1,NLL,KL,REG
5000,73,28,99,596.02,857,592.92,2.9,0.19,597.51,1120,593.54,3.78,0.19
10000,63,57,24,595.24,836,591.22,3.86,0.15,574.03,904,569.31,4.64,0.07
15000,83,85,124,593.25,826,588.9,4.19,0.15,570.9,879,566.27,4.51,0.11
20000,80,114,49,592.86,816,588.28,4.44,0.14,568.38,835,563.28,5.0,0.1
25000,64,142,149,592.18,811,587.5,4.55,0.13,567.79,844,563.19,4.52,0.09
30000,69,171,74,591.63,805,586.84,4.66,0.12,567.55,823,562.4,5.07,0.08
35000,73,199,174,591.33,798,586.44,4.77,0.12,565.78,813,560.62,5.06,0.09
40000,78,228,99,590.62,793,585.65,4.86,0.12,566.81,813,561.4,5.3,0.11
45000,70,257,24,590.46,789,585.43,4.92,0.12,565.66,808,560.51,5.08,0.07
50000,70,285,124,589.95,786,584.87,4.97,0.11,566.09,804,560.84,5.19,0.06


0 R: 1.000 P: 0.320 write article get think know one go like say make
   1 R: 0.225 P: 0.059 space president year program work launch new tax stephanopoulos system
     11 R: 0.023 P: 0.023 team game play hockey new san la pt win period
     13 R: 0.034 P: 0.034 game player year team play win season run good last
     14 R: 0.110 P: 0.110 god christian say jesus one believe people bible church life
   2 R: 0.134 P: 0.058 gun government people law state use weapon crime police right
     22 R: 0.046 P: 0.046 israel israeli people state jews man arab write article war
     21 R: 0.030 P: 0.030 use study food disease cause medical patient doctor water msg
   6 R: 0.064 P: 0.025 key use chip encryption government clipper system privacy security law
     62 R: 0.012 P: 0.012 say go one people come know see tell take get
     63 R: 0.027 P: 0.027 turkish armenian armenians turks turkey armenia people greek government genocide
   7 R: 0.129 P: 0.071 use file window program display run server 

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,TM,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,TM,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL.1,NLL,KL,REG
5000,73,28,99,596.02,857,592.92,2.9,0.19,597.51,1120,593.54,3.78,0.19
10000,63,57,24,595.24,836,591.22,3.86,0.15,574.03,904,569.31,4.64,0.07
15000,83,85,124,593.25,826,588.9,4.19,0.15,570.9,879,566.27,4.51,0.11
20000,80,114,49,592.86,816,588.28,4.44,0.14,568.38,835,563.28,5.0,0.1
25000,64,142,149,592.18,811,587.5,4.55,0.13,567.79,844,563.19,4.52,0.09
30000,69,171,74,591.63,805,586.84,4.66,0.12,567.55,823,562.4,5.07,0.08
35000,73,199,174,591.33,798,586.44,4.77,0.12,565.78,813,560.62,5.06,0.09
40000,78,228,99,590.62,793,585.65,4.86,0.12,566.81,813,561.4,5.3,0.11
45000,70,257,24,590.46,789,585.43,4.92,0.12,565.66,808,560.51,5.08,0.07
50000,70,285,124,589.95,786,584.87,4.97,0.11,566.09,804,560.84,5.19,0.06


0 R: 1.000 P: 0.321 write article get think know one go like say make
   1 R: 0.225 P: 0.058 space president year program work launch new tax stephanopoulos system
     13 R: 0.022 P: 0.022 team game play hockey new san la pt win period
     14 R: 0.035 P: 0.035 game player year team play win season run good last
     12 R: 0.110 P: 0.110 god christian say jesus one believe people bible church life
   2 R: 0.133 P: 0.058 gun government people law state use weapon crime police right
     22 R: 0.047 P: 0.047 israel israeli people state jews man arab write article war
     23 R: 0.028 P: 0.028 use study food disease cause medical patient doctor water msg
   6 R: 0.065 P: 0.025 key use chip encryption government clipper system privacy security law
     63 R: 0.040 P: 0.040 say go one people come know see tell take get
   7 R: 0.130 P: 0.071 use file window program display run server application set widget
     72 R: 0.035 P: 0.035 drive use card system disk problem mac work driver monitor