In [None]:
import numpy as np
import random
import pickle

import os
import time


In [None]:
DATA_DIR = './'

DATA_FILE_NAME = 'v1_data.pkl'
#TEST_DATA_FILE_NAME = 'v1_data_test.pkl'  # contains only 500 samples

VOCA_FILE_NAME = 'v1_dic.pkl'
#GLOVE_FILE_NAME = 'v2_glove.pkl'


In [None]:
class ProcessData:
    
    def __init__(self, is_test):
        
        print('IS_TEST = {}'.format(str(is_test)))
        
        self.is_test = is_test
        self.voca = None
        self.pad_index = 0
        self.index2word = {}
        
        self.train_set = []
        self.valid_set = []
        
        self.load_data()
        self.create_train_set()
        self.create_valid_set()
        
        
    def load_data(self):
        
        if self.is_test:
            self.train_data, self.valid_data, self.test_data = pickle.load(open(DATA_DIR + TEST_DATA_FILE_NAME, 'r'))
            #print 'load data : ' + TEST_DATA_FILE_NAME
        else:
            self.train_data, self.valid_data = pickle.load(open(DATA_DIR + DATA_FILE_NAME, 'rb'))
            #print 'load data : ' + DATA_FILE_NAME
        
        self.voca = pickle.load(open(DATA_DIR + VOCA_FILE_NAME, 'rb') )
        #self.W_glove_init = pickle.load(open(DATA_DIR + GLOVE_FILE_NAME, 'r') )
        
        self.pad_index = self.voca['_PAD_']
        
        for w in self.voca:
            self.index2word[self.voca[w]] = w
        
        #print '[completed] load data'
        print("voca size (include _PAD_, _UNK_): {}".format ( str( len(self.voca)) ) )
        
        
    # create train set : context -> split by using '__EOS__' -> multiple sentneces
    # convert to soucre, target, label
    def create_train_set(self):
        
        data_len = len(self.train_data['c'])
        
        for index in range(data_len):
            
            turn =[x.strip() for x in (' '.join(str(e) for e in self.train_data['c'][index])).split(str(self.voca['./SF']))]
            turn = [ x for x in turn if len(x) >1]
            
            tmp_ids = [x.split(' ') for x in turn]
            source_ids = []
            for sent in tmp_ids:
                source_ids.append( [ int(x) for x in sent]  )
                
            target_ids = self.train_data['r'][index]
            label = float(self.train_data['y'][index])
            
            self.train_set.append( [source_ids, target_ids, label] )
        
        print("[completed] create train set : {}".format( str(len(self.train_set)) ) )
        
        
    # create valid set : context -> split by using '__EOS__' -> multiple sentneces
    # convert to soucre, target, label
    def create_valid_set(self):
        
        data_len = len(self.valid_data['c'])
        
        for index in range(data_len):
            
            turn =[x.strip() for x in (' '.join(str(e) for e in self.valid_data['c'][index])).split(str(self.voca['./SF']))]
            turn = [ x for x in turn if len(x) >1]
            
            tmp_ids = [x.split(' ') for x in turn]
            source_ids = []
            for sent in tmp_ids:
                source_ids.append( [ int(x) for x in sent]  )
                
            target_ids = self.valid_data['r'][index]
            label = float(self.valid_data['y'][index])
            
            self.valid_set.append( [source_ids, target_ids, label] )
        
        print("[completed] create valid set : {}".format( str(len(self.valid_set)) ) )
        
        
    def get_batch(self, data, batch_size, encoder_size, context_size, encoderR_size, is_test, start_index=0, target_index=1):

        encoder_inputs, encoderR_inputs, encoder_seq, context_seq, encoderR_seq, target_labels = [], [], [], [], [], []
        index = start_index
        
        # Get a random batch of encoder and encoderR inputs from data,
        # pad them if needed

        for _ in range(batch_size):

            if is_test is False:
                list_encoder_input, encoderR_input, target_label = random.choice(data)
            else:
                list_encoder_input = data[index][0]
                #encoderR_input = data[index][1][target_index]
                encoderR_input = data[index][1]
                #index = index +1
    
            list_len = len( list_encoder_input )
            tmp_encoder_inputs = []
            tmp_encoder_seq = []
            
            for en_input in list_encoder_input:
                encoder_pad = [self.pad_index] * (encoder_size - len( en_input ))
                tmp_encoder_inputs.append( (en_input + encoder_pad)[:encoder_size] )        
                tmp_encoder_seq.append( min( len( en_input ), encoder_size ) )    
            
            # add pad
            for i in range( context_size - list_len ):
                encoder_pad = [self.pad_index] * (encoder_size)
                tmp_encoder_inputs.append( encoder_pad )
                tmp_encoder_seq.append( 0 ) 

            encoder_inputs.extend( tmp_encoder_inputs[-context_size:] )
            encoder_seq.extend( tmp_encoder_seq[-context_size:] )
            
            context_seq.append( min(  len(list_encoder_input), context_size  ) )

            # encoderR inputs are padded
            encoderR_pad = [self.pad_index] * (encoderR_size - len(encoderR_input))
            encoderR_inputs.append( (encoderR_input + encoderR_pad)[:encoderR_size]) 

            encoderR_seq.append( min(len(encoderR_input), encoderR_size) )

            # Target Label for batch
            if is_test is False:
                target_labels.append( int(target_label) )
            else:
                target_labels.append( int(data[index][2]) )
                index = index + 1
                #if target_index is 0:
                #    target_labels.append( int(1) )
                #else:
                #    target_labels.append( int(0) )
                    
                    
        return encoder_inputs, encoderR_inputs, encoder_seq, context_seq, encoderR_seq, np.reshape(target_labels, (batch_size, 1))


In [None]:
def create_dir(dir_name):
    if not os.path.isdir(dir_name):
        os.mkdir(dir_name)
        

In [None]:
batch_size=256
encoder_size=80
context_size=15
encoderR_size=160

# siaseme RNN
num_layer=1
hidden_dim=300

# context RNN
num_layer_con=1
hidden_dim_con=300

embed_size=300
num_train_steps=100000
lr=0.001
valid_freq=500
is_save=1
graph_prefix='HRDE_LTC_korquad_v1_'

is_test=0
use_glove=0

dr=0.3
dr_con=1.0
memory_dr=0.8

# latent topic
memory_dim=256
topic_size=3


In [None]:
graph_dir_name = graph_prefix + \
                '_b' + str(batch_size) + \
                '_es' + str(encoder_size) + \
                '_eRs' + str(encoderR_size) + \
                '_cs' + str(context_size) + \
                '_L' + str(num_layer) + \
                '_H' + str(hidden_dim) + \
                '_Lc' + str(num_layer_con) + \
                '_Hc' + str(hidden_dim_con) + \
                '_G' + str(use_glove) + \
                '_dr' + str(dr)  + \
                '_drc' + str(dr_con) + \
                '_drM' + str(memory_dr) + \
                '_M' + str(memory_dim) + \
                '_T' + str(topic_size)

if is_save is 1:
    create_dir('save/')
    create_dir('save/'+ graph_dir_name )

create_dir('graph/')
create_dir('graph/' + graph_dir_name )


In [None]:
batch_gen = ProcessData(is_test=is_test)


In [None]:
from HRDE_Model_mem_v1 import *
from HRDE_evaluation import *


In [None]:
model = HRDualEncoderModel(voca_size=len(batch_gen.voca),
                               batch_size=batch_size,
                               encoder_size=encoder_size,
                               context_size=context_size,
                               encoderR_size=encoderR_size,
                               num_layer=num_layer,                 
                               hidden_dim=hidden_dim,
                               num_layer_con=num_layer_con,
                               hidden_dim_con=hidden_dim_con,
                               lr=lr,
                               embed_size=embed_size,
                               use_glove = use_glove,
                               dr=dr,
                               dr_con=dr_con,
                               memory_dr = memory_dr,
                               memory_dim = memory_dim,
                               topic_size=topic_size
                          )


In [None]:
model.build_graph()


In [None]:
#train_model(model, batch_gen, num_train_steps, valid_freq, is_save, graph_prefix)
saver = tf.train.Saver()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

summary = None
val_summary = None

CAL_ACCURACY_FROM = 1
MAX_EARLY_STOP_COUNT = 14

with tf.Session(config=config) as sess:
    
    sess.run(tf.global_variables_initializer())
    early_stop_count = MAX_EARLY_STOP_COUNT
    
    ckpt = tf.train.get_checkpoint_state(os.path.dirname('save/' + graph_dir_name + '/'))
    if ckpt and ckpt.model_checkpoint_path:
        print ('from check point!!!')
        saver.restore(sess, ckpt.model_checkpoint_path)
        
    writer = tf.summary.FileWriter('./graph/'+graph_dir_name, sess.graph)
    
    initial_time = time.time()
    
    max_accr = 0
    
    for index in range(num_train_steps):
    #for index in range(1):
    
        try:
            # run train 
            
            raw_encoder_inputs, raw_encoderR_inputs, raw_encoder_seq, raw_context_seq, raw_encoderR_seq, raw_target_label = batch_gen.get_batch(
                                        data=batch_gen.train_set,
                                        batch_size=model.batch_size,
                                        encoder_size=model.encoder_size,
                                        context_size=model.context_size,
                                        encoderR_size=model.encoderR_size,
                                        is_test=False
                                        )
            
            # prepare data which will be push from pc to placeholder
            input_feed = {}
            
            input_feed[model.encoder_inputs] = raw_encoder_inputs
            input_feed[model.encoderR_inputs] = raw_encoderR_inputs

            input_feed[model.encoder_seq_length] = raw_encoder_seq
            input_feed[model.context_seq_length] = raw_context_seq
            input_feed[model.encoderR_seq_length] = raw_encoderR_seq

            input_feed[model.y_label] = raw_target_label

            input_feed[model.dr_prob] = model.dr
            input_feed[model.dr_prob_con] = model.dr_con
            input_feed[model.dr_memory_prob] = model.memory_dr
            
            _, summary, loss = sess.run([model.optimizer, model.summary_op, model.loss], input_feed)
            
            writer.add_summary( summary, global_step=model.global_step.eval() )
            
        except:
            print("excepetion occurs in train step")
            pass
        
        # run validation
        if (index + 1) % valid_freq == 0:
            
            print('=======')
            
            num_corr = 0
            sum_loss = 0.0
            
            itr_loop = len(batch_gen.valid_set) / model.batch_size
            
            for test_itr in range( int(itr_loop) ):
                
                raw_encoder_inputs, raw_encoderR_inputs, raw_encoder_seq, raw_context_seq, raw_encoderR_seq, raw_target_label = batch_gen.get_batch(
                                                                            data=batch_gen.valid_set,
                                                                            batch_size=model.batch_size,
                                                                            encoder_size=model.encoder_size,
                                                                            context_size = model.context_size,
                                                                            encoderR_size=model.encoderR_size,
                                                                            is_test=True,
                                                                            start_index= (test_itr* model.batch_size))
                
                
                # prepare data which will be push from pc to placeholder
                input_feed = {}
                
                input_feed[model.encoder_inputs] = raw_encoder_inputs
                input_feed[model.encoderR_inputs] = raw_encoderR_inputs
                
                input_feed[model.encoder_seq_length] = raw_encoder_seq
                input_feed[model.context_seq_length] = raw_context_seq
                input_feed[model.encoderR_seq_length] = raw_encoderR_seq
                
                input_feed[model.y_label] = raw_target_label

                input_feed[model.dr_prob] = 1.0          # no drop out while evaluating
                input_feed[model.dr_prob_con] = 1.0   # no drop out while evaluating 
                input_feed[model.dr_memory_prob] = 1.0
                
                try:
                    bprob, b_loss, lo = sess.run([model.batch_prob, model.batch_loss, model.loss], input_feed)
                except:
                    print("excepetion occurs in valid step : {}".format(str(test_itr)))
                    pass
                
                for idx, prob in enumerate(bprob):
                    if prob > 0.5:
                        if raw_target_label[idx] == 1:
                            num_corr = num_corr + 1
                    else:
                        if raw_target_label[idx] == 0:
                            num_corr = num_corr + 1
                            
                sum_loss = sum_loss + lo
                #print(num_corr / model.batch_size)
                
            avg_ce = sum_loss / int(itr_loop)
            avg_accr = num_corr / ( int(itr_loop) * model.batch_size )
            
            model.valid_loss = avg_ce
            model.accuracy = avg_accr

            value1 = summary_pb2.Summary.Value(tag="valid_loss", simple_value=avg_ce)
            value2 = summary_pb2.Summary.Value(tag="valid_accuracy", simple_value=avg_accr)
            val_summary = summary_pb2.Summary(value=[value1, value2])
            
            writer.add_summary( val_summary, global_step=model.global_step.eval() )
                
            end_time = time.time()
            
            print(avg_ce)
            print(avg_accr)
            
            accr = avg_accr
            
            if index > CAL_ACCURACY_FROM:
                
                if ( accr > max_accr ):
                    max_accr = accr
                    
                    # save best result
                    if is_save is 1:
                        saver.save(sess, 'save/' + graph_dir_name + '/', model.global_step.eval() )

                    early_stop_count = MAX_EARLY_STOP_COUNT
                    
                else:
                    # early stopping
                    if early_stop_count == 0:
                        print("early stopped")
                        break
                        
                    early_stop_count = early_stop_count -1
                    
    writer.close()
    
    print ('Total steps : {}'.format(model.global_step.eval()) )


In [None]:
#print(ce)
#print(accr)