In [4]:
%%javascript
IPython.notebook.kernel.execute('nb_name = "' + IPython.notebook.notebook_name + '"')

<IPython.core.display.Javascript object>

In [5]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ['PYTHONHASHSEED'] = '0'
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import argparse
import subprocess
import pdb
import time
import random
import _pickle as cPickle
import matplotlib.pyplot as plt
import glob

%matplotlib inline

import numpy as np
import pandas as pd
import tensorflow as tf

from data_structure import get_batches
from gsm import GaussianSoftmaxModel
from rsm import RecurrentStickbreakingModel
from evaluation import validate, print_flat_topic_sample
from configure import get_config

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# load data & set config

In [6]:
config = get_config(nb_name)

In [7]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
np.random.seed(config.seed)
random.seed(config.seed)

In [8]:
instances_train, categories, word_to_idx, idx_to_word, embeddings = cPickle.load(open(config.path_data, 'rb'))
config.batch_size = len(instances_train)
train_batches = get_batches(instances_train, config.batch_size, iterator=False)
config.dim_bow = len(idx_to_word)
config.dim_emb = embeddings.shape[-1]

In [9]:
def debug(variable, sample_batch=None, sample=False):
    if sample_batch is None: sample_batch = train_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch, mode='eval')
    _variable = sess.run(variable, feed_dict=feed_dict)
    return _variable

# run

## initialize log

In [10]:
checkpoint = []
losses_train = []
ppls_train = []
ppl_min = np.inf
epoch = 0

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','','','','','VALID:','','','','','TEST:',''],
                            ['Time','Ep','Ct','LOSS','PPL','NLL','KL','REG','LOSS','PPL','NLL','KL','REG','LOSS','PPL']]))))

cmd_rm = 'rm -r %s' % config.dir_model
res = subprocess.call(cmd_rm.split())
cmd_mk = 'mkdir %s' % config.dir_model
res = subprocess.call(cmd_mk.split())

def update_checkpoint(config, checkpoint, global_step):
    checkpoint.append(config.path_model + '-%i' % global_step)
    if len(checkpoint) > config.max_to_keep:
        path_model = checkpoint.pop(0) + '.*'
        for p in glob.glob(path_model):
            os.remove(p)
    cPickle.dump(checkpoint, open(config.path_checkpoint, 'wb'))

## initialize model

In [11]:
config.n_topic = 20
config.train_emb = False

In [12]:
if 'sess' in globals(): sess.close()
if config.model == 'gsm':
    Model = GaussianSoftmaxModel
elif config.model == 'rsm':
    Model = RecurrentStickbreakingModel
model = Model(config)    
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
sess.run(tf.global_variables_initializer())
sess.run(model.bow_embeddings.assign(embeddings))
saver = tf.train.Saver(max_to_keep=config.max_to_keep)
update_tree_flg = False

## train & validate model

In [13]:
config.n_epochs = 10000
config.log_period = 500

In [14]:
time_start = time.time()
while epoch < config.n_epochs:
    ct, batch = train_batches[0]
    
    feed_dict = model.get_feed_dict(batch)
    _, loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, ppls_batch, global_step_log = \
    sess.run([model.opt, model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_reg, model.topic_ppls, tf.train.get_global_step()], feed_dict = feed_dict)

    losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch]]
    ppls_train += list(ppls_batch)

    if global_step_log % config.log_period == 0:
        # validate
        loss_train, topic_loss_recon_train, topic_loss_kl_train, topic_loss_reg_train = np.mean(losses_train, 0)
        ppl_train = np.exp(np.mean(ppls_train))
        loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, probs_topic_dev = 0, 0, 0, 0, 0, 0

        # test
        if ppl_train < ppl_min:
            ppl_min = ppl_dev
            loss_test, _, _, _, ppl_test, _ = 0, 0, 0, 0, 0, 0
            saver.save(sess, config.path_model, global_step=global_step_log)
            cPickle.dump(config, open(config.path_config % global_step_log, 'wb'))
            update_checkpoint(config, checkpoint, global_step_log)

        # visualize topic
        topics_freq_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:, :config.n_freq]
        topics_freq_idxs = topics_freq_indices
        topics_freq_tokens = [[idx_to_word[idx] for idx in topic_freq_idxs] for topic_freq_idxs in topics_freq_idxs]

        # log
        clear_output()
        time_log = int(time.time() - time_start)
        log_series = pd.Series([time_log, epoch, ct, \
                '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_train, '%.2f'%topic_loss_reg_train, \
                '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_dev, '%.2f'%topic_loss_reg_dev, \
                '%.2f'%loss_test, '%.0f'%ppl_test],
                index=log_df.columns)
        log_df.loc[global_step_log] = log_series
        display(log_df)
        cPickle.dump(log_df, open(os.path.join(config.path_log), 'wb'))
        print_flat_topic_sample(sess, model, topics_freq_tokens=topics_freq_tokens)

#         # update tree
#         if not config.static:

#             if update_sequence_flg:
#                 print(config.tree_idxs)
#                 name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))} # store paremeters
#                 if 'sess' in globals(): sess.close()
#                 model = Model(config)
#                 sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
#                 name_tensors = {tensor.name: tensor for tensor in tf.global_variables()}
#                 sess.run([name_tensors[name].assign(variable) for name, variable in name_variables.items()]) # restore parameters
#                 saver = tf.train.Saver(max_to_keep=1)

        time_start = time.time()

    epoch += 1

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,TEST:,Unnamed: 15_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL,NLL,KL,REG,LOSS,PPL
500,3,499,0,110234.72,5779,110135.62,99.07,0.01,0.0,0,0.0,0.0,0.0,0.0,0
1000,2,999,0,109280.47,5058,109183.59,96.77,0.01,0.0,0,0.0,0.0,0.0,0.0,0
1500,2,1499,0,108824.63,4680,108731.34,93.05,0.01,0.0,0,0.0,0.0,0.0,0.0,0
2000,2,1999,0,108539.26,4446,108449.98,89.16,0.01,0.0,0,0.0,0.0,0.0,0.0,0
2500,2,2499,0,108337.69,4277,108251.81,85.69,0.01,0.0,0,0.0,0.0,0.0,0.0,0
3000,2,2999,0,108185.61,4141,108102.53,82.88,0.01,0.0,0,0.0,0.0,0.0,0.0,0
3500,2,3499,0,108064.65,4032,107984.1,80.34,0.01,0.0,0,0.0,0.0,0.0,0.0,0
4000,2,3999,0,107965.11,3942,107886.95,77.98,0.01,0.0,0,0.0,0.0,0.0,0.0,0
4500,2,4499,0,107879.38,3871,107803.23,75.89,0.01,0.0,0,0.0,0.0,0.0,0.0,0
5000,2,4999,0,107803.99,3813,107729.83,73.99,0.01,0.0,0,0.0,0.0,0.0,0.0,0


0 vascular_grafts vascular_access vascular_graft hemodialysis_access vascular_prostheses vascular_prosthesis bone_regeneration stent-grafts guided_bone_regeneration guided_tissue_regeneration
1 gtr gbr limb_salvage tract_reconstruction seal tract_reconstructions vein_graft dorsal_augmentation lumbar_discectomy ptfe_/_vein_cuff
2 chordal_replacement femoropopliteal_bypass greases femoropopliteal_bypasses gaskets seals mitral_valves mitral_valve_repair mitral_valve valves
3 bearings tips treatment implants grease hepatic_tissue solid_lubricant journal_bearings solid_lubricants solid_sorbent
4 electrodes electrode cathode electrolyte cell cells anode cathodes binder devices
5 fuel_cell pemfc ptfe_transfer_film ptfe_grafts fuel_cell_performance proton_exchange_membrane_fuel_cell pemfcs pem_fuel_cell catalyst_ink ptfe_/_nafion_membrane_electrode_assembly
6 membrane_distillation direct_contact_membrane_distillation water_treatment vacuum_membrane_distillation wastewater_treatment proton_exch

# visualize topics

In [57]:
def print_freq_uses(category_index, n_topics=10, n_freq=100, n_known=10, n_unknown=10):
    topics_use_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:n_freq]
    
    instance = train_batches[0][1][category_index]
    known_use_indices = np.where(instance.bow>0)[0]
    category = categories[category_index]

    prob_topics = sess.run(model.prob_topic, feed_dict=model.get_feed_dict(batch=[instance], mode='infer'))[0]
    sorted_topic_indices = np.argsort(prob_topics)[::-1]
    
    print(category)
    for topic_index in sorted_topic_indices[:n_topics]:
        prob_topic = prob_topics[topic_index]
        if prob_topic < 0.01: break
        
        topic_use_indices = topics_use_indices[topic_index]
        unknown_topic_use_indices = [index for index in topic_use_indices if index not in known_use_indices]
        print(topic_index, '%.3f'%prob_topic, \
                  ' '.join([idx_to_word[index] for index in topic_use_indices[:n_known]]), \
                  ' '.join(['<%s>'%idx_to_word[index] for index in unknown_topic_use_indices[:n_unknown]]))

In [58]:
print_freq_uses(category_index=-1)

pvdf
17 0.459 mf md structural_pvdf_sensor libs cnt_/_pvdf_membrane pvdf_sensors da resultant_nanofiber_pvdf_membranes health_monitoring pure_pvdf_sensors <da> <nerve_cells> <sensing_membranes> <dopamine_sensor> <cell_system> <optical_system> <multifunctional_supports> <evmd> <single_cell_evolution> <phase_separators>
8 0.188 ferro-actuator dssc transducers transducer dmfc actuator actuators anode dsscs direct_methanol_fuel_cell <charge_capacitors> <electric_conductor> <zinc_electrode> <ecs> <plate_materials> <saline_organic_wastewater> <polymer_electrolyte_fuel_cells> <roofs> <organic_electric_insulator> <electret_transducers>
19 0.087 lithium-ion_batteries lithium_ion_batteries energy_harvesting applications lithium-ion_battery lithium_batteries energy_harvesting_applications energy_harvesters energy_harvester lithium_ion_battery <biomaterial_applications> <polymer_electrolyte_fuel_cells> <tissue_engineering> <lithium-> <space_applications> <biomedical_industry> <ic_application> <ele

In [59]:
print_freq_uses(category_index=-2)

ptfe
1 0.309 gtr gbr limb_salvage tract_reconstruction seal tract_reconstructions vein_graft dorsal_augmentation lumbar_discectomy ptfe_/_vein_cuff <repair_patch> <catheter> <iol> <pd> <protect_motoneurons> <posterior_chamber_iols> <toec> <barriers> <scpcs> <collision_surface>
3 0.239 bearings tips treatment implants grease hepatic_tissue solid_lubricant journal_bearings solid_lubricants solid_sorbent <thrust_ball_bearings> <new_solid_lubricant> <microwave_absorbers> <pre-treatment> <food_packages> <semisolid_lubricants> <meshed_implants> <open_valves> <vapor_phase_condensation_reflow> <liquid_lubricant_research>
2 0.131 chordal_replacement femoropopliteal_bypass greases femoropopliteal_bypasses gaskets seals mitral_valves mitral_valve_repair mitral_valve valves <fluorination_agent> <electricity> <gas_industry> <modifier> <pervaporative_separation> <bond_papers> <oilfield> <transistors> <high_pressure_valves> <triboelectric_devices>
5 0.098 fuel_cell pemfc ptfe_transfer_film ptfe_graft

In [60]:
print_freq_uses(category_index=0)

etfe
10 0.630 building_sector building_façade building_roofs building_envelopes building_applications building_materials buildings smart_structure photovoltaics civil_building <smart_structure> <photovoltaics> <thermal_insulation> <photovoltaic_cell> <solar_panels> <polymer_electrolyte_fuel_cell> <electrical_wires> <gmf_fuel_cell> <dmfcs> <polymer_electrolyte_fuel_cells>
18 0.193 leads etfe_cushion_roof etfe_foils etfe_cushion_models etfe_cushion_model pv-etfe_cushion etfe_foil_spring_cushion air_systems etfe_foil_cushions cushion_structures <air_systems> <"_at-risk_"_systems> <"_at-risk_"_high-pressure_oxygen> <hydraulic_fluid> <optical_window> <helicopters> <high_pressure_valves> <jackets> <ppts> <valve_seats>
12 0.134 fuel_cell_membranes pumps fuel_cells coatings ion_exchange_membranes surface_activation_treatments display proton_exchange_membranes laser_welding diodes <diodes> <greases> <aems> <proton> <ion-exchange_membranes> <other_applications> <thickeners> <eyeglasses> <anticor