In [1]:
import numpy as np
import math

In [2]:
model_prediction = [[0.1, 0.7, 0.1, 0.1],
                    [0.7, 0.1, 0.1, 0.1],
                    [0.1, 0.1, 0.6, 0.2],
                    [0.1, 0.1, 0.1, 0.7],
                    [0.4, 0.3, 0.2, 0.1]]

In [3]:
def greedy_search_decoder(predictions):
    
    #select token with the maximum probability for each prediction
    output_sequence = [np.argmax(prediction) for prediction in predictions]
    
    #storing token probabilities
    token_probabilities = [np.max(prediction) for prediction in predictions]
    
    #multiply individaul token-level probabilities to get overall sequence probability
    sequence_probability = np.product(token_probabilities)
    
    return output_sequence, sequence_probability

In [4]:
greedy_search_decoder(model_prediction)

([1, 0, 2, 3, 0], 0.08231999999999998)

In [5]:
def beam_search_decoder(predictions, top_k = 3):
    output_sequences = [([], 1.0)]
    
    #looping through all the predictions
    for token_probs in predictions:
        new_sequences = []
        
        #append new tokens to old sequences and re-score
        for old_seq, old_score in output_sequences:
            for char_index in range(len(token_probs)):
                new_seq = old_seq + [char_index]
                #considering log-likelihood for scoring
                new_score = old_score * (-math.log(token_probs[char_index]))
                new_sequences.append((new_seq, new_score))
                
        #sort all new sequences in increasing order of score
        output_sequences = sorted(new_sequences, key = lambda val: val[1])
        
        #select top-k based on score 
        # *Note- best sequence is with the smallest score as we are 
        #        calculating log-likelihood
        output_sequences = output_sequences[:top_k]
        
    return output_sequences

In [6]:
beam_search_decoder(model_prediction, top_k = 5)

[([1, 0, 2, 3, 0], 0.0212384966700603),
 ([1, 0, 2, 3, 1], 0.02790661468682351),
 ([1, 0, 2, 3, 2], 0.037304799180916594),
 ([1, 0, 2, 3, 3], 0.05337110169177288),
 ([1, 0, 3, 3, 0], 0.0669152841079077)]