In [1]:
import numpy as np
import theano
import theano.tensor as T
import time

import lasagne as L

import sys
sys.path.insert(0, '../rnn_ex/')

from HRED import HRED
from mt_load import load_mt, get_mt_voc, get_w2v_embs
from load_subtle import load_subtle

In [2]:
# remember, now the pad value is the same as the <utt_end> token

pad_value = -1 # <utt_end>'s vector is the last one

In [3]:
t0 = time.time()
subtle_path = "/pio/data/data/mtriples/"

train_subtle = load_subtle(subtle_path, split=True, trim=200)
print time.time() - t0

IOError: Failed to interpret file '/pio/data/data/mtriples/Subtle_Dataset.triples.pkl' as a pickle

In [3]:
# mt_path = "/pio/data/data/mtriples/"
mt_path = "../DATA/MovieTriples_Dataset/"

train, valid, test = load_mt(path=mt_path, split=True, trim=200)
idx_to_w, w_to_idx, voc_size, freqs = get_mt_voc(path=mt_path, train_len=len(train))

In [4]:
word2vec_embs, word2vec_embs_mask = get_w2v_embs(path=mt_path)

w2v_train_mask = np.where(word2vec_embs_mask[:,0] == 1)[0]

In [5]:
hred_net = HRED(voc_size=voc_size,
                emb_size=300,
                lv1_rec_size=300, 
                lv2_rec_size=300, 
                out_emb_size=300, 
                num_sampled=200,
                ssoft_probs=freqs,
                emb_init=word2vec_embs)

Building the model...
Compiling theano functions...
Building a network for generating...
Done


In [6]:
hred_net.load_params('trained_models/subtleFixed_300_300_300_300_ssoft200unigr_bs30_cut200_early5.npz')

In [7]:
def print_utt(utt):
    return ' '.join([idx_to_w[x] if x != voc_size-1 else '<utt_end>' for x in utt])

def rnd_next_word(probs, size=1):
    return np.random.choice(np.append(np.arange(probs.shape[0]-1), -1).astype(np.int32), 
                            size=size, p=probs)

def utt_to_array(utt):
    arr = np.array([w_to_idx[w] for w in utt])[np.newaxis].astype(np.int32)
    arr[arr == -voc_size] = -1
    return arr

def context_summary(context, lookup=True):
    con_init = np.zeros((1, hred_net.lv2_rec_size), dtype=np.float32)
    for utt in context:
        con_init = hred_net.get_new_con_init_fn(utt_to_array(utt) if lookup else utt, con_init)
    return con_init

In [136]:
def diverse_beam_search(beam, gs, dec_init, init_seq=np.array([[1]]), rank_penalty=0, group_diversity_penalty=1,
                        seq_diversity_penalty=1, verbose_log=False):
    assert not beam % gs
    num_groups = beam / gs
    
    seq = np.repeat(init_seq.astype(np.int32), beam, axis=0)
    probs, dec_init = hred_net.get_probs_and_new_dec_init_fn(seq, dec_init)
    
    words = probs[0].argpartition(-beam)[-beam:].astype(np.int32)
    words[words == voc_size-1] = pad_value
    scores = np.log(probs[0][words])
    seq = np.hstack([seq, words[:, np.newaxis]])
    
    finished = []
    
    while seq.shape[1] < 50:
        all_probs, all_dec_init = hred_net.get_probs_and_new_dec_init_fn(seq[:,-1:], dec_init)
        
        new_seq = np.zeros((0, seq.shape[1] + 1), dtype=np.int32)
        new_dec_inits = []
        next_scores = []
        
        for g in xrange(num_groups):
            g_idx = slice(gs * g, gs * (g + 1))
            log_probs = np.log(all_probs[g_idx])
            dec_init = all_dec_init[g_idx]
            
            # here we have to add the dissimilarity as described in https://arxiv.org/pdf/1610.02424.pdf            
            # simple Hamming diversity            
            log_probs[:, new_seq[:, -1]] -= group_diversity_penalty
            
            # penalize repeating words in the same sequence
            log_probs[np.indices((gs, seq.shape[1]))[0], seq[g_idx]] -= seq_diversity_penalty
            
            words = log_probs.argpartition(-gs, axis=1)[:, -gs:].astype(np.int32)
            next_word_scores = log_probs[np.indices((gs, gs))[0], words]

            new_scores = next_word_scores + scores[g_idx, np.newaxis]
            
            # this line is for implementing rank penalty: https://arxiv.org/abs/1611.08562
            new_scores = (new_scores + (new_scores.argsort(axis=1) + 1) * rank_penalty).ravel()

            new_scores = new_scores.ravel()
            order = (-new_scores).argsort().astype(np.int32)

            for idx in order:
                if new_seq.shape[0] == (g + 1) * gs:
                    break

                i,j = divmod(idx, gs)

                extended_seq = np.concatenate([seq[gs * g + i], np.array([words[i,j]])])
                if extended_seq[-1] == w_to_idx['</s>']:
                    finished.append((extended_seq, new_scores[idx]))
                else:
                    new_seq = np.vstack([new_seq, extended_seq])
                    new_dec_inits.append(dec_init[i])
                    next_scores.append(new_scores[idx])

        if not new_seq.size:
            print 'Ending...'
            break
                
        seq = new_seq
        scores = np.array(next_scores)
        dec_init = np.array(new_dec_inits)
    
        if verbose_log:
            print 'Length ', seq.shape[1], '\n'
            for utt, s in zip(seq, scores):
                print '{:.4f} {}'.format(s, print_utt(utt))
                print ''
            print '#############\n'
            

#     final_scores = np.array(map(lambda x: x[1], finished))
#     finished = map(lambda x: x[0], finished)
    
    return finished#[final_scores.argmax()]

In [132]:
context = map(lambda x: np.array(x, dtype=np.int32)[np.newaxis], train[99:101])
lookup = False
for u in context:
    print print_utt(u[0])

<s> i couldn ' t say . </s>
<s> you were a prosecutor . </s>


In [121]:
context = ['<s> hi . </s>'.split(), '<s> hello , what \' s up ? </s>'.split()]
lookup = True

In [123]:
context = ['<s> yeah , okay . </s>'.split(), '<s> well , i guess i \' ll be going now . </s>'.split()]
lookup = True

In [125]:
context= ['<s> what would the table think about if it could think ? </s>'.split()]
lookup = True

In [127]:
context= ['<s> i saw a pretty good movie yesterday . </s>'.split()]
lookup = True

In [31]:
context= ['<s> hi . </s>'.split()]
lookup = True

In [139]:
beam_size = 20
group_size = 2
con_init = context_summary(context, lookup)
W = L.layers.get_all_param_values(hred_net.train_net)[31]
b = L.layers.get_all_param_values(hred_net.train_net)[32]
dec_init = np.repeat(con_init.dot(W) + b, beam_size, axis=0)

mean = True

beamsearch = diverse_beam_search(beam_size, group_size, dec_init, init_seq=utt_to_array('<s> '.split()), 
                                 rank_penalty=0, group_diversity_penalty=1, seq_diversity_penalty=1, 
                                 verbose_log=True)

# print print_utt(beamsearch)

len_bonus = lambda size: 0#np.log(size)**2

def fn_score(x, y, mean=mean, len_bonus=len_bonus):
    denom = (x.size - 1) if mean else 1
    return (y + len_bonus(x.size)) / denom

sort1 = sorted(beamsearch, key=lambda (x,y): fn_score(x, y), reverse=True)
sort2 = sorted(beamsearch, key=lambda x: ' '.join(print_utt(x[0][1:-1])))

for utt, scr in sort1:
    print '{:.3f}  '.format(fn_score(utt, scr)), print_utt(utt[1:-1])
    print ''

Length  3 

-5.1909 <s> he was

-5.3055 <s> <unk> .

-4.4524 <s> a <unk>

-6.0636 <s> in a

-5.7318 <s> it was

-5.7568 <s> so i

-4.1023 <s> yes .

-4.6445 <s> yes ,

-4.5991 <s> and you

-4.8679 <s> and i

-4.3893 <s> <person> .

-4.5492 <s> <person> ?

-3.6010 <s> i was

-4.1649 <s> i '

-4.3903 <s> you were

-4.9893 <s> that was

-5.7513 <s> my father

-6.4807 <s> my mother

-6.0574 <s> oh .

-6.1127 <s> oh ,

#############

Length  4 

-6.7884 <s> he was a

-6.8248 <s> he was .

-6.1384 <s> a <unk> ?

-6.3011 <s> a <unk> .

-7.3158 <s> so i was

-8.4983 <s> it was my

-5.9421 <s> yes , i

-6.8896 <s> yes , but

-5.5384 <s> and you were

-6.6294 <s> and i was

-7.8967 <s> <person> . i

-9.2663 <s> <person> ? <continued_utterance>

-4.8607 <s> i ' m

-5.1046 <s> i ' d

-7.2894 <s> you were in

-7.4859 <s> that was the

-7.7402 <s> my father was

-7.9366 <s> my father died

-8.3053 <s> oh , yes

-8.8196 <s> oh , no

#############

Length  5 

-8.5892 <s> he was a <unk>

-9.5944 <s> h

In [120]:
beam_size = 50
group_size = 2

for i in xrange(0, 100, 3):
    context = map(lambda x: np.array(x, dtype=np.int32)[np.newaxis], train[i:i+2])
    lookup = False
    for u in context:
        print print_utt(u[0])
        
    con_init = context_summary(context, lookup)
    W = L.layers.get_all_param_values(hred_net.train_net)[31]
    b = L.layers.get_all_param_values(hred_net.train_net)[32]
    dec_init = np.repeat(con_init.dot(W) + b, beam_size, axis=0)

    beamsearch = diverse_beam_search(beam_size, group_size, dec_init, 
                                     init_seq=utt_to_array('<s> '.split()))
    
    len_bonus = lambda size: 0#np.log(size)**2

    mean = True
        
    def fn_score(x, y, mean=mean, len_bonus=len_bonus):
        denom = (x.size - 1) if mean else 1
        return (y + len_bonus(x.size)) / denom

    for utt, scr in sorted(beamsearch, key=lambda (x,y): fn_score(x, y), reverse=True)[:1]:
        print '{:.3f}  '.format(fn_score(utt, scr)), print_utt(utt[1:-1])
        print ''
        
# for utt in beamsearch:
#     print_utt(utt)
#     print ''

<s> you lied to me so many times -- </s>
<s> reggie -- trust me once more -- please . </s>
-1.256   what do you want me to do ?

<s> even by modern male standards you ' re a <unk> immature little shit . <unk> with the kind of money you have access to , that ' s deadly . <person> may not have a four hour stand up routine about the <unk> building , but she ' s a solid girl who will look after you . </s>
<s> i have you for that . </s>
-1.468   what do you want me to do ?

<s> is that what you think i was thinking ? </s>
<s> no -- that ' s what i know you were thinking . <continued_utterance> how often do you make love to your wife , <person> ? once a week ? sometimes twice ? there once was passion , wasn ' t there ? but now it ' s <unk> , predictable . tell me , when you do it -- do you always think of her ? or do you wonder what it would be like to be with someone else ? someone wild . someone who would force you to lose control . <continued_utterance> there ' s nothing wrong in admittin

In [10]:
hred_net.train_one_epoch(train_subtle2, 60)

Done 10 batches in 0.86s	training loss:	7.176166
Done 20 batches in 1.69s	training loss:	6.175796
Done 30 batches in 2.44s	training loss:	5.697121
Done 40 batches in 3.47s	training loss:	5.374195
Done 50 batches in 4.27s	training loss:	5.207509
Done 60 batches in 5.11s	training loss:	5.063584
Done 70 batches in 5.92s	training loss:	4.934647
Done 80 batches in 6.69s	training loss:	4.840843


KeyboardInterrupt: 

In [7]:
hred_net.validate(test, 30)

Done 100 batches in 4.07s
Done 200 batches in 8.39s
Done 300 batches in 12.95s
Done 400 batches in 17.68s
Done 500 batches in 21.70s
Done 600 batches in 26.34s
Done 700 batches in 31.03s
Done 800 batches in 35.56s
Done 900 batches in 40.01s
Done 1000 batches in 44.49s
Done 1100 batches in 48.81s
Done 1200 batches in 53.21s
Done 1300 batches in 57.49s
Done 1400 batches in 61.62s
Done 1500 batches in 66.13s
Done 1600 batches in 70.65s
Done 1700 batches in 74.90s
Done 1800 batches in 79.24s
Done 1900 batches in 83.80s
Done 2000 batches in 88.35s
Done 2100 batches in 92.88s
Done 2200 batches in 97.45s
Done 2300 batches in 101.66s
Done 2400 batches in 106.09s


3.2819798801888118

In [8]:
'''full softmax, bs=30'''
# train, 1 dir, 1 epoch: 3.485554076321884
# val: 3.455356876018342

# train, 2 dir, concat, 1 epoch: 3.4864403798772239
# val: 3.4579001751897063

# train, 2 dir, L2 + concat, 1 epoch: 3.4881669768474675
# val: 3.4584704095551695
# training time: ~4700s

'''sampled softmax'''
# bs=30
# train, 2 dir, L2 + concat, 1 epoch: 3.486180601246621
# val: 3.4811877499289308
# training time: ~2300s

# bs=60
# train, 2 dir, L2 + concat, 1 epoch: 3.5235153449672456
# val: 3.5063306987542759
# training time: ~1900s