In [1]:
import numpy as np
import random

vocab = {
    0: '早安你好',
    1: '我覺得今天天氣不錯',
    2: '有事嗎',
    3: '我覺得你心情不錯',
    4: '晚安',
    5: 'BOS',
    6: 'EOS'
}
reverse_vocab = dict([(v,k) for k,v in vocab.items()]) #轉為dict
vocab_size = len(vocab.items())


In [2]:
reverse_vocab

{'BOS': 5,
 'EOS': 6,
 '我覺得今天天氣不錯': 1,
 '我覺得你心情不錯': 3,
 '早安你好': 0,
 '晚安': 4,
 '有事嗎': 2}

In [3]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

def reduce_mul(l): 
    out = 1.0
    for x in l:
        out *= x
    return out

def check_all_done(seqs):
    for seq in seqs:
        if not seq[-1]:
            return False
    return True

In [4]:
def decode_step(encoder_context, input_seq):    
    #encoder_context包含encoder，ouput_step包含文字的機率
    words_prob = [random.random() for _ in range(vocab_size)]
    words_prob[reverse_vocab['BOS']] = 0.0 #將開始的符號的機率設為0
    words_prob = softmax(words_prob) #將數值轉非負的exp，並進行標準化
    ouput_step = [(idx,prob) for idx,prob in enumerate(words_prob)] #建立index     
    ouput_step = sorted(ouput_step, key=lambda x: x[1], reverse=True) #降幕排序descending，越排越低
    return ouput_step

#seq: [[word,word],[word,word],[word,word]]
#output: [[word,word,word],[word,word,word],[word,word,word]]

In [5]:
def beam_search_step(encoder_context, top_seqs, k):       
    all_seqs = []
    for seq in top_seqs:
        seq_score = reduce_mul([_score for _,_score in seq]) #將分數數值都乘以1.0
        if seq[-1][0] == reverse_vocab['EOS']:
            all_seqs.append((seq, seq_score, True))
            continue
        #獲得這個階段的資料
        current_step = decode_step(encoder_context, seq)
        for i,word in enumerate(current_step):    
            if i >= k:
                break
            word_index = word[0]
            word_score = word[1]   
            score = seq_score * word_score
            rs_seq = seq + [word]
            done = (word_index == reverse_vocab['EOS'])            
            all_seqs.append((rs_seq, score, done))            
    all_seqs = sorted(all_seqs, key = lambda seq: seq[1], reverse=True)        
    topk_seqs = [seq for seq,_,_ in all_seqs[:k]]
    all_done = check_all_done(topk_seqs)
    return topk_seqs, all_done

In [6]:
def beam_search(encoder_context):
    beam_size = 3 #不超過輸入seq長度
    max_len = 10 #不超過seq長度
    #開始
    top_seqs = [[(reverse_vocab['BOS'],1.0)]]
    #迴圈
    for _ in range(max_len):        
        top_seqs, all_done = beam_search_step(encoder_context, top_seqs, beam_size)
        if all_done:            
            break        
    return top_seqs

In [7]:
if __name__ == '__main__':
    encoder_context = None
    top_seqs = beam_search(encoder_context)
    for i,seq in enumerate(top_seqs): 
        print (('Path[%d]: ') % (i))
        for word in seq[1:]:
            word_index = word[0]
            print("word_index",word_index)
            word_prob = word[1]
            print("word_prob",word_prob)
            print (('%s(%.4f)') % (vocab[word_index], word_prob),)
            if word_index == reverse_vocab['EOS']:
                break
        print ('\n')



Path[0]: 
word_index 4
word_prob 0.1701881067161776
晚安(0.1702)


Path[1]: 
word_index 3
word_prob 0.1697865960377839
我覺得你心情不錯(0.1698)


Path[2]: 
word_index 1
word_prob 0.1510122269450021
我覺得今天天氣不錯(0.1510)




In [9]:
top_seqs = beam_search(encoder_context)

In [62]:
beam_search(encoder_context)
for i,seq in enumerate(top_seqs):
    print (('Path[%d]: ') % (i))

Path[0]: 
Path[1]: 
Path[2]: 
Path[3]: 


In [57]:
beam_size = 3
max_len = 10
    #開始
top_seqs = [[(reverse_vocab['BOS'],1.0)]]
    #迴圈
for _ in range(max_len):        
    top_seqs, all_done = beam_search_step(encoder_context, top_seqs, beam_size) 
    print("top_seqs",top_seqs)
    print("=================")
    print("all_done",all_done)

top_seqs [[(5, 1.0), (4, 0.21519921648542806)], [(5, 1.0), (3, 0.1666729881798719)], [(5, 1.0), (0, 0.1376148844031451)]]
all_done True
top_seqs [[(5, 1.0), (4, 0.21519921648542806), (1, 0.1901838032221608)], [(5, 1.0), (4, 0.21519921648542806), (6, 0.1729106676268892)], [(5, 1.0), (4, 0.21519921648542806), (0, 0.16900314835737126)]]
all_done True
top_seqs [[(5, 1.0), (4, 0.21519921648542806), (6, 0.1729106676268892)], [(5, 1.0), (4, 0.21519921648542806), (1, 0.1901838032221608), (4, 0.23599408608406133)], [(5, 1.0), (4, 0.21519921648542806), (0, 0.16900314835737126), (1, 0.23124380016713914)]]
all_done True
top_seqs [[(5, 1.0), (4, 0.21519921648542806), (6, 0.1729106676268892)], [(5, 1.0), (4, 0.21519921648542806), (1, 0.1901838032221608), (4, 0.23599408608406133), (6, 0.25984466702242376)], [(5, 1.0), (4, 0.21519921648542806), (1, 0.1901838032221608), (4, 0.23599408608406133), (1, 0.15859631018242)]]
all_done True
top_seqs [[(5, 1.0), (4, 0.21519921648542806), (6, 0.1729106676268892)

None
