Load model and embedding data.

In [1]:
import theano
import theano.tensor as T
import numpy as np
import cPickle
import random
from utils import *
from state import *
from title_model import TitleModel
from whenst_hour_model import WhenstHourModel
from whened_hour_model import WhenedHourModel
from dur_hour_model import DurHourModel
from whenst_min_model import WhenstMinModel
from whened_min_model import WhenedMinModel
from dur_min_model import DurMinModel

theano.config.floatX='float32'

title_model_name = 'model/title_emb256_h256_f32_tok10_model.npz'

title_state = title_state()
title_model = TitleModel(title_state, test_mode=True)
title_model.load(title_model_name)

whenst_hour_model_name = 'model/whenst_hour_emb256_h256_f32_model.npz'

whenst_hour_state = whenst_hour_state()
whenst_hour_model = WhenstHourModel(whenst_hour_state, test_mode=True)
whenst_hour_model.load(whenst_hour_model_name)

whenst_min_model_name = 'model/whenst_min_emb256_h256_f32_model.npz'

whenst_min_state = whenst_min_state()
whenst_min_model = WhenstMinModel(whenst_min_state, test_mode=True)
whenst_min_model.load(whenst_min_model_name)

whened_hour_model_name = 'model/whened_hour_emb256_h256_f32_model.npz'

whened_hour_state = whened_hour_state()
whened_hour_model = WhenedHourModel(whened_hour_state, test_mode=True)
whened_hour_model.load(whened_hour_model_name)

whened_min_model_name = 'model/whened_min_emb256_h256_f32_model.npz'

whened_min_state = whened_min_state()
whened_min_model = WhenedMinModel(whened_min_state, test_mode=True)
whened_min_model.load(whened_min_model_name)

dur_hour_model_name = 'model/dur_hour_emb256_h256_f32_model.npz'

dur_hour_state = dur_hour_state()
dur_hour_model = DurHourModel(dur_hour_state, test_mode=True)
dur_hour_model.load(dur_hour_model_name)

dur_min_model_name = 'model/dur_min_emb256_h256_f32_model.npz'

dur_min_state = dur_min_state()
dur_min_model = DurMinModel(dur_min_state, test_mode=True)
dur_min_model.load(dur_min_model_name)

(ind2word, word2ind, _, _, _) = cPickle.load(open('data/dict.pkl'))

print('Data loaded.')

Data loaded.


Now we try to mannually provide an input sentence (if a word is out of vocab, we replace it with a special token).

In [2]:
def restoreW(ind_lst, ind2word, tmp_map):
    res = []
    for ind in ind_lst:
        if ind in tmp_map:
            res.append(tmp_map[ind])
        else:
            res.append(ind2word[ind])
    return ' '.join(res)
    

test_sents = ['i would like to go hiking at one pm tomorrow until five in the evening', \
              'i would like to go hiking at one pm tomorrow until five', \
              'i want to sleep all day to seven pm', \
              'i want to play video games for two hours at eleven in the morning', \
              'i want to visit my dental . call me at one pm', \
              'i will wash my car this afternoon from half past two to a quarter past four pm', \
              'i want to play chess this afternoon . call me at three pm . i will be back at five pm', \
              'i want to go jogging this afternoon . i will be back at five thirty pm', \
              'have a meeting today for three hours']

for (k, test_sent) in enumerate(test_sents):
    print "Test sent:", test_sent
    
    # Process title
    words = test_sent.split()
    nat_coded = [1]
    tmp_map = {}
    tok_set = range(10)
    for w in words:
        if w in word2ind:
            nat_coded.append(word2ind[w])
        else:
            tok_ind = random.choice(tok_set)
            tok_s = '<TOK%d>' % tok_ind
            tok_set.remove(tok_ind)
            nat_coded.append(word2ind[tok_s])
            tmp_map[word2ind[tok_s]] = w
            print '  out of vocab: %s, replaced with %s' % (w, tok_s)
    nat_coded.append(0)
    # print 'Coded input:', nat_coded
    # print restoreW(nat_coded, ind2word, tmp_map)

    m = title_state['seq_len_in']
    nat_coded_mat = numpy.zeros((m, 2), dtype='int32')
    nat_mask = numpy.zeros((m, 2), dtype='float32')
    sent_len = len(nat_coded)
    nat_coded_mat[:sent_len, 0] = nat_coded
    nat_mask[:sent_len, 0] = 1
    nat_coded_mat[:sent_len, 1] = nat_coded
    nat_mask[:sent_len, 1] = 1
    pred_fn = title_model.build_gen_function()
    
    res = [1]
    abs_in = 1
    title_model.gen_reset()
    while True:
        abs_in_mat = np.zeros((2, ), dtype='int32') + abs_in
        #print 'abs_in', abs_in_mat
        [p_t, o_t, alpha_t] = pred_fn(nat_coded_mat, nat_mask, abs_in_mat)
        #print "ot", o_t, ind2word[o_t[0]]
        pt_col = p_t[0]
        alpha_t = alpha_t[:, 0]
        #print alpha_t
        alpha_s = alpha_t.argsort()[::-1]
        #print sum(pt_col)
        pt_norm = [1.0 * a / sum(pt_col) for a in pt_col]
        #print pt_norm
        ind = np.asarray(pt_norm).argmax()
        abs_in = ind
        res.append(ind)
        if ind == 0 or len(res) > 10:        
            break
        
        #print 'Explanation of: %s' % restoreW([ind], ind2word, tmp_map)
        #for k in alpha_s[:len(nat_coded)]:
        #    print "    %s: %.4f" % (restoreW([nat_coded[k]], ind2word, tmp_map), alpha_t[k])
    print
    print restoreW(res, ind2word, tmp_map)
    
    # Whenst hour
    whenst_hour_pred = theano.function([whenst_hour_model.x_data, whenst_hour_model.xmask], \
                                      whenst_hour_model.ot)
    whenst_hour_ot = whenst_hour_pred(nat_coded_mat, nat_mask)
    whenst_hour_res = whenst_hour_ot[0]
    
    # Whenst min
    whenst_min_pred = theano.function([whenst_min_model.x_data, whenst_min_model.xmask], \
                                      whenst_min_model.ot)
    whenst_min_ot = whenst_min_pred(nat_coded_mat, nat_mask)
    whenst_min_res = whenst_min_ot[0]
    
    if whenst_hour_res == 24:
        whenst_str = 'Not mentioned'
    else:
        whenst_str = "%02d:%02d" % (whenst_hour_res, whenst_min_res*15%60)
    print 'Start time: %s' % whenst_str
    
    # Whened hour
    whened_hour_pred = theano.function([whened_hour_model.x_data, whened_hour_model.xmask], \
                                      whened_hour_model.ot)
    whened_hour_ot = whened_hour_pred(nat_coded_mat, nat_mask)
    whened_hour_res = whened_hour_ot[0]
    
    # Whened min
    whened_min_pred = theano.function([whened_min_model.x_data, whened_min_model.xmask], \
                                      whened_min_model.ot)
    whened_min_ot = whened_min_pred(nat_coded_mat, nat_mask)
    whened_min_res = whened_min_ot[0]
    
    if whened_hour_res == 24:
        whened_str = 'Not mentioned'
    else:
        whened_str = "%02d:%02d" % (whened_hour_res, whened_min_res*15%60)
    print 'End time: %s' % whened_str
    
    # Dur hour
    dur_hour_pred = theano.function([dur_hour_model.x_data, dur_hour_model.xmask], \
                                      dur_hour_model.ot)
    dur_hour_ot = dur_hour_pred(nat_coded_mat, nat_mask)
    dur_hour_res = dur_hour_ot[0]
    
    # Dur min
    dur_min_pred = theano.function([dur_min_model.x_data, dur_min_model.xmask], \
                                      dur_min_model.ot)
    dur_min_ot = dur_min_pred(nat_coded_mat, nat_mask)
    dur_min_res = dur_min_ot[0]
    
    if dur_hour_res == 24:
        dur_str = 'Not mentioned'
    else:
        dur_str = "%02d:%02d" % (dur_hour_res, dur_min_res*15%60)
    print 'Duration: %s' % dur_str
    
    print
    print

Test sent: i would like to go hiking at one pm tomorrow until five in the evening
  out of vocab: hiking, replaced with <TOK6>
  out of vocab: tomorrow, replaced with <TOK5>

<START> go hiking <END>
Start time: 13:00
End time: 17:15
Duration: Not mentioned


Test sent: i would like to go hiking at one pm tomorrow until five
  out of vocab: hiking, replaced with <TOK4>
  out of vocab: tomorrow, replaced with <TOK9>

<START> go hiking <END>
Start time: 13:00
End time: 05:00
Duration: Not mentioned


Test sent: i want to sleep all day to seven pm
  out of vocab: sleep, replaced with <TOK1>

<START> sleep out <END>
Start time: 19:00
End time: 18:00
Duration: Not mentioned


Test sent: i want to play video games for two hours at eleven in the morning
  out of vocab: play, replaced with <TOK4>
  out of vocab: video, replaced with <TOK8>
  out of vocab: games, replaced with <TOK9>

<START> take video <END>
Start time: 11:00
End time: Not mentioned
Duration: 02:00


Test sent: i want to visit 