In [1]:
%%javascript
IPython.notebook.kernel.execute('nb_name = "' + IPython.notebook.notebook_name + '"')

<IPython.core.display.Javascript object>

In [8]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ['PYTHONHASHSEED'] = '0'
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import argparse
import subprocess
import pdb
import time
import random
import _pickle as cPickle
import matplotlib.pyplot as plt
import glob

%matplotlib inline

import numpy as np
import pandas as pd
import tensorflow as tf

from data_structure import get_batches
from tsgntm import TreeStructuredGaussianNeuralTopicModel
from tree import get_descendant_idxs
from evaluation import validate, get_hierarchical_affinity, get_topic_specialization, print_topic_sample
from configure import get_config

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# load data & set config

In [3]:
config = get_config(nb_name)

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
np.random.seed(config.seed)
random.seed(config.seed)

In [5]:
instances_train, categories, word_to_idx, idx_to_word, embeddings = cPickle.load(open(config.path_data, 'rb'))
config.batch_size = len(instances_train)
train_batches = get_batches(instances_train, config.batch_size, iterator=False)
config.dim_bow = len(idx_to_word)
config.dim_emb = embeddings.shape[-1]

In [6]:
def debug(variable, sample_batch=None, sample=False):
    if sample_batch is None: sample_batch = train_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch, mode='eval')
    _variable = sess.run(variable, feed_dict=feed_dict)
    return _variable

# run

## initialize log

In [13]:
checkpoint = []
losses_train = []
ppls_train = []
ppl_min = np.inf
epoch = 0

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','','','','','VALID:','','','','','TEST:','', 'SPEC:', '', '', 'HIER:', ''],
                            ['Time','Ep','Ct','LOSS','PPL','NLL','KL','REG','LOSS','PPL','NLL','KL','REG','LOSS','PPL', '1', '2', '3', 'CHILD', 'OTHER']]))))

cmd_rm = 'rm -r %s' % config.dir_model
res = subprocess.call(cmd_rm.split())
cmd_mk = 'mkdir %s' % config.dir_model
res = subprocess.call(cmd_mk.split())

def update_checkpoint(config, checkpoint, global_step):
    checkpoint.append(config.path_model + '-%i' % global_step)
    if len(checkpoint) > config.max_to_keep:
        path_model = checkpoint.pop(0) + '.*'
        for p in glob.glob(path_model):
            os.remove(p)
    cPickle.dump(checkpoint, open(config.path_checkpoint, 'wb'))

## initialize model

In [14]:
if 'sess' in globals(): sess.close()
model = TreeStructuredGaussianNeuralTopicModel(config)
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=config.max_to_keep)
update_tree_flg = False

## train & validate model

In [15]:
config.n_epochs = 5000
config.log_period = 500

In [16]:
time_start = time.time()
while epoch < config.n_epochs:
    ct, batch = train_batches[0]
    
    feed_dict = model.get_feed_dict(batch)
    _, loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, ppls_batch, global_step_log = \
    sess.run([model.opt, model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_reg, model.topic_ppls, tf.train.get_global_step()], feed_dict = feed_dict)

    losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch]]
    ppls_train += list(ppls_batch)

    if global_step_log % config.log_period == 0:
        # validate
        loss_train, topic_loss_recon_train, topic_loss_kl_train, topic_loss_reg_train = np.mean(losses_train, 0)
        ppl_train = np.exp(np.mean(ppls_train))
        loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, probs_topic_dev = validate(sess, train_batches, model)

        # test
        if ppl_train < ppl_min:
            ppl_min = ppl_dev
            loss_test, _, _, _, ppl_test, _ = 0, 0, 0, 0, 0, 0
            saver.save(sess, config.path_model, global_step=global_step_log)
            cPickle.dump(config, open(config.path_config % global_step_log, 'wb'))
            update_checkpoint(config, checkpoint, global_step_log)

        # visualize topic
        topics_freq_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:, :config.n_freq]
        topics_freq_idxs = topics_freq_indices
        topic_freq_tokens = {topic_idx: [idx_to_word[idx] for idx in topic_freq_idxs] for topic_idx, topic_freq_idxs in zip(model.topic_idxs, topics_freq_idxs)}
        topic_prob_topic = {topic_idx: prob_topic for topic_idx, prob_topic in zip(model.topic_idxs, probs_topic_dev)}
        descendant_idxs = {parent_idx: get_descendant_idxs(model, parent_idx) for parent_idx in model.topic_idxs}
        recur_prob_topic = {parent_idx: np.sum([topic_prob_topic[child_idx] for child_idx in recur_child_idxs]) for parent_idx, recur_child_idxs in descendant_idxs.items()}

        depth_specs = get_topic_specialization(sess, model, instances_train)
        hierarchical_affinities = get_hierarchical_affinity(sess, model)

        # log
        clear_output()
        time_log = int(time.time() - time_start)
        log_series = pd.Series([time_log, epoch, ct, \
                '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_train, '%.2f'%topic_loss_reg_train, \
                '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_dev, '%.2f'%topic_loss_reg_dev, \
                '%.2f'%loss_test, '%.0f'%ppl_test, \
                '%.2f'%depth_specs[1], '%.2f'%depth_specs[2], '%.2f'%depth_specs[3], \
                '%.2f'%hierarchical_affinities[0], '%.2f'%hierarchical_affinities[1]],
                index=log_df.columns)
        log_df.loc[global_step_log] = log_series
        display(log_df)
        cPickle.dump(log_df, open(os.path.join(config.path_log), 'wb'))
        print_topic_sample(sess, model, topic_prob_topic=topic_prob_topic, recur_prob_topic=recur_prob_topic, topic_freq_tokens=topic_freq_tokens)

#         # update tree
#         if not config.static:
#             config.tree_idxs, update_tree_flg = model.update_tree(topic_prob_topic, recur_prob_topic)
#             if update_tree_flg:
#                 print(config.tree_idxs)
#                 name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))} # store paremeters
#                 if 'sess' in globals(): sess.close()
#                 model = HierarchicalNeuralTopicModel(config)
#                 sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
#                 name_tensors = {tensor.name: tensor for tensor in tf.global_variables()}
#                 sess.run([name_tensors[name].assign(variable) for name, variable in name_variables.items()]) # restore parameters
#                 saver = tf.train.Saver(max_to_keep=1)

        time_start = time.time()

    epoch += 1

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,TEST:,Unnamed: 15_level_0,SPEC:,Unnamed: 17_level_0,Unnamed: 18_level_0,HIER:,Unnamed: 20_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL,NLL,KL,REG,LOSS,PPL,1,2,3,CHILD,OTHER
500,9,499,0,102269.55,1830,102121.49,21.56,0.37,101475.84,1445,101355.08,12.92,0.38,0.0,0,0.9,0.42,0.21,0.66,0.6
1000,6,999,0,101912.95,1645,101785.2,16.87,0.36,101415.38,1488,101315.83,12.44,0.32,0.0,0,0.85,0.45,0.24,0.66,0.56
1500,7,1499,0,101782.69,1583,101665.1,15.04,0.35,101403.56,1505,101310.88,10.53,0.3,0.0,0,0.8,0.44,0.26,0.63,0.53
2000,6,1999,0,101714.04,1553,101602.77,13.72,0.35,101431.58,1516,101341.09,9.69,0.32,0.0,0,0.88,0.46,0.27,0.66,0.59
2500,6,2499,0,101670.35,1535,101563.13,12.83,0.34,101464.07,1414,101373.42,8.76,0.28,0.0,0,0.87,0.5,0.29,0.64,0.51
3000,6,2999,0,101640.24,1522,101535.95,12.28,0.34,101482.88,1415,101394.48,8.79,0.31,0.0,0,0.81,0.45,0.28,0.68,0.57
3500,6,3499,0,101616.98,1512,101515.08,11.76,0.34,101487.05,1649,101397.85,8.8,0.36,0.0,0,0.89,0.51,0.25,0.7,0.54
4000,6,3999,0,101599.09,1504,101499.18,11.31,0.33,101397.84,1411,101309.31,8.1,0.26,0.0,0,0.91,0.49,0.24,0.65,0.49
4500,6,4499,0,101584.87,1499,101486.41,10.95,0.33,101380.46,1511,101291.54,8.24,0.37,0.0,0,0.84,0.51,0.27,0.72,0.59


0 R: 1.000 P: 0.023 building_applications leads aerospace_applications building_sector fuel_cell_applications electric_wire insulation building_materials cables etfe_applications
   1 R: 0.654 P: 0.510 vascular_grafts lubricants greases vascular_access hemodialysis_access chordal_replacement mitral_valve_repair femoropopliteal_bypass binder electrode
     11 R: 0.144 P: 0.144 sensor binder sensors hollow_fiber_membranes electrodes actuators ultrafiltration separator hollow_fiber_membrane lithium-ion_batteries
     12 R: 0.000 P: 0.000 binder electrodes electrode cathode sensor vascular_grafts hollow_fiber_membranes direct_contact_membrane_distillation membrane_distillation biomedical_applications
   2 R: 0.184 P: 0.000 cathode separator hollow_fiber_membrane polymer_electrolyte_membranes proton_exchange_membranes cells direct_methanol_fuel_cell energy_harvesting binder hollow_fiber_membranes
     21 R: 0.183 P: 0.183 binder sensor sensors electrodes actuators hollow_fiber_membranes ele

KeyboardInterrupt: 