In [1]:
%%javascript
IPython.notebook.kernel.execute('nb_name = "' + IPython.notebook.notebook_name + '"')

<IPython.core.display.Javascript object>

In [2]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ['PYTHONHASHSEED'] = '0'
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import argparse
import subprocess
import pdb
import time
import random
import _pickle as cPickle
import matplotlib.pyplot as plt
import glob

%matplotlib inline

import numpy as np
import pandas as pd
import tensorflow as tf

from data_structure import get_batches
from hntm import HierarchicalNeuralTopicModel
from nhdp import nestedHierarchicalNeuralTopicModel
from tsgntm import TreeStructuredGaussianNeuralTopicModel
from tree import get_descendant_idxs
from evaluation import get_hierarchical_affinity, get_topic_specialization, print_topic_sample
from configure import get_config

pd.set_option('display.max_columns', 30)

# load data & set config

In [3]:
config = get_config(nb_name)

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
np.random.seed(config.seed)
random.seed(config.seed)

In [5]:
instances_train, instances_dev, instances_test, word_to_idx, idx_to_word, bow_idxs = cPickle.load(open(config.path_data,'rb'))
train_batches = get_batches(instances_train, config.batch_size)
dev_batches = get_batches(instances_dev, config.batch_size)
test_batches = get_batches(instances_test, config.batch_size)
config.dim_bow = len(bow_idxs)

In [6]:
def debug(variable, sample_batch=None):
    if sample_batch is None: sample_batch = test_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch, mode='test')
    _variable = sess.run(variable, feed_dict=feed_dict)
    return _variable

def check(variable):
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    sample_batch = test_batches[0]
    feed_dict = model.get_feed_dict(sample_batch, mode='test', assertion=True)
    _variable = sess.run(variable, feed_dict=feed_dict)
    return _variable

# model

## initialize model

In [7]:
if 'sess' in globals(): sess.close()
TopicModels = {'hntm': HierarchicalNeuralTopicModel, 'nhdp': nestedHierarchicalNeuralTopicModel, 'tsgntm': TreeStructuredGaussianNeuralTopicModel}
TopicModel = TopicModels[config.model]
model = TopicModel(config)
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=config.max_to_keep)
update_tree_flg = False

# run

## initialize log

In [8]:
checkpoint = []
losses_train = []
ppls_train = []
ppl_min = np.inf
epoch = 0
train_batches = get_batches(instances_train, config.batch_size, iterator=True)

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','','','','','','VALID:','','','','','','TEST:','', 'SPEC:', '', '', 'HIER:', ''],
                            ['Time','Ep','Ct','LOSS','PPL','NLL','KL', 'GAUSS', 'REG','LOSS','PPL','NLL','KL', 'GAUSS','REG','LOSS','PPL', '1', '2', '3', 'CHILD', 'OTHER']]))))

cmd_rm = 'rm -r %s' % config.dir_model
res = subprocess.call(cmd_rm.split())
cmd_mk = 'mkdir %s' % config.dir_model
res = subprocess.call(cmd_mk.split())

def update_checkpoint(config, checkpoint, global_step):
    checkpoint.append(config.path_model + '-%i' % global_step)
    if len(checkpoint) > config.max_to_keep:
        path_model = checkpoint.pop(0) + '.*'
        for p in glob.glob(path_model):
            os.remove(p)
    cPickle.dump(checkpoint, open(config.path_checkpoint, 'wb'))
    
def validate(sess, batches, model):
    losses = []
    ppl_list = []
    prob_topic_list = []
    n_bow_list = []
    n_topics_list = []
    for ct, batch in batches:
        feed_dict = model.get_feed_dict(batch, mode='test')
        loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_gauss_batch, topic_loss_reg_batch, ppls_batch, prob_topic_batch, n_bow_batch, n_topics_batch \
            = sess.run([model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_gauss, model.topic_loss_reg, model.topic_ppls, model.prob_topic, model.n_bow, model.n_topics], feed_dict = feed_dict)
        losses += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_gauss_batch, topic_loss_reg_batch]]
        ppl_list += list(ppls_batch)
        prob_topic_list.append(prob_topic_batch)
        n_bow_list.append(n_bow_batch)
        n_topics_list.append(n_topics_batch)
    loss_mean, topic_loss_recon_mean, topic_loss_kl_mean, topic_loss_gauss_mean, topic_loss_reg_mean = np.mean(losses, 0)
    ppl_mean = np.exp(np.mean(ppl_list))
    
    probs_topic = np.concatenate(prob_topic_list, 0)
    
    n_bow = np.concatenate(n_bow_list, 0)
    n_topics = np.concatenate(n_topics_list, 0)
    probs_topic_mean = np.sum(n_topics, 0) / np.sum(n_bow)
    
    return loss_mean, topic_loss_recon_mean, topic_loss_kl_mean, topic_loss_gauss_mean, topic_loss_reg_mean, ppl_mean, probs_topic_mean    

## train & validate model

In [9]:
time_start = time.time()
while epoch < config.n_epochs:
    # train
    for ct, batch in train_batches:
        feed_dict = model.get_feed_dict(batch)
        _, loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_gauss_batch, topic_loss_reg_batch, ppls_batch, global_step_log = \
        sess.run([model.opt, model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_gauss, model.topic_loss_reg, model.topic_ppls, tf.train.get_global_step()], feed_dict = feed_dict)

        losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_gauss_batch, topic_loss_reg_batch]]
        ppls_train += list(ppls_batch)

        if global_step_log % config.log_period == 0:
            # validate
            loss_train, topic_loss_recon_train, topic_loss_kl_train, topic_loss_gauss_train, topic_loss_reg_train = np.mean(losses_train, 0)
            ppl_train = np.exp(np.mean(ppls_train))
            loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_gauss_dev, topic_loss_reg_dev, ppl_dev, probs_topic_dev = validate(sess, dev_batches, model)

            # test
            if ppl_dev < ppl_min:
                ppl_min = ppl_dev
                loss_test, _, _, _, _, ppl_test, _ = validate(sess, test_batches, model)
                saver.save(sess, config.path_model, global_step=global_step_log)
                cPickle.dump(config, open(config.path_config % global_step_log, 'wb'))
                update_checkpoint(config, checkpoint, global_step_log)
            
            # visualize topic
            topics_freq_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:, :config.n_freq]
            topics_freq_idxs = bow_idxs[topics_freq_indices]
            topic_freq_tokens = {topic_idx: [idx_to_word[idx] for idx in topic_freq_idxs] for topic_idx, topic_freq_idxs in zip(model.topic_idxs, topics_freq_idxs)}
            topic_prob_topic = {topic_idx: prob_topic for topic_idx, prob_topic in zip(model.topic_idxs, probs_topic_dev)}
            descendant_idxs = {parent_idx: get_descendant_idxs(model, parent_idx) for parent_idx in model.topic_idxs}
            recur_prob_topic = {parent_idx: np.sum([topic_prob_topic[child_idx] for child_idx in recur_child_idxs]) for parent_idx, recur_child_idxs in descendant_idxs.items()}
            
            depth_specs = get_topic_specialization(sess, model, instances_test)
            hierarchical_affinities = get_hierarchical_affinity(sess, model)
            
            # log
            clear_output()
            time_log = int(time.time() - time_start)
            log_series = pd.Series([time_log, epoch, ct, \
                    '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_train, '%.2f'%topic_loss_gauss_train, '%.2f'%topic_loss_reg_train, \
                    '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_dev, '%.2f'%topic_loss_gauss_dev, '%.2f'%topic_loss_reg_dev, \
                    '%.2f'%loss_test, '%.0f'%ppl_test, \
                    '%.2f'%depth_specs[1], '%.2f'%depth_specs[2], '%.2f'%depth_specs[3], \
                    '%.2f'%hierarchical_affinities[0], '%.2f'%hierarchical_affinities[1]],
                    index=log_df.columns)
            log_df.loc[global_step_log] = log_series
            display(log_df)
            cPickle.dump(log_df, open(os.path.join(config.path_log), 'wb'))
            print_topic_sample(sess, model, topic_prob_topic=topic_prob_topic, recur_prob_topic=recur_prob_topic, topic_freq_tokens=topic_freq_tokens)

            # update tree
            if not config.static:
                config.tree_idxs, update_tree_flg = model.update_tree(topic_prob_topic, recur_prob_topic)
                if update_tree_flg:
                    print(config.tree_idxs)
                    name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))} # store paremeters
                    if 'sess' in globals(): sess.close()
                    model = HierarchicalNeuralTopicModel(config)
                    sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
                    name_tensors = {tensor.name: tensor for tensor in tf.global_variables()}
                    sess.run([name_tensors[name].assign(variable) for name, variable in name_variables.items()]) # restore parameters
                    saver = tf.train.Saver(max_to_keep=1)
                
            time_start = time.time()

    train_batches = get_batches(instances_train, config.batch_size, iterator=True)
    epoch += 1

loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, probs_topic_dev = validate(sess, dev_batches, model)
topics_freq_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:, :config.n_freq]
topics_freq_idxs = bow_idxs[topics_freq_indices]
topic_freq_tokens = {topic_idx: [idx_to_word[idx] for idx in topic_freq_idxs] for topic_idx, topic_freq_idxs in zip(model.topic_idxs, topics_freq_idxs)}
topic_prob_topic = {topic_idx: prob_topic for topic_idx, prob_topic in zip(model.topic_idxs, probs_topic_dev)}
descendant_idxs = {parent_idx: get_descendant_idxs(model, parent_idx) for parent_idx in model.topic_idxs}
recur_prob_topic = {parent_idx: np.sum([topic_prob_topic[child_idx] for child_idx in recur_child_idxs]) for parent_idx, recur_child_idxs in descendant_idxs.items()}
display(log_df)
print_topic_sample(sess, model, topic_prob_topic=topic_prob_topic, recur_prob_topic=recur_prob_topic, topic_freq_tokens=topic_freq_tokens)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,VALID:,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,Unnamed: 14_level_0,Unnamed: 15_level_0,TEST:,Unnamed: 17_level_0,SPEC:,Unnamed: 19_level_0,Unnamed: 20_level_0,HIER:,Unnamed: 22_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,GAUSS,REG,LOSS,PPL,NLL,KL,GAUSS,REG,LOSS,PPL,1,2,3,CHILD,OTHER
5000,77,8,463,9266.38,1971,9219.21,16.53,30.64,0.14,9240.28,1861,9182.74,20.37,37.17,0.12,9238.98,1858,0.16,0.49,0.34,0.5,0.31
10000,72,17,360,9226.75,1894,9169.75,18.57,38.43,0.13,9236.92,1830,9162.81,20.43,53.67,0.13,9238.99,1834,0.31,0.49,0.34,0.34,0.21
15000,72,26,257,9214.31,1864,9149.26,19.29,45.73,0.13,9242.35,1823,9156.83,20.77,64.76,0.14,9242.0,1823,0.13,0.6,0.35,0.35,0.22
20000,74,35,154,9207.12,1845,9136.51,19.76,50.8,0.14,9233.41,1808,9145.03,21.14,67.24,0.15,9233.44,1807,0.17,0.63,0.37,0.3,0.19
25000,73,44,51,9202.1,1832,9127.21,20.09,54.8,0.14,9236.87,1803,9142.52,21.35,73.0,0.14,9239.36,1806,0.19,0.61,0.39,0.32,0.19
30000,72,52,515,9198.77,1822,9120.34,20.32,58.1,0.14,9236.95,1802,9141.24,21.29,74.42,0.15,9238.17,1803,0.17,0.57,0.37,0.33,0.21
35000,74,61,412,9196.06,1814,9114.93,20.5,60.61,0.14,9234.92,1795,9137.19,21.39,76.33,0.16,9236.31,1797,0.61,0.67,0.36,0.27,0.17
40000,71,70,309,9194.0,1807,9110.57,20.64,62.78,0.14,9234.03,1791,9133.34,21.49,79.2,0.14,9236.62,1794,0.48,0.55,0.4,0.2,0.13
45000,72,79,206,9192.04,1802,9106.73,20.75,64.54,0.14,9235.15,1795,9136.35,21.33,77.47,0.15,9236.62,1794,0.46,0.62,0.37,0.21,0.14
50000,72,88,103,9190.22,1797,9103.36,20.83,66.01,0.14,9234.32,1788,9131.98,21.54,80.8,0.14,9233.64,1787,0.2,0.65,0.39,0.35,0.23


0 R: 1.000 P: 0.001 type bridging hpsg category modal tense coreference graph arg connectives
   1 R: 0.266 P: 0.056 translation english source target phrase alignment systems sentences pairs translations
     11 R: 0.114 P: 0.114 models training embeddings neural network vector learning vectors performance context
     12 R: 0.096 P: 0.096 features sentences topic sentiment analysis scores score level human evaluation
   2 R: 0.225 P: 0.036 document query documents question answer questions terms retrieval search queries
     21 R: 0.101 P: 0.101 features training performance feature classification learning test domain classifier systems
     22 R: 0.088 P: 0.088 similarity sense semantic wordnet relations senses relation method pairs knowledge
   3 R: 0.179 P: 0.011 annotation annotated agreement annotations annotators annotator scheme inter annotate guidelines
     31 R: 0.082 P: 0.082 morphological languages english character dictionary errors forms form chinese characters
     32 

KeyboardInterrupt: 

In [10]:
np.sum(debug(model.topic_logvars), 1)

array([108.88487 ,  98.85533 , 101.73428 , 138.29723 ,  90.70738 ,
        97.048004,  37.88733 ,  38.16417 ,  41.808304,  45.578796,
        53.009644,  54.871647,  50.3262  ,  45.501328,  54.289352,
        48.97226 ], dtype=float32)

In [17]:
debug()

array([[1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        ],
       [1.        , 0.89822125, 1.        , 1.        , 1.        ,
        1.        , 0.        , 1.        , 1.        , 0.27806905,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        0.        , 0.        , 0.44281808, 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        0.89801407, 1.        , 0.8412844 , 0.        , 0.        ,
        0.        , 0.08885161],
       [0.        , 1.        , 1.        , 1.        , 0.85427874,
        1.        , 0.        , 1.        , 1.    