In [1]:
%%javascript
IPython.notebook.kernel.execute('nb_name = "' + IPython.notebook.notebook_name + '"')

<IPython.core.display.Javascript object>

In [2]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ['PYTHONHASHSEED'] = '0'
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import argparse
import subprocess
import pdb
import time
import random
import _pickle as cPickle
import matplotlib.pyplot as plt
import glob

%matplotlib inline

import numpy as np
import pandas as pd
import tensorflow as tf

from data_structure import get_batches
from gsm import GaussianSoftmaxModel
from rsm import RecurrentStickbreakingModel
from evaluation import validate, print_flat_topic_sample
from configure import get_config

# load data & set config

In [3]:
config = get_config(nb_name)

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
np.random.seed(config.seed)
random.seed(config.seed)

In [5]:
instances_train, instances_dev, instances_test, word_to_idx, idx_to_word, bow_idxs = cPickle.load(open(config.path_data,'rb'))
train_batches = get_batches(instances_train, config.batch_size)
dev_batches = get_batches(instances_dev, config.batch_size)
test_batches = get_batches(instances_test, config.batch_size)
config.dim_bow = len(bow_idxs)

In [6]:
def debug_shape(variables, model):
    sample_batch = test_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)
    for _variable, variable in zip(_variables, variables):
        if hasattr(variable, 'name'):
            print(variable.name, ':', _variable.shape)
        else:
            print(_variable.shape)

def debug_value(variables, model, return_value=False):
    sample_batch = test_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch)
    _variables = sess.run(variables, feed_dict=feed_dict)

    return _variables

# run

## initialize log

In [7]:
checkpoint = []
losses_train = []
ppls_train = []
ppl_min = np.inf
epoch = 0
train_batches = get_batches(instances_train, config.batch_size, iterator=True)

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','','','','','VALID:','','','','','TEST:',''],
                            ['Time','Ep','Ct','LOSS','PPL','NLL','KL','REG','LOSS','PPL','NLL','KL','REG','LOSS','PPL']]))))

cmd_rm = 'rm -r %s' % config.dir_model
res = subprocess.call(cmd_rm.split())
cmd_mk = 'mkdir %s' % config.dir_model
res = subprocess.call(cmd_mk.split())

def update_checkpoint(config, checkpoint, global_step):
    checkpoint.append(config.path_model + '-%i' % global_step)
    if len(checkpoint) > config.max_to_keep:
        path_model = checkpoint.pop(0) + '.*'
        for p in glob.glob(path_model):
            os.remove(p)
    cPickle.dump(checkpoint, open(config.path_checkpoint, 'wb'))

## initialize model

In [8]:
if 'sess' in globals(): sess.close()
if config.model == 'gsm':
    Model = GaussianSoftmaxModel
elif config.model == 'rsm':
    Model = RecurrentStickbreakingModel
model = Model(config)    
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=config.max_to_keep)
update_tree_flg = False

## train & validate model

In [9]:
time_start = time.time()
while epoch < config.n_epochs:
    # train
    for ct, batch in train_batches:
        feed_dict = model.get_feed_dict(batch)
        _, loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, ppls_batch, global_step_log = \
        sess.run([model.opt, model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_reg, model.topic_ppls, tf.train.get_global_step()], feed_dict = feed_dict)

        losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch]]
        ppls_train += list(ppls_batch)

        if global_step_log % config.log_period == 0:
            # validate
            loss_train, topic_loss_recon_train, topic_loss_kl_train, topic_loss_reg_train = np.mean(losses_train, 0)
            ppl_train = np.exp(np.mean(ppls_train))
            loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, probs_topic_dev = validate(sess, dev_batches, model)

            # test
            if ppl_dev < ppl_min:
                ppl_min = ppl_dev
                loss_test, _, _, _, ppl_test, _ = validate(sess, test_batches, model)
                saver.save(sess, config.path_model, global_step=global_step_log)
                cPickle.dump(config, open(config.path_config % global_step_log, 'wb'))
                update_checkpoint(config, checkpoint, global_step_log)
            
            # visualize topic
            topics_freq_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:, :config.n_freq]
            topics_freq_idxs = bow_idxs[topics_freq_indices]
            topics_freq_tokens = [[idx_to_word[idx] for idx in topic_freq_idxs] for topic_freq_idxs in topics_freq_idxs]
            
            # log
            clear_output()
            time_log = int(time.time() - time_start)
            log_series = pd.Series([time_log, epoch, ct, \
                    '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_train, '%.2f'%topic_loss_reg_train, \
                    '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_dev, '%.2f'%topic_loss_reg_dev, \
                    '%.2f'%loss_test, '%.0f'%ppl_test],
                    index=log_df.columns)
            log_df.loc[global_step_log] = log_series
            display(log_df)
            cPickle.dump(log_df, open(os.path.join(config.path_log), 'wb'))
            print_flat_topic_sample(sess, model, topics_freq_tokens=topics_freq_tokens)

            # update tree
            if not config.static:
                config.n_topic, update_flg, diff = model.update_topic(sess, dev_batches)
                print('Diff: %.6f' % diff)
                if update_flg:
                    print('Update to %i' % config.n_topic)
                    name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))} # store paremeters
                    if 'sess' in globals(): sess.close()
                    model = Model(config)
                    sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
                    name_tensors = {tensor.name: tensor for tensor in tf.global_variables()}
                    sess.run([name_tensors[name].assign(variable) for name, variable in name_variables.items()]) # restore parameters
                    saver = tf.train.Saver(max_to_keep=1)
                
            time_start = time.time()

    train_batches = get_batches(instances_train, config.batch_size, iterator=True)
    epoch += 1

display(log_df)
print_flat_topic_sample(sess, model, topics_freq_tokens=topics_freq_tokens)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,TEST:,Unnamed: 15_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL,NLL,KL,REG,LOSS,PPL
5000,40,28,99,599.83,922,596.57,3.14,0.12,582.9,1020,579.4,3.41,0.09,583.0,1023
10000,37,57,24,600.72,920,597.07,3.55,0.1,574.45,959,570.78,3.6,0.07,574.56,961
15000,44,85,124,599.26,916,595.36,3.8,0.09,576.02,957,571.88,4.08,0.06,576.12,960
20000,49,114,49,599.04,913,594.94,4.02,0.08,573.07,952,568.71,4.3,0.06,572.82,949
25000,43,142,149,598.15,910,593.89,4.19,0.08,572.09,943,567.58,4.46,0.04,571.91,942
30000,45,171,74,597.53,907,593.16,4.3,0.07,576.07,957,571.56,4.46,0.05,571.91,942
35000,44,199,174,597.32,905,592.86,4.4,0.07,573.67,961,569.02,4.61,0.05,571.91,942
40000,47,228,99,596.61,903,592.08,4.47,0.07,581.02,987,576.39,4.57,0.06,571.91,942
45000,47,257,24,596.46,901,591.87,4.53,0.07,572.12,947,567.37,4.69,0.05,571.91,942
50000,47,285,124,595.96,900,591.32,4.58,0.07,577.28,960,572.6,4.64,0.05,571.91,942


0 use get write file one work drive also system thanks
1 write go get article say one think know people like
2 god one say christian people write believe jesus think know
3 game team play player win year hockey season league score
4 key use chip encryption government system clipper law information privacy
5 car use bike one get dod ride write article good
6 gun law government people state weapon crime use right firearm
7 israel turkish armenian israeli armenians people jews turkey say arab
8 space center use research information nasa mission health medical university
9 insurance health war south care private secret new nuclear tax
10 man homosexual gay power male government show sexual hitler number
11 people job work go want program young think president make
12 stephanopoulos president say go think know myers make work something
13 president administration government program russia official russian go fund think
14 water april president vote national washington american new energy go

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,TEST:,Unnamed: 15_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL,NLL,KL,REG,LOSS,PPL
5000,40,28,99,599.83,922,596.57,3.14,0.12,582.9,1020,579.4,3.41,0.09,583.0,1023
10000,37,57,24,600.72,920,597.07,3.55,0.1,574.45,959,570.78,3.6,0.07,574.56,961
15000,44,85,124,599.26,916,595.36,3.8,0.09,576.02,957,571.88,4.08,0.06,576.12,960
20000,49,114,49,599.04,913,594.94,4.02,0.08,573.07,952,568.71,4.3,0.06,572.82,949
25000,43,142,149,598.15,910,593.89,4.19,0.08,572.09,943,567.58,4.46,0.04,571.91,942
30000,45,171,74,597.53,907,593.16,4.3,0.07,576.07,957,571.56,4.46,0.05,571.91,942
35000,44,199,174,597.32,905,592.86,4.4,0.07,573.67,961,569.02,4.61,0.05,571.91,942
40000,47,228,99,596.61,903,592.08,4.47,0.07,581.02,987,576.39,4.57,0.06,571.91,942
45000,47,257,24,596.46,901,591.87,4.53,0.07,572.12,947,567.37,4.69,0.05,571.91,942
50000,47,285,124,595.96,900,591.32,4.58,0.07,577.28,960,572.6,4.64,0.05,571.91,942


0 use get write file one work drive also system thanks
1 write go get article say one think know people like
2 god one say christian people write believe jesus think know
3 game team play player win year hockey season league score
4 key use chip encryption government system clipper law information privacy
5 car use bike one get dod ride write article good
6 gun law government people state weapon crime use right firearm
7 israel turkish armenian israeli armenians people jews turkey say arab
8 space center use research information nasa mission health medical university
9 insurance health war south care private secret new nuclear tax
10 man homosexual gay power male government show sexual hitler number
11 people job work go want program young think president make
12 stephanopoulos president say go think know myers make work something
13 president administration government program russia official russian go fund think
14 water april president vote national washington american new energy go