# This notebooks is a test notebook to develoop the Beam Search function
It should probably be deleted in the future since the function should be put into the utilities folder.

In [48]:
from  heapq import heappush, heappop, nsmallest
import numpy as np

In [45]:
def beam_search(inputs_to_ids_fn, starting_state, starting_id, stop_condition_fn, beam_width=3, \
                num_states_returned=3, top_k=10, top_p=1.0, num_dist_samples=-1, temperature=1.0):
    """
    Searches a space using a beam constrrained tree to find the most likely outcomes.
    
    inputs_to_ids_fn: (state, id)->(new_state, [float]): the prediction function returning an 
                                            array of probabilities corresponding to the next valid ids
    starting_cache: an object containing the state of the prediction fn before being passed a new id
    starting_id: int: the integer used to prime the first distribution
    stop_condition_fn: (state)->bool: a function determining if the sequence should be stopped,
                                      this could include maximum length reached.
    beam_width: int: number of active beams to search at any particular point.
    
    returns: [state]: most probable terminated states
    """
    terminated_states = []
    active_states = [(0.0, starting_state, starting_id)] # we use log probabilities, and thus sum future log probs. log(p==1) = 0
    
    while len(terminated_states) < beam_width:
        p, best_state, best_next_id = heappop(active_states)
        new_state, next_id_probs = inputs_to_ids_fn(best_state, best_next_id)
        
        if stop_condition_fn(new_state):
            terminated_states.append((p, new_state))
            continue
        
        # choose best IDs to add to tree
        next_id_probs = np.array(next_id_probs)
        sorted_ids = np.argsort(next_id_probs)
        cumulative_probs = np.cumsum(next_ids[sorted_ids])
        top_p_ids = [idx for idx, cumulative_p in zip(sorted_ids,cumulative_probs) if cumulative_p<top_p]
        top_k_ids = top_p_ids[:top_k]
        
        if num_dist_samples<1:
            sampled_ids = top_k_ids[:beam_width]
        else:
            preds = np.log(next_id_probs[top_k_ids]) / temperature
            exp_preds = np.exp(preds)
            preds = exp_preds / np.sum(exp_preds)
            sampled_ids = np.random.choice(next_id_probs[top_k_ids], num_dist_samples, p=preds)
            
        for idx in sampled_ids:
            new_prob = p+np.log(next_id_probs[idx])
            heappush(active_states, (new_prob, best_state, idx))
        
        while len(active_states) > beam_width:
            heappop(active_states)
        
    return terminated_states

In [18]:
a = np.array([0.4,0.2,0.3,0.1])
np.cumsum(a[np.argsort(a)])

array([0.1, 0.3, 0.6, 1. ])

In [22]:
assert 0 not in np.array([0.4,0.1])

In [44]:
np.random.multinomial(2, [0.3,0.5,0.4], 2)

array([[0, 2, 0],
       [2, 0, 0]])

In [49]:
b = [3,5,4,2]
nsmallest(2,b)

[2, 3]

In [62]:
from anytree import Node, RenderTree, AsciiStyle, find_by_attr
f = Node("f")
b = Node("b", parent=f, foo=None)

In [63]:
a = find_by_attr(f, name="foo", value=None)

In [64]:
print(b)

Node('/f/b', foo=None)
