In [1]:
%%javascript
IPython.notebook.kernel.execute('nb_name = "' + IPython.notebook.notebook_name + '"')

<IPython.core.display.Javascript object>

In [15]:
%load_ext autoreload
%autoreload
from IPython.display import clear_output

import os
os.environ['PYTHONHASHSEED'] = '0'
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import sys
import argparse
import subprocess
import pdb
import time
import random
import _pickle as cPickle
import matplotlib.pyplot as plt
import glob

%matplotlib inline

import numpy as np
import pandas as pd
import tensorflow as tf

from data_structure import get_batches
from gsm import GaussianSoftmaxModel
from rsm import RecurrentStickbreakingModel
from evaluation import validate, print_flat_topic_sample
from configure import get_config

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# load data & set config

In [16]:
config = get_config(nb_name)

In [17]:
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
np.random.seed(config.seed)
random.seed(config.seed)

In [18]:
instances_train, categories, word_to_idx, idx_to_word, embeddings = cPickle.load(open(config.path_data, 'rb'))
config.batch_size = len(instances_train)
train_batches = get_batches(instances_train, config.batch_size, iterator=False)
config.dim_bow = len(idx_to_word)
config.dim_emb = embeddings.shape[-1]

In [19]:
def debug(variable, sample_batch=None, sample=False):
    if sample_batch is None: sample_batch = train_batches[0][1]
    feed_dict = model.get_feed_dict(sample_batch, mode='eval')
    _variable = sess.run(variable, feed_dict=feed_dict)
    return _variable

# run

## initialize log

In [34]:
checkpoint = []
losses_train = []
ppls_train = []
ppl_min = np.inf
epoch = 0

log_df = pd.DataFrame(columns=pd.MultiIndex.from_tuples(
                    list(zip(*[['','','','TRAIN:','','','','','VALID:','','','','','TEST:',''],
                            ['Time','Ep','Ct','LOSS','PPL','NLL','KL','REG','LOSS','PPL','NLL','KL','REG','LOSS','PPL']]))))

cmd_rm = 'rm -r %s' % config.dir_model
res = subprocess.call(cmd_rm.split())
cmd_mk = 'mkdir %s' % config.dir_model
res = subprocess.call(cmd_mk.split())

def update_checkpoint(config, checkpoint, global_step):
    checkpoint.append(config.path_model + '-%i' % global_step)
    if len(checkpoint) > config.max_to_keep:
        path_model = checkpoint.pop(0) + '.*'
        for p in glob.glob(path_model):
            os.remove(p)
    cPickle.dump(checkpoint, open(config.path_checkpoint, 'wb'))

## initialize model

In [35]:
config.n_topic = 20
config.train_emb = False

In [36]:
if 'sess' in globals(): sess.close()
if config.model == 'gsm':
    Model = GaussianSoftmaxModel
elif config.model == 'rsm':
    Model = RecurrentStickbreakingModel
model = Model(config)    
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
sess.run(tf.global_variables_initializer())
sess.run(model.bow_embeddings.assign(embeddings))
saver = tf.train.Saver(max_to_keep=config.max_to_keep)
update_tree_flg = False

## train & validate model

In [37]:
config.n_epochs = 10000
config.log_period = 500

In [38]:
time_start = time.time()
while epoch < config.n_epochs:
    ct, batch = train_batches[0]
    
    feed_dict = model.get_feed_dict(batch)
    _, loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch, ppls_batch, global_step_log = \
    sess.run([model.opt, model.loss, model.topic_loss_recon, model.topic_loss_kl, model.topic_loss_reg, model.topic_ppls, tf.train.get_global_step()], feed_dict = feed_dict)

    losses_train += [[loss_batch, topic_loss_recon_batch, topic_loss_kl_batch, topic_loss_reg_batch]]
    ppls_train += list(ppls_batch)

    if global_step_log % config.log_period == 0:
        # validate
        loss_train, topic_loss_recon_train, topic_loss_kl_train, topic_loss_reg_train = np.mean(losses_train, 0)
        ppl_train = np.exp(np.mean(ppls_train))
        loss_dev, topic_loss_recon_dev, topic_loss_kl_dev, topic_loss_reg_dev, ppl_dev, probs_topic_dev = 0, 0, 0, 0, 0, 0

        # test
        if ppl_train < ppl_min:
            ppl_min = ppl_dev
            loss_test, _, _, _, ppl_test, _ = 0, 0, 0, 0, 0, 0
            saver.save(sess, config.path_model, global_step=global_step_log)
            cPickle.dump(config, open(config.path_config % global_step_log, 'wb'))
            update_checkpoint(config, checkpoint, global_step_log)

        # visualize topic
        topics_freq_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:, :config.n_freq]
        topics_freq_idxs = topics_freq_indices
        topics_freq_tokens = [[idx_to_word[idx] for idx in topic_freq_idxs] for topic_freq_idxs in topics_freq_idxs]

        # log
        clear_output()
        time_log = int(time.time() - time_start)
        log_series = pd.Series([time_log, epoch, ct, \
                '%.2f'%loss_train, '%.0f'%ppl_train, '%.2f'%topic_loss_recon_train, '%.2f'%topic_loss_kl_train, '%.2f'%topic_loss_reg_train, \
                '%.2f'%loss_dev, '%.0f'%ppl_dev, '%.2f'%topic_loss_recon_dev, '%.2f'%topic_loss_kl_dev, '%.2f'%topic_loss_reg_dev, \
                '%.2f'%loss_test, '%.0f'%ppl_test],
                index=log_df.columns)
        log_df.loc[global_step_log] = log_series
        display(log_df)
        cPickle.dump(log_df, open(os.path.join(config.path_log), 'wb'))
        print_flat_topic_sample(sess, model, topics_freq_tokens=topics_freq_tokens)

#         # update tree
#         if not config.static:

#             if update_sequence_flg:
#                 print(config.tree_idxs)
#                 name_variables = {tensor.name: variable for tensor, variable in zip(tf.global_variables(), sess.run(tf.global_variables()))} # store paremeters
#                 if 'sess' in globals(): sess.close()
#                 model = Model(config)
#                 sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1))
#                 name_tensors = {tensor.name: tensor for tensor in tf.global_variables()}
#                 sess.run([name_tensors[name].assign(variable) for name, variable in name_variables.items()]) # restore parameters
#                 saver = tf.train.Saver(max_to_keep=1)

        time_start = time.time()

    epoch += 1

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,TRAIN:,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,VALID:,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,TEST:,Unnamed: 15_level_0
Unnamed: 0_level_1,Time,Ep,Ct,LOSS,PPL,NLL,KL,REG,LOSS,PPL,NLL,KL,REG,LOSS,PPL
500,4,499,0,8694.49,3652,8671.99,22.5,0.01,0.0,0,0.0,0.0,0.0,0.0,0
1000,3,999,0,8622.62,3332,8600.63,21.98,0.01,0.0,0,0.0,0.0,0.0,0.0,0
1500,3,1499,0,8584.94,3154,8563.31,21.63,0.01,0.0,0,0.0,0.0,0.0,0.0,0
2000,3,1999,0,8559.49,3031,8537.99,21.49,0.01,0.0,0,0.0,0.0,0.0,0.0,0
2500,3,2499,0,8540.52,2940,8519.04,21.46,0.01,0.0,0,0.0,0.0,0.0,0.0,0
3000,3,2999,0,8525.68,2869,8504.24,21.43,0.01,0.0,0,0.0,0.0,0.0,0.0,0
3500,3,3499,0,8513.64,2812,8492.24,21.39,0.01,0.0,0,0.0,0.0,0.0,0.0,0
4000,3,3999,0,8503.58,2765,8482.23,21.34,0.01,0.0,0,0.0,0.0,0.0,0.0,0
4500,3,4499,0,8495.0,2726,8473.7,21.3,0.01,0.0,0,0.0,0.0,0.0,0.0,0
5000,3,4999,0,8487.57,2693,8466.31,21.25,0.01,0.0,0,0.0,0.0,0.0,0.0,0


0 gtr gbr guided_bone_regeneration mitral_valve dorsal_augmentation mitral_chordal_substitute guided_tissue_regeneration lumbar_discectomy nasal_augmentation mitral_valve_repair
1 water_purification mbr absorbent membrane_bioreactor pvdf_membrane waste_water_purification pvdf_membranes pure_water membrane_separation pre-treatment
2 valves piston_rings high-pressure_oxygen_systems seals greases gaskets microwave_devices circuit_boards power_systems durable_bond
3 binder electrodes electrode cathode anode cathodes binders electrolyte battery batteries
4 ferro-actuator mf direct_methanol_fuel_cell dmfc vanadium_redox_flow_battery direct_methanol_fuel_cells pumps liquid_crystal_display cantilever-type_ferro-actuators dmfcs
5 lubricants magnetic_disks magnetic_disk lubricant_fluids hard_disk_lubricants lubricant pfpe_lubricants greases hard_disk_drive hard_disk_drives
6 sensors sensor actuator binder transducer transducers separator actuators binders electrode
7 mea membrane_electrode_assem

# visualize topics

In [48]:
def print_freq_uses(category_index, n_topics=5, n_freq=100, n_known=10, n_unknown=10):
    topics_use_indices = np.argsort(sess.run(model.topic_bow), 1)[:, ::-1][:n_freq]
    
    instance = train_batches[0][1][category_index]
    known_use_indices = np.where(instance.bow>0)[0]
    category = categories[category_index]

    prob_topics = sess.run(model.prob_topic, feed_dict=model.get_feed_dict(batch=[instance], mode='infer'))[0]
    sorted_topic_indices = np.argsort(prob_topics)[::-1]
    
    print(category)
    for topic_index in sorted_topic_indices[:n_topics]:
        prob_topic = prob_topics[topic_index]
        topic_use_indices = topics_use_indices[topic_index]
        unknown_topic_use_indices = [index for index in topic_use_indices if index not in known_use_indices]
        print(topic_index, '%.3f'%prob_topic, \
                  ' '.join([idx_to_word[index] for index in topic_use_indices[:n_known]]), \
                  ' '.join(['<%s>'%idx_to_word[index] for index in unknown_topic_use_indices[:n_unknown]]))

In [49]:
print_freq_uses(category_index=list(categories).index('pvdf'))

pvdf
17 0.247 md nanogenerator ultrasonic_transducers support mf generator nanogenerators stress_gauge novel_nanogenerator cheap_ultrasonic_transducers <gauge> <multi-layer_transducer> <actuator_transducers> <power> <flexible_force_sensor> <flexible_transducer_array> <nuclear_waste> <conductive_layers> <shock_waves> <khz_transducer>
11 0.152 cell hydrophone cells hydrophones charge_amplifier transmitter potential_electronic_active_polymers lipase amplifier pseudo_capacitor <nerve_cells> <charge_capacitors> <base_polymer_film> <catalyst_binder> <medical_polymer> <biocatalytic_active_polymer_membranes> <piezoelectric_polymers> <water-dispersible_binder> <representative_polymeric> <piezoelectric_polymer_hydrophone>
19 0.095 electrolytes lithium-ion_batteries electrolyte lithium_ion_batteries libs polymer_electrolytes polymer_electrolyte lithium_batteries lithium-ion_battery li-ion_batteries <lithium-> <polymer_lithium_ion_batteries> <polymer_electrolyte_separator> <efficient_electronic_de

In [50]:
print_freq_uses(category_index=list(categories).index('ptfe'))

ptfe
8 0.377 gaskets seal gasket seals femoropopliteal_bypasses femoropopliteal_bypass femoropopliteal knee_femoropopliteal_bypasses insulation gas_circuit_breakers <ak_femoropopliteal_bypass> <urinary_incontinence> <femorocrural_bypasses> <femoropopliteal_grafts> <femorocrural_bypass> <bk_bypasses> <femoral-popliteal_bypass> <vesicorenal_reflux> <below-knee_bypasses> <stress_urinary_incontinence>
7 0.155 mea membrane_electrode_assembly fuel_cell pemfcs pemfc proton_exchange_membrane_fuel_cell mfc ptfe_membrane catalyst_ink mfcs <pem_fuel_cell> <optimal_membrane_electrode_assembly> <cardiovascular_surgery> <pefc> <polytetrafluoroethylene-incorporated_membrane_electrode_assembly> <membrane_electrodes> <dielectric_barrier_discharge> <proton-exchange_membrane> <various_electromechanical_devices> <tribological_materials>
16 0.127 vascular_access hemodialysis_access haemodialysis_access hemodialysis ptfe_graft ptfe_grafts hemodialysis_vascular_access blood_access ptfe_graft_hemodialysis_acc

In [51]:
print_freq_uses(category_index=0)

eptfe
0 0.640 gtr gbr guided_bone_regeneration mitral_valve dorsal_augmentation mitral_chordal_substitute guided_tissue_regeneration lumbar_discectomy nasal_augmentation mitral_valve_repair <nasal_augmentation> <surgical_bone_defects> <hernia_meshes> <mitral_repair> <nasal_dorsal_augmentation> <guided_periodontal_tissue_regeneration> <dural_defect> <neochordal_implantation> <mitral_valves> <surgical_barrier>
16 0.184 vascular_access hemodialysis_access haemodialysis_access hemodialysis ptfe_graft ptfe_grafts hemodialysis_vascular_access blood_access ptfe_graft_hemodialysis_access haemodialysis <haemodialysis_access> <ptfe_graft> <blood_access> <ptfe_graft_hemodialysis_access> <limb_salvage> <vascular_accesses> <long-term_vascular_access> <vascular> <hemodialysis_vascular_access_grafts> <uk_hemodialysis_population>
18 0.133 vascular_grafts vascular_graft vascular_prostheses vascular_prosthesis grafts vascular_graft_material ®_grafts stent_grafts vascular_graft_materials filters <®_graft