In [1]:
%%javascript
IPython.notebook.kernel.execute('nb_name = "' + IPython.notebook.notebook_name + '"')

<IPython.core.display.Javascript object>

In [2]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ['PYTHONHASHSEED'] = '0'
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import pdb
import _pickle as cPickle
import time
import subprocess
import glob

import random
import numpy as np
import pandas as pd
import tensorflow as tf

from collections import defaultdict, Counter
from rcrp import Topic, Doc, init, sample, get_perplexity, get_topic_specialization, get_hierarchical_affinities, get_freq_tokens_rcrp, get_docs
from configure import get_config

# load config & data 

In [3]:
config = get_config(nb_name)
np.random.seed(config.seed)
random.seed(config.seed)

In [4]:
instances_train, instances_dev, instances_test, word_to_idx, idx_to_word, bow_idxs = cPickle.load(open(config.path_data,'rb'))

In [5]:
config.n_doc = len(instances_train)
config.n_vocab = len(bow_idxs)
config.n_doc, config.n_vocab

(9006, 1995)

In [6]:
docs_bow = [instance.bow for instance in instances_train]
docs_raw = [[[bow_index]*int(doc_bow[bow_index]) for bow_index in np.where(doc_bow > 0)[0]] for doc_bow in docs_bow]
docs_words = [[idx for idxs in doc for idx in idxs] for doc in docs_raw]
np.sum([len(doc_words) for doc_words in docs_words])

695746

# run

## initialize log

In [7]:
checkpoint = []
ppl_min = np.inf
epoch = 0

cmd_rm = 'rm -r %s' % config.dir_model
res = subprocess.call(cmd_rm.split())
cmd_mk = 'mkdir %s' % config.dir_model
res = subprocess.call(cmd_mk.split())

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','VALID:','TEST:','SPEC:', '', '', 'HIER:', ''],
                            ['Time','Ep','Ct','PPL','PPL', 'PPL','1', '2', '3', 'CHILD', 'OTHER']]))))

def update_checkpoint(config, checkpoint, epoch):
    checkpoint.append(config.path_model + '-%i' % epoch)
    if len(checkpoint) > config.max_to_keep:
        path_model = checkpoint.pop(0)
        for p in glob.glob(path_model):
            os.remove(p)
    cPickle.dump(checkpoint, open(config.path_checkpoint, 'wb'))

## initialize data

In [8]:
topic_root = Topic(idx='0', sibling_idx=0, parent=None, depth=1, config=config)
train_docs = get_docs(instances_train, config, topic_root=topic_root)
dev_docs = get_docs(instances_dev, config, topic_root=topic_root)
test_docs = get_docs(instances_test, config, topic_root=topic_root)
init(train_docs, dev_docs, test_docs, topic_root)

## run

In [None]:
time_start = time.time()
while epoch < config.n_epochs:
    sample(train_docs, dev_docs, test_docs, topic_root)
#     ppl_train = get_perplexity(train_docs, topic_root)
    ppl_train = 0.
    ppl_dev = get_perplexity(dev_docs, topic_root)
    
    if ppl_dev < ppl_min:
        ppl_min = ppl_dev
        ppl_test = get_perplexity(test_docs, topic_root)
        cPickle.dump([test_docs, topic_root], open(config.path_model + '-%i'%epoch, 'wb'))
        update_checkpoint(config, checkpoint, epoch)
        
    depth_spec = get_topic_specialization(test_docs, topic_root)
    hierarchical_affinities = get_hierarchical_affinities(topic_root)
    
    clear_output()
    time_log = int(time.time() - time_start)
    time_start = time.time()
    log_series = pd.Series([time_log, epoch, 0, \
            '%.0f'%ppl_train, '%.0f'%ppl_dev, '%.0f'%ppl_test, \
            '%.2f'%depth_spec[1], '%.2f'%depth_spec[2] if 2 in depth_spec else 0, '%.2f'%depth_spec[3] if 3 in depth_spec else 0, \
            '%.2f'%hierarchical_affinities[0], '%.2f'%hierarchical_affinities[1]],
            index=log_df.columns)
    log_df.loc[epoch] = log_series
    display(log_df)
    get_freq_tokens_rcrp(topic_root, idx_to_word, bow_idxs)
    
    cPickle.dump(log_df, open(config.path_log, 'wb'))
    epoch += 1

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,VALID:,TEST:,SPEC:,Unnamed: 8_level_0,Unnamed: 9_level_0,HIER:,Unnamed: 11_level_0
Unnamed: 0_level_1,Time,Ep,Ct,PPL,PPL,PPL,1,2,3,CHILD,OTHER
0,1198,0,0,0,1060,796,0.02,0.49,0.54,0.85,0.32
1,2843,1,0,0,1018,751,0.02,0.48,0.55,0.63,0.34


   0 : ['0-1', '0-2', '0-3', '0-4', '0-5', '0-6', '0-7', '0-8', '0-9', '0-10', '0-11', '0-12', '0-13', '0-14', '0-15', '0-16', '0-17', '0-18', '0-19', '0-20', '0-21', '0-23', '0-24', '0-25', '0-26', '0-27', '0-28', '0-29', '0-30', '0-31', '0-32', '0-33', '0-34', '0-35', '0-36', '0-37', '0-38', '0-39', '0-40', '0-41', '0-42', '0-43'] 2038 48039.0 ['use', 'write', 'one', 'get', 'article', 'know', 'like', 'make', 'say', 'go']
     0-1 : [] 882 63484.0 ['god', 'one', 'write', 'say', 'people', 'think', 'article', 'make', 'know', 'believe']
     0-2 : [] 276 23543.0 ['god', 'jesus', 'say', 'one', 'people', 'know', 'write', 'christian', 'christ', 'come']
     0-3 : [] 709 36636.0 ['use', 'image', 'file', 'available', 'software', 'write', 'also', 'ftp', 'include', 'version']
     0-4 : ['0-4-1'] 103 6803.0 ['msg', 'food', 'write', 'one', 'article', 'people', 'get', 'use', 'know', 'make']
       0-4-1 : [] 3 165.0 ['doug', 'build', 'error', 'western', 'problem', 'apr', 'gmt', 'say', 'load', 'li