In [4]:
%%javascript
IPython.notebook.kernel.execute('nb_name = "' + IPython.notebook.notebook_name + '"')

<IPython.core.display.Javascript object>

In [2]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ['PYTHONHASHSEED'] = '0'
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import argparse
import subprocess
import pdb
import time
import random
import _pickle as cPickle
import matplotlib.pyplot as plt
import glob

%matplotlib inline

import numpy as np
import pandas as pd
import tensorflow as tf

from data_structure import get_batches
from gsm import GaussianSoftmaxModel
from rsm import RecurrentStickbreakingModel
from evaluation import validate, print_flat_topic_sample
from configure import get_config

# load data & set config

In [5]:
config = get_config(nb_name)

In [6]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
np.random.seed(config.seed)
random.seed(config.seed)

In [7]:
instances_train, instances_dev, instances_test, word_to_idx, idx_to_word, bow_idxs = cPickle.load(open(config.path_data,'rb'))
train_batches = get_batches(instances_train, config.batch_size)
dev_batches = get_batches(instances_dev, config.batch_size)
test_batches = get_batches(instances_test, config.batch_size)
config.dim_bow = len(bow_idxs)

In [8]:
def debug_shape(variables, model):
    sample_batch = test_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable.shape)
        else:
            print(_variable.shape)

def debug_value(variables, model, return_value=False):
    sample_batch = test_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)

    return _variables

# run

## initialize log

In [9]:
checkpoint = []
losses_train = []
ppls_train = []
ppl_min = np.inf
epoch = 0
train_batches = get_batches(instances_train, config.batch_size, iterator=True)

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','','','','','VALID:','','','','','TEST:',''],
                            ['Time','Ep','Ct','LOSS','PPL','NLL','KL','REG','LOSS','PPL','NLL','KL','REG','LOSS','PPL']]))))

cmd_rm = 'rm -r %s' % config.dir_model
res = subprocess.call(cmd_rm.split())
cmd_mk = 'mkdir %s' % config.dir_model
res = subprocess.call(cmd_mk.split())

def update_checkpoint(config, checkpoint, global_step):
    checkpoint.append(config.path_model + '-%i' % global_step)
    if len(checkpoint) > config.max_to_keep:
        path_model = checkpoint.pop(0) + '.*'
        for p in glob.glob(path_model):
            os.remove(p)
    cPickle.dump(checkpoint, open(config.path_checkpoint, 'wb'))

## initialize model

In [10]:
if 'sess' in globals(): sess.close()
if config.model == 'gsm':
    Model = GaussianSoftmaxModel
elif config.model == 'rsm':
    Model = RecurrentStickbreakingModel
model = Model(config)    
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=config.max_to_keep)
update_tree_flg = False

## train & validate model

In [22]:
time_start = time.time()
while epoch < config.n_epochs:
    # train
    for ct, batch in train_batches:
        feed_dict = model.get_feed_dict(batch)
        _, loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, ppls_batch, global_step_log = \
        sess.run([model.opt, model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_reg, model.topic_ppls, tf.train.get_global_step()], feed_dict = feed_dict)

        losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch]]
        ppls_train += list(ppls_batch)

        if global_step_log % config.log_period == 0:
            # validate
            loss_train, topic_loss_recon_train, topic_loss_kl_train, topic_loss_reg_train = np.mean(losses_train, 0)
            ppl_train = np.exp(np.mean(ppls_train))
            loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, probs_topic_dev = validate(sess, dev_batches, model)

            # test
            if ppl_dev < ppl_min:
                ppl_min = ppl_dev
                loss_test, _, _, _, ppl_test, _ = validate(sess, test_batches, model)
                saver.save(sess, config.path_model, global_step=global_step_log)
                cPickle.dump(config, open(config.path_config % global_step_log, 'wb'))
                update_checkpoint(config, checkpoint, global_step_log)
            
            # visualize topic
            topics_freq_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:, :config.n_freq]
            topics_freq_idxs = bow_idxs[topics_freq_indices]
            topics_freq_tokens = [[idx_to_word[idx] for idx in topic_freq_idxs] for topic_freq_idxs in topics_freq_idxs]
            
            # log
            clear_output()
            time_log = int(time.time() - time_start)
            log_series = pd.Series([time_log, epoch, ct, \
                    '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_train, '%.2f'%topic_loss_reg_train, \
                    '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_dev, '%.2f'%topic_loss_reg_dev, \
                    '%.2f'%loss_test, '%.0f'%ppl_test],
                    index=log_df.columns)
            log_df.loc[global_step_log] = log_series
            display(log_df)
            cPickle.dump(log_df, open(os.path.join(config.path_log), 'wb'))
            print_flat_topic_sample(sess, model, topics_freq_tokens=topics_freq_tokens)

            # update tree
            if not config.static:
                config.n_topic, update_flg, diff = model.update_topic(sess, dev_batches)
                print('n_topic = Diff: %.6f' % diff)
                if update_flg:
                    print('Update to %i' % config.n_topic)
                    name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))} # store paremeters
                    if 'sess' in globals(): sess.close()
                    model = Model(config)
                    sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
                    name_tensors = {tensor.name: tensor for tensor in tf.global_variables()}
                    sess.run([name_tensors[name].assign(variable) for name, variable in name_variables.items()]) # restore parameters
                    saver = tf.train.Saver(max_to_keep=1)
                
            time_start = time.time()

    train_batches = get_batches(instances_train, config.batch_size, iterator=True)
    epoch += 1

display(log_df)
print_flat_topic_sample(sess, model, topics_freq_tokens=topics_freq_tokens)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,TEST:,Unnamed: 15_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL,NLL,KL,REG,LOSS,PPL
5000,41,28,99,600.12,925,596.89,3.11,0.12,586.09,1065,582.68,3.33,0.08,586.31,1068
10000,39,57,24,601.01,925,597.35,3.55,0.1,574.02,968,570.27,3.68,0.06,574.07,968
15000,38,85,124,599.19,920,595.28,3.83,0.08,575.08,960,570.91,4.12,0.05,574.86,956
20000,40,114,49,598.95,916,594.85,4.02,0.08,572.73,958,568.6,4.08,0.05,572.85,957
25000,39,142,149,598.22,914,593.98,4.16,0.07,572.17,942,567.78,4.35,0.04,572.0,939
30000,42,171,74,597.72,911,593.39,4.26,0.07,572.94,947,568.53,4.35,0.06,572.0,939
35000,44,199,174,597.64,909,593.23,4.35,0.07,571.86,941,567.28,4.52,0.07,572.01,941
40000,42,228,99,597.05,908,592.58,4.41,0.07,575.08,983,570.56,4.46,0.06,572.01,941
45000,42,257,24,597.01,906,592.47,4.47,0.07,572.0,939,567.43,4.51,0.06,571.89,940
50000,43,285,124,596.58,905,592.0,4.52,0.07,573.39,948,568.75,4.59,0.05,571.89,940


0 use get write file one drive work know problem like
1 write get go article game year think one like play
2 god say one write people think know believe make jesus
3 say people israel one turkish armenian israeli kill jews come
4 key use chip encryption system government clipper one phone information
5 space launch nasa system satellite university include new datum research
6 go say president get think know work people make tax
7 article write get one food use cause like doctor disease
8 people write government article law state right make one rights
9 use number man report study homosexual child drug health gay
10 file gun firearm bill law weapon control use vote state
11 fbi fire koresh batf write people child compound start article
12 fire gun tank claim use weapon know well write believe
13 gun cop police thing safety gang know go get see
14 gun weapon people one crime use criminal get kill think
n_topic = Diff: -0.000016


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,TEST:,Unnamed: 15_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL,NLL,KL,REG,LOSS,PPL
5000,41,28,99,600.12,925,596.89,3.11,0.12,586.09,1065,582.68,3.33,0.08,586.31,1068
10000,39,57,24,601.01,925,597.35,3.55,0.1,574.02,968,570.27,3.68,0.06,574.07,968
15000,38,85,124,599.19,920,595.28,3.83,0.08,575.08,960,570.91,4.12,0.05,574.86,956
20000,40,114,49,598.95,916,594.85,4.02,0.08,572.73,958,568.6,4.08,0.05,572.85,957
25000,39,142,149,598.22,914,593.98,4.16,0.07,572.17,942,567.78,4.35,0.04,572.0,939
30000,42,171,74,597.72,911,593.39,4.26,0.07,572.94,947,568.53,4.35,0.06,572.0,939
35000,44,199,174,597.64,909,593.23,4.35,0.07,571.86,941,567.28,4.52,0.07,572.01,941
40000,42,228,99,597.05,908,592.58,4.41,0.07,575.08,983,570.56,4.46,0.06,572.01,941
45000,42,257,24,597.01,906,592.47,4.47,0.07,572.0,939,567.43,4.51,0.06,571.89,940
50000,43,285,124,596.58,905,592.0,4.52,0.07,573.39,948,568.75,4.59,0.05,571.89,940


0 use get write file one drive work know problem like
1 write get go article game year think one like play
2 god say one write people think know believe make jesus
3 say people israel one turkish armenian israeli kill jews come
4 key use chip encryption system government clipper one phone information
5 space launch nasa system satellite university include new datum research
6 go say president get think know work people make tax
7 article write get one food use cause like doctor disease
8 people write government article law state right make one rights
9 use number man report study homosexual child drug health gay
10 file gun firearm bill law weapon control use vote state
11 fbi fire koresh batf write people child compound start article
12 fire gun tank claim use weapon know well write believe
13 gun cop police thing safety gang know go get see
14 gun weapon people one crime use criminal get kill think
