In [38]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ['PYTHONHASHSEED'] = '0'
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import argparse
import subprocess
import pdb
import time
import random
import _pickle as cPickle
import matplotlib.pyplot as plt

%matplotlib inline

import numpy as np
import pandas as pd
import tensorflow as tf

from data_structure import get_batches
from hntm import HierarchicalNeuralTopicModel
from tree import get_descendant_idxs
from evaluation import validate, get_topic_specialization, get_hierarchical_affinity, print_topic_sample
from coherence import compute_word_count, compute_coherence
from configure import get_config
from ncrp import get_docs, get_freq_tokens_ncrp

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [50]:
def load_model(config, name_model, nb_name, index=-1):
    dir_model = os.path.join('model', config.data, name_model, ''.join(nb_name.split()[1:]))
    ckpt = cPickle.load(open(os.path.join(dir_model, 'checkpoint'), 'rb'))
    path_restore = ckpt[index]
    
    if name_model == 'hntm':
        path_config = path_restore + '.config'
        config = cPickle.load(open(path_config, 'rb'))
        model = HierarchicalNeuralTopicModel(config)
        sess = tf.Session()
        saver = tf.train.Saver()
        saver.restore(sess, path_restore)         
        return sess, model, config
    elif name_model == 'ncrp':
        test_docs, topic_root = cPickle.load(open(path_restore, 'rb'))
        config = get_config(nb_name)
        return test_docs, topic_root, config

In [48]:
def get_freq_tokens(sess, model, topic_freq_tokens=None, parent_idx=0, depth=0):
    if depth == 0:
        topics_freq_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:, :10]
        topics_freq_idxs = bow_idxs[topics_freq_indices]
        topic_freq_tokens = {topic_idx: [idx_to_word[idx] for idx in topic_freq_idxs] for topic_idx, topic_freq_idxs in zip(model.topic_idxs, topics_freq_idxs)}
        
        # print root
        freq_tokens = topic_freq_tokens[parent_idx]
        print(parent_idx, ' '.join(freq_tokens))
    
    child_idxs = model.tree_idxs[parent_idx]
    depth += 1
    for child_idx in child_idxs:
        freq_tokens = topic_freq_tokens[child_idx]
        print('  '*depth, child_idx, ' '.join(freq_tokens))
        
        if child_idx in model.tree_idxs: 
            get_freq_tokens(sess, model, topic_freq_tokens=topic_freq_tokens, parent_idx=child_idx, depth=depth)
            
    return topic_freq_tokens

# bags

## load data

In [34]:
nb_name_base = '0 bags'
config_bags = get_config(nb_name_base)
_, _, instances_bags, word_to_idx_bags, idx_to_word_bags, bow_idxs_bags = cPickle.load(open(config_bags.path_data,'rb'))
bags_batches = get_batches(instances_bags, batch_size=config_bags.batch_size)

## restore hntm

In [62]:
if 'sess' in globals(): sess.close()
# sess, model_bags_hntm, config_bags_hntm = load_model(config=config_bags, name_model = 'hntm', nb_name = '1  bags -tree 33 -temp 1 -seed 0', index=-1)
# sess, model_bags_hntm, config_bags_hntm = load_model(config=config_bags, name_model = 'hntm', nb_name = '2 bags -tree 33 -temp 10 -seed 0', index=-1)
sess, model_bags_hntm, config_bags_hntm = load_model(config=config_bags, name_model = 'hntm', nb_name = '3 bags -tree 33 -temp 10 -seed 0 -min', index=-1)
log_bags_hntm = cPickle.load(open(model_bags_hntm.config.path_log, 'rb'))
log_bags_hntm[-10:]

INFO:tensorflow:Restoring parameters from model/bags/hntm/bags-tree33-temp10-seed0-min/model-480000


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,TEST:,Unnamed: 15_level_0,SPEC:,Unnamed: 17_level_0,Unnamed: 18_level_0,HIER:,Unnamed: 20_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL,NLL,KL,REG,LOSS,PPL,1,2,3,CHILD,OTHER
450000,59,901,400,109.75,410,107.39,2.34,0.01,103.0,399,100.6,2.4,0.0,101.14,395,0.34,0.61,0.64,0.38,0.06
455000,59,911,410,109.74,410,107.39,2.34,0.01,103.01,399,100.6,2.41,0.0,101.14,395,0.34,0.6,0.64,0.39,0.06
460000,59,921,420,109.74,410,107.39,2.35,0.01,102.99,399,100.58,2.41,0.0,101.14,395,0.34,0.6,0.64,0.39,0.06
465000,59,931,430,109.74,410,107.38,2.35,0.01,102.97,398,100.56,2.41,0.0,101.14,395,0.34,0.6,0.64,0.38,0.06
470000,59,941,440,109.74,410,107.38,2.35,0.01,102.98,399,100.57,2.41,0.0,101.14,395,0.34,0.6,0.64,0.39,0.06
475000,59,951,450,109.74,409,107.37,2.35,0.01,102.97,399,100.58,2.39,0.0,101.14,395,0.34,0.6,0.64,0.39,0.06
480000,59,961,460,109.73,409,107.37,2.35,0.01,102.96,398,100.55,2.41,0.0,101.15,395,0.34,0.6,0.64,0.38,0.06
485000,59,971,470,109.73,409,107.37,2.36,0.01,102.97,398,100.55,2.41,0.0,101.15,395,0.34,0.6,0.64,0.37,0.06
490000,59,981,480,109.73,409,107.36,2.36,0.01,102.97,398,100.55,2.42,0.0,101.15,395,0.34,0.6,0.64,0.37,0.06
495000,59,991,490,109.73,409,107.36,2.36,0.01,102.97,399,100.56,2.4,0.0,101.15,395,0.34,0.6,0.64,0.37,0.06


In [63]:
freq_tokens_bags_hntm = get_freq_tokens(sess, model_bags_hntm)
coherence_bags_hntm = compute_coherence(freq_tokens_bags_hntm.values(), config_bags.dir_corpus, topns=[5, 10])

0 bought quality price 'm ... time buy 've nice back
   1 strap pocket pockets side inside shoulder zipper small compartment nice
     11 carry pockets room compartments back pack comfortable books space straps
     12 mouse power netbook cord pocket drive room usb cords external
     14 ipad room carry perfect charger pocket small love extra front
   6 sleeve protection air pro inside smell protect inch zipper snug
     62 & ; perfectly size big hp laptops tablet dell room
     63 sleeve perfectly protection inch chromebook inside protect nice bought perfect
   4 mac ! pro love air book recommend perfectly protect easy
     42 ! love perfect recommend color loves absolutely buy awesome ...
   2 bottom cover top scratches plastic apple speck easily hard feet
     21 color cover keyboard love perfectly blue picture pink screen protector
Average Topic Coherence = 0.104
Median Topic Coherence = 0.097


## restore ncrp

In [29]:
bags_docs, topic_bags, config_bags_ncrp = load_model(config=config_bags, name_model = 'ncrp', nb_name = '0 bags -m ncrp -alp 10 5 1 -eta 1 -gam 0.01')

In [33]:
log_bags_ncrp = cPickle.load(open(config_bags_ncrp.path_log, 'rb'))
log_bags_ncrp

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,VALID:,TEST:,SPEC:,Unnamed: 8_level_0,Unnamed: 9_level_0,HIER:,Unnamed: 11_level_0
Unnamed: 0_level_1,Time,Ep,Ct,PPL,PPL,PPL,PPL,1,2,3,CHILD
0,199,0,0,367,349,348,0.09,0.5,0.64,0.47,0.34
1,242,1,0,328,320,322,0.1,0.53,0.66,0.44,0.29
2,266,2,0,317,315,312,0.1,0.52,0.67,0.46,0.29
3,286,3,0,306,305,304,0.1,0.54,0.65,0.44,0.29


In [45]:
freq_tokens_bags_ncrp = get_freq_tokens_ncrp(topic_bags, idx_to_word_bags, bow_idxs_bags)
coherence_bags_ncrp = compute_coherence(freq_tokens_bags_ncrp.values(), config_bags.dir_corpus, topns=[5, 10])

 0 31943 292848.0 ! bought nice quality price made perfectly perfect love room
   0-1 5038 32689.0 ! color love mac cover pro recommend apple air sleeve
     0-1-1 60 153.0 cover keyboard keys typing key protector board type marks light
     0-1-2 3035 4230.0 cover keyboard pink perfectly love purple received picture blue compliments
     0-1-3 336 693.0 cover keyboard shell mcover hard mbp highly clear purchased unibody
     0-1-4 1095 1474.0 blue & smell retina logo ; cover rubberized picture perfectly
     0-1-5 332 285.0 major searching ca shipped samsung paid past pop order feeling
     0-1-6 180 185.0 kuzy service saved protected notebooks issues felt returned heavy velcro
   0-2 3227 26145.0 carry pockets pack shoulder books comfortable straps pocket back travel
     0-2-1 386 622.0 : pros smaller cons pocket lap padded back started company
     0-2-2 2175 3164.0 comfortable security bags trip plenty carry stuff space items folders
     0-2-3 488 673.0 rolling heavy office files