In [1]:
import numpy as np

In [2]:
def P(state): #The sentence example where we have maskings for each state (non-markovian assumption)
    if state == "START":
        return [("the","a"), (0.5, 0.5)]
    elif state == "the":
        return [("cat","dog","mouse" ),(0.4, 0.3, 0.3)]
    elif state == "a":
        return [("dog","cat" ), (0.5, 0.5)]
    elif state == "cat":
        return [("jumped","slept",'ran' ), (0.4, 0.5,0.1)]
    elif state == "dog":
        return [("barked","slept" ), (0.6, 0.4)]
    elif state == "mouse":
        return [("squeaked","slept" ), (0.7, 0.3)]
    elif state == "jumped":
        return [("over","on",'near' ), (0.3, 0.4, 0.3)]
    elif state == "on":
        return [("the","a" ), (0.6, 0.4)]
    elif state == "squeaked":
        return [("under","near" ), (0.8, 0.2)]
    elif state == "over":
        return [("the","a" ), (0.6, 0.4)]
    elif state == "barked":
        return [("loudly","quietly",'violently' ), (0.4, 0.4,0.2)]
    elif state == "loudly":
        return [("EOS","quietly" ), (0.5, 0.5)]
    elif state == "quietly":
        return [("EOS","loudly" ), (0.5, 0.5)]
    elif state == "slept":
        return [("for","while" ), (0.5, 0.5)]
    elif state == "for":
        return [("a","the" ), (0.6, 0.4)]
    elif state == "while":
        return [("a","the" ), (0.4, 0.6)]
    elif state == "under":
        return [("a","the" ), (0.7, 0.3)]
    elif state == "near":
        return [("a","the" ), (0.5, 0.5)]
    elif state == 'ran':
        return [("over","on" ), (0.3, 0.7)]
    elif state == 'violently':
        return [("at","while" ), (0.1, 0.9)]
    elif state == 'at':
        return [("the" ), (1)]
    else:
        return []

In [3]:
def score_func(strip,length_penalty): #Function to score the current sequence
    curr_tkn = strip[0]
    score = 0
    for pos in range(0,len(strip)-1):
        nxt_tkn = strip[pos+1]
        jump_score = [i for i in range(len(P(curr_tkn)[0])) if P(curr_tkn)[0][i] == nxt_tkn][0]
        score = score - np.log(P(curr_tkn)[1][jump_score])/(pos+1) ** length_penalty
        curr_tkn = nxt_tkn
    return score


In [4]:
class BeamSearch: #vanilla version of beamsearch (similar to what is implemented in the Transformer)
    def __init__(self, decoder_func, beam_size=2, max_length=10, length_penalty=0.6):
        self.decoder_func = decoder_func #Function that returns next tokens and scores
        self.beam_size = beam_size #Width k of the beam
        self.max_length = max_length #Max length of the sequence
        self.length_penalty = length_penalty #Parameter to penalize the length of the sequence

    def beam_search(self, initial_state):
        beams = [(np.array([initial_state]), 0)] #sequence, prob

        for _ in range(self.max_length):
            new_beams = []

            for current_seq, score in beams:
                if current_seq[-1] == 'EOS':
                    new_beams.append((current_seq,score))
                    continue
                
                tokens,probs = self.decoder_func(current_seq[-1])
                log_probs = np.log(probs)
                max_args = np.argsort(log_probs)[-self.beam_size:]

                for arg in max_args:
                    next_token = tokens[arg]
                    new_seq = np.append(current_seq, next_token)
                    new_score = score - log_probs[arg] / len(new_seq) ** self.length_penalty
                    new_beams.append((new_seq, new_score))

            beams = sorted(new_beams, key=lambda x: x[1])[:self.beam_size]

        return beams


In [12]:
import heapq

class A_Star_BeamSearch: #The proposed method for this deliverable (an enhanced Beam search)
    def __init__(self, x,score, heuristic_func,stop, beam_size=2, max_length=10, length_penalty=0.6):
        self.x = x #Function that returns next tokens and scores
        self.score = score #score function assuming additive costs
        self.heuristic_func = heuristic_func #Function to work with possible modification of scores
        self.stop = stop #Stopping function returns True or False, and takes Q as input
        self.beam_size = beam_size #Width k of the beam
        self.max_length = max_length #Max length of the sequence
        self.length_penalty = length_penalty #Parameter to penalize the length of the sequence
        self.Q = []

    def beam_search(self, initial_state):
        heapq.heappush(self.Q,(0,[initial_state]))
        POPS = {}
        while (len(self.Q) > 0) and (self.stop(self.Q) == False):
            s_h,y = heapq.heappop(self.Q)
            s_h = s_h*-1
            if tuple(y) not in POPS.keys():
                POPS[tuple(y)] = 0
            if (POPS[tuple(y)] >= self.beam_size) or (len(y) > self.max_length):
                continue
            POPS[tuple(y)] += 1
            if y[-1] == 'EOS':
                y.append('EOS')
                heapq.heappush(self.Q,(-s_h,y))
            else:
                tokens,_ = self.x(y[-1])
                for token in tokens:
                    y_temp = y+[token]
                    s = score_func(y_temp,self.length_penalty)
                    s_h = s+ self.heuristic_func(y_temp)
                    heapq.heappush(self.Q,(-s_h,y_temp))

        return self.Q

def stop_funct(Q):
    if Q[0][1][-1] == 'EOS':
        return True
    else:
        return False

def heuristic_func(y):
    return 0

abeamSearch = A_Star_BeamSearch(P,score_func,heuristic_func,stop_funct,1,5,0.6)
sequence = abeamSearch.beam_search('START')
sequence

[(-2.0774357808251827, ['START', 'a', 'dog', 'barked', 'loudly', 'EOS']),
 (-2.0774357808251827, ['START', 'a', 'dog', 'barked', 'loudly', 'quietly']),
 (-1.8135333087486334, ['START', 'a', 'dog', 'barked', 'quietly']),
 (-0.6931471805599453, ['START', 'the'])]

In [13]:
# Example usage
beam_search = BeamSearch(P, beam_size=1, max_length=5, length_penalty=0.6)
initial_state = "START"  # initial state
decoded_sequence = beam_search.beam_search(initial_state)

print("Decoded sequence:", decoded_sequence)

Decoded sequence: [(array(['START', 'a', 'cat', 'slept', 'while', 'the'], dtype='<U5'), 1.555805293055767)]
