In [49]:
import numpy as np

In [61]:
vocab_size = 100
hidden_size = 10
beam_size = 3
time_step = 10
initial_y = np.array([[3],[6]],dtype=np.int32)
batch_size = len(y)
weigth_xh = np.random.randn(vocab_size, hidden_size)
weight_hh = np.random.randn(hidden_size, hidden_size)
weight_ho = np.random.randn(hidden_size, vocab_size)
EOS_id = 0

In [62]:
def onehot(array, vocab_size):
    """
    array: [batch]
    labels_one_hot: [batch, vocab_size]
    """
    labels_one_hot = (np.expand_dims(array,2) == np.arange(vocab_size)).astype(np.int32)
    
    return labels_one_hot

In [63]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    
    return e_x / e_x.sum(axis=1,keepdims=True)

In [64]:
def decode(y):
    prev_hidden = np.zeros((y.shape[0], hidden_size))
    
    for time in range(y.shape[1]):
        token = y[:, time]
        x_to_h = np.matmul(onehot(token, vocab_size), weigth_xh)
        h_to_h = np.matmul(prev_hidden, weight_hh)
        hidden = np.tanh(x_to_h + h_to_h)
        prev_hidden = hidden
        if time == y.shape[1]-1:
            outputs = np.matmul(hidden, weight_ho)
            probs = softmax(outputs)
            probs = np.log(probs)
            
            return probs

In [65]:
for t in range(time_step):  
    def _get_preds_and_probs(PREDS):
        probs = decode(y=PREDS)
        preds_k = np.argsort(probs)[:, ::-1][:, :beam_size].flatten()
        probs_k = np.sort(probs)[:, ::-1][:, :beam_size].flatten()
        return preds_k, probs_k
    
    def logging(PREDS_k, EOS_k, PROBS_k):
        for i, (PREDS_k_batch, EOS_k_batch, PROBS_k_batch) in \
                    enumerate(zip(np.split(PREDS_k, batch_size), np.split(EOS_k, batch_size), np.split(PROBS_k, batch_size) )):
            print("batch num=", i)
            for each_PREDS_k_batch, each_EOS_k_batch, each_PROBS_k_batch in zip(PREDS_k_batch, EOS_k_batch, PROBS_k_batch):
                print("{}\t{}\t{}".format(each_PREDS_k_batch, each_EOS_k_batch, each_PROBS_k_batch))
               
    if t==0:
        print("="*10, "timesteps=", t, "="*10)
        
        preds_k, probs_k = _get_preds_and_probs(initial_y)
        PREDS_k = np.expand_dims(preds_k, -1)
        PROBS_k = probs_k
        EOS_k = preds_k==EOS_id
        
        # logging
        logging(PREDS_k, EOS_k, PROBS_k)
                                                      
    else:
        print("="*10, "timesteps=", t, "="*10)
        print("Expansion...")
        
        preds_kk, probs_kk = _get_preds_and_probs(PREDS_k)
        
        # preds for exanded beams
        PREDS_kk = np.repeat(PREDS_k, beam_size, axis=0)
        PREDS_kk = np.append(PREDS_kk, np.expand_dims(preds_kk, -1), -1)
        
        # eos for expanded beams
        eos_kk = preds_kk==EOS_id
        EOS_kk = np.repeat(EOS_k, beam_size, axis=0)
        EOS_kk = np.logical_or(EOS_kk, eos_kk)
        
        # probs for expanded beams
        PROBS_kk = np.repeat(PROBS_k, beam_size, axis=0)
        normalized_probs = ( PROBS_kk*t + probs_kk ) / (t+1)
        PROBS_kk = np.where(EOS_kk, PROBS_kk, normalized_probs)
        
        # logging
        logging(PREDS_kk, EOS_kk, PROBS_kk)
        
        print("Pruning ...")
        winners = []
        for j, prob_kk in enumerate(np.split(PROBS_kk, batch_size)):
            if t == time_step-1: # final step
                winner = np.argsort(prob_kk)[::-1][:1] # final 1 best
                winners.extend(list(winner + j*len(prob_kk)))
            else:
                winner = np.argsort(prob_kk)[::-1][:beam_size]
                winners.extend(list(winner + j*len(prob_kk)))
        
        PREDS_k = PREDS_kk[winners]
        PROBS_k = PROBS_kk[winners]
        EOS_k = EOS_kk[winners]
        
        # logging
        logging(PREDS_k, EOS_k, PROBS_k)

batch num= 0
[68]	False	-1.2324655717618953
[46]	False	-2.6376396515813605
[13]	False	-2.7762188007971154
batch num= 1
[41]	False	-1.4037301427948343
[65]	False	-1.562554304895344
[23]	False	-1.9929396506896107
Expansion...
batch num= 0
[68 17]	False	-1.2150569704370096
[68  1]	False	-1.8612188630741735
[68 14]	False	-1.9552234187669355
[46 18]	False	-1.8027837122787977
[46  2]	False	-2.3372687785011177
[46 80]	False	-2.5915110543235547
[13 17]	False	-1.6862257913536134
[13 87]	False	-2.88853416569071
[13 66]	False	-3.1750479963021707
batch num= 1
[41 11]	False	-1.7342978017745594
[41 62]	False	-1.914181830439107
[41 88]	False	-1.9430399395098354
[65 11]	False	-1.7701307724864415
[65 95]	False	-1.8066065400195082
[65 67]	False	-2.089750237715706
[23  2]	False	-1.641263362943986
[23 71]	False	-1.9867958833490111
[23 80]	False	-2.19114839504747
Pruning ...
batch num= 0
[68 17]	False	-1.2150569704370096
[13 17]	False	-1.6862257913536134
[46 18]	False	-1.8027837122787977
batch num= 1
[23  

  
