In [1]:
%%javascript
IPython.notebook.kernel.execute('nb_name = "' + IPython.notebook.notebook_name + '"')

<IPython.core.display.Javascript object>

In [2]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ['PYTHONHASHSEED'] = '0'
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import argparse
import subprocess
import pdb
import time
import random
import _pickle as cPickle
import matplotlib.pyplot as plt
import glob

%matplotlib inline

import numpy as np
import pandas as pd
import tensorflow as tf

from data_structure import get_batches
from gsm import GaussianSoftmaxModel
from rsm import RecurrentStickbreakingModel
from evaluation import validate, print_flat_topic_sample
from configure import get_config

# load data & set config

In [3]:
config = get_config(nb_name)

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
np.random.seed(config.seed)
random.seed(config.seed)

In [5]:
instances_train, categories, word_to_idx, idx_to_word, embeddings = cPickle.load(open(config.path_data, 'rb'))
config.batch_size = len(instances_train)
train_batches = get_batches(instances_train, config.batch_size, iterator=False)
config.dim_bow = len(idx_to_word)
config.dim_emb = embeddings.shape[-1]

In [6]:
def debug(variable, sample_batch=None, sample=False):
    if sample_batch is None: sample_batch = train_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch, mode='eval')
    _variable = sess.run(variable, feed_dict=feed_dict)
    return _variable

# run

## initialize log

In [7]:
checkpoint = []
losses_train = []
ppls_train = []
ppl_min = np.inf
epoch = 0

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','','','','','VALID:','','','','','TEST:',''],
                            ['Time','Ep','Ct','LOSS','PPL','NLL','KL','REG','LOSS','PPL','NLL','KL','REG','LOSS','PPL']]))))

cmd_rm = 'rm -r %s' % config.dir_model
res = subprocess.call(cmd_rm.split())
cmd_mk = 'mkdir %s' % config.dir_model
res = subprocess.call(cmd_mk.split())

def update_checkpoint(config, checkpoint, global_step):
    checkpoint.append(config.path_model + '-%i' % global_step)
    if len(checkpoint) > config.max_to_keep:
        path_model = checkpoint.pop(0) + '.*'
        for p in glob.glob(path_model):
            os.remove(p)
    cPickle.dump(checkpoint, open(config.path_checkpoint, 'wb'))

## initialize model

In [8]:
config.n_topic = 20
config.train_emb = False

In [9]:
if 'sess' in globals(): sess.close()
if config.model == 'gsm':
    Model = GaussianSoftmaxModel
elif config.model == 'rsm':
    Model = RecurrentStickbreakingModel
model = Model(config)    
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
sess.run(tf.global_variables_initializer())
sess.run(model.bow_embeddings.assign(embeddings))
saver = tf.train.Saver(max_to_keep=config.max_to_keep)
update_tree_flg = False

## train & validate model

In [10]:
config.n_epochs = 10000
config.log_period = 500

In [11]:
time_start = time.time()
while epoch < config.n_epochs:
    ct, batch = train_batches[0]
    
    feed_dict = model.get_feed_dict(batch)
    _, loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, ppls_batch, global_step_log = \
    sess.run([model.opt, model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_reg, model.topic_ppls, tf.train.get_global_step()], feed_dict = feed_dict)

    losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch]]
    ppls_train += list(ppls_batch)

    if global_step_log % config.log_period == 0:
        # validate
        loss_train, topic_loss_recon_train, topic_loss_kl_train, topic_loss_reg_train = np.mean(losses_train, 0)
        ppl_train = np.exp(np.mean(ppls_train))
        loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, probs_topic_dev = 0, 0, 0, 0, 0, 0

        # test
        if ppl_train < ppl_min:
            ppl_min = ppl_dev
            loss_test, _, _, _, ppl_test, _ = 0, 0, 0, 0, 0, 0
            saver.save(sess, config.path_model, global_step=global_step_log)
            cPickle.dump(config, open(config.path_config % global_step_log, 'wb'))
            update_checkpoint(config, checkpoint, global_step_log)

        # visualize topic
        topics_freq_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:, :config.n_freq]
        topics_freq_idxs = topics_freq_indices
        topics_freq_tokens = [[idx_to_word[idx] for idx in topic_freq_idxs] for topic_freq_idxs in topics_freq_idxs]

        # log
        clear_output()
        time_log = int(time.time() - time_start)
        log_series = pd.Series([time_log, epoch, ct, \
                '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_train, '%.2f'%topic_loss_reg_train, \
                '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_dev, '%.2f'%topic_loss_reg_dev, \
                '%.2f'%loss_test, '%.0f'%ppl_test],
                index=log_df.columns)
        log_df.loc[global_step_log] = log_series
        display(log_df)
        cPickle.dump(log_df, open(os.path.join(config.path_log), 'wb'))
        print_flat_topic_sample(sess, model, topics_freq_tokens=topics_freq_tokens)

#         # update tree
#         if not config.static:

#             if update_sequence_flg:
#                 print(config.tree_idxs)
#                 name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))} # store paremeters
#                 if 'sess' in globals(): sess.close()
#                 model = Model(config)
#                 sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
#                 name_tensors = {tensor.name: tensor for tensor in tf.global_variables()}
#                 sess.run([name_tensors[name].assign(variable) for name, variable in name_variables.items()]) # restore parameters
#                 saver = tf.train.Saver(max_to_keep=1)

        time_start = time.time()

    epoch += 1

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,TEST:,Unnamed: 15_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL,NLL,KL,REG,LOSS,PPL
500,3,499,0,110359.08,6178,110280.55,78.53,0.01,0.0,0,0.0,0.0,0.0,0.0,0
1000,2,999,0,109386.16,5498,109309.41,76.76,0.01,0.0,0,0.0,0.0,0.0,0.0,0
1500,2,1499,0,108921.2,5044,108846.84,74.46,0.01,0.0,0,0.0,0.0,0.0,0.0,0
2000,2,1999,0,108628.47,4768,108556.37,72.11,0.01,0.0,0,0.0,0.0,0.0,0.0,0
2500,2,2499,0,108421.54,4567,108351.52,70.04,0.01,0.0,0,0.0,0.0,0.0,0.0,0
3000,2,2999,0,108263.84,4415,108195.5,68.3,0.01,0.0,0,0.0,0.0,0.0,0.0,0
3500,2,3499,0,108136.66,4295,108069.9,66.76,0.01,0.0,0,0.0,0.0,0.0,0.0,0
4000,2,3999,0,108030.84,4196,107965.3,65.46,0.01,0.0,0,0.0,0.0,0.0,0.0,0
4500,2,4499,0,107939.57,4115,107875.29,64.34,0.01,0.0,0,0.0,0.0,0.0,0.0,0
5000,2,4999,0,107859.81,4044,107796.43,63.3,0.01,0.0,0,0.0,0.0,0.0,0.0,0


0 vascular_grafts vascular_access vascular_graft hemodialysis_access vascular_prostheses vascular_prosthesis bone_regeneration stent-grafts coatings guided_tissue_regeneration
1 gtr gbr lumbar_discectomy tract_reconstruction ptfe_/_vein_cuff tract_reconstructions infected_eptfe_mesh haemodialysis adherent ptfe_graft_reconstruction
2 femoropopliteal_bypass femoropopliteal_bypasses valves femoropopliteal_bypass_grafts femoropopliteal_bypass_graft bypasses seals bypass femoropopliteal_bypass_grafting valve
3 bearings mitral_valve_repair greases chordal_replacement gaskets tips solid_lubricant solid_lubricants repair vesicoureteral_reflux
4 electrodes electrode binder cathode cell electrolyte anode cathodes cells devices
5 fuel_cell ptfe_transfer_film fuel_cell_performance pemfc ptfe_grafts gas_diffusion_layer pemfcs gas_diffusion_layers mfc catalyst_ink
6 hollow_fiber_membranes hollow_fiber_membrane ptfe_hollow_fiber_membranes hollow_fibre_membranes pvdf_hollow_fiber_membranes ptfe_hollow

# visualize topics

In [12]:
def print_freq_uses(category_index, n_topics=10, n_freq=100, n_known=10, n_unknown=10):
    topics_use_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:n_freq]
    
    instance = train_batches[0][1][category_index]
    known_use_indices = np.where(instance.bow>0)[0]
    category = categories[category_index]

    prob_topics = sess.run(model.prob_topic, feed_dict=model.get_feed_dict(batch=[instance], mode='infer'))[0]
    sorted_topic_indices = np.argsort(prob_topics)[::-1]
    
    print(category)
    for topic_index in sorted_topic_indices[:n_topics]:
        prob_topic = prob_topics[topic_index]
        if prob_topic < 0.01: break
        
        topic_use_indices = topics_use_indices[topic_index]
        unknown_topic_use_indices = [index for index in topic_use_indices if index not in known_use_indices]
        print(topic_index, '%.3f'%prob_topic, \
                  ' '.join([idx_to_word[index] for index in topic_use_indices[:n_known]]), \
                  ' '.join(['<%s>'%idx_to_word[index] for index in unknown_topic_use_indices[:n_unknown]]))

In [13]:
print_freq_uses(category_index=-1)

pvdf
8 0.405 transducers transducer dssc pvdf_transducers pvdf-transducers dsscs pvdf_nanoweb_fibers pvdf_transducer ultrasonic_transducers acoustic_nanogenerator <power> <bio-_/_nano-_sensors> <dcs> <nir_dye> <nanofiber_separator> <films> <sensing_membranes> <ultrasonic_motors> <ultrasonic_motor> <concentrator>
12 0.165 mf ferro-actuator md dmfc dmfcs mfcs lib mfc vanadium_redox_flow_battery mbr <mfs> <alkaline_anion_exchange_membrane_water_electrolysis> <oilfield_water> <available_bench_model> <cathode_fouling> <solid-state_supercapacitor> <af> <ferroelectric_liquid_crystals_displays> <lithium_/_sulfur_battery> <high_sulfide_copper_ores>
19 0.121 lithium-ion_batteries lithium_ion_batteries applications polymer_electrolytes lithium-ion_battery lithium_batteries electrolytes li-ion_batteries polymer_electrolyte electronic_applications <lithium-> <biomedical_materials> <lithium_/> <tissue_engineering> <biomaterial_applications> <vivo_applications> <biomedical_material> <polymer_electrol

In [14]:
print_freq_uses(category_index=-2)

ptfe
1 0.378 gtr gbr lumbar_discectomy tract_reconstruction ptfe_/_vein_cuff tract_reconstructions infected_eptfe_mesh haemodialysis adherent ptfe_graft_reconstruction <asc> <repair_patch> <scpcs> <pd> <fibrin-based_vascular_graft> <hard_disk_media> <dvs> <antibiotic_surface_modification> <spinal_cord_repair> <small_flows>
2 0.142 femoropopliteal_bypass femoropopliteal_bypasses valves femoropopliteal_bypass_grafts femoropopliteal_bypass_graft bypasses seals bypass femoropopliteal_bypass_grafting valve <bioreactors> <triboelectric_devices> <paper_industry> <field_coil_insulation_material> <fluorination_agent> <printer_paper> <modifier> <oilfield> <configuration-printed_circuit_board> <glow_discharge_polymerization_method>
3 0.126 bearings mitral_valve_repair greases chordal_replacement gaskets tips solid_lubricant solid_lubricants repair vesicoureteral_reflux <gas_industry> <liquid_lubricant> <bone_restoration> <replacement_lubricants> <mixed_lubricants> <dental_units> <bars> <abdominal

In [60]:
print_freq_uses(category_index=0)

etfe
10 0.630 building_sector building_façade building_roofs building_envelopes building_applications building_materials buildings smart_structure photovoltaics civil_building <smart_structure> <photovoltaics> <thermal_insulation> <photovoltaic_cell> <solar_panels> <polymer_electrolyte_fuel_cell> <electrical_wires> <gmf_fuel_cell> <dmfcs> <polymer_electrolyte_fuel_cells>
18 0.193 leads etfe_cushion_roof etfe_foils etfe_cushion_models etfe_cushion_model pv-etfe_cushion etfe_foil_spring_cushion air_systems etfe_foil_cushions cushion_structures <air_systems> <"_at-risk_"_systems> <"_at-risk_"_high-pressure_oxygen> <hydraulic_fluid> <optical_window> <helicopters> <high_pressure_valves> <jackets> <ppts> <valve_seats>
12 0.134 fuel_cell_membranes pumps fuel_cells coatings ion_exchange_membranes surface_activation_treatments display proton_exchange_membranes laser_welding diodes <diodes> <greases> <aems> <proton> <ion-exchange_membranes> <other_applications> <thickeners> <eyeglasses> <anticor