In [1]:
# This is a sketch of the constrained decoding algorithm by Hokamp & Liu

In [2]:
import numpy as np
from sortedcontainers import SortedListWithKey

In [3]:
# imagine a max_source_len x constraint_len+1 grid
# at the top left corner, there is a triangle with constraint_len-1 sides cut out
# at the bottom right corner, there is a triangle with constraint_len sides cut out

# we move to the right, filling the beams in each column starting with the bottommost, and moving upwards
#        - filling the beams in a column can be done in parallel, since there are no dependencies within the column

# the horizontal (t) axis represents time
#     - every hypothesis in a column t has the same number of tokens
# the vertical axis (j) represents coverage of constraints
#     - every hypothesis in a row j covers the same number of constraint tokens

# FILLING CELL (i,j)
# there are two source beams from which we can generate hypotheses:
# LEFT (cell (i-1, j))
#    - this cell can only generate 
# BELOW+LEFT (cell (i-1, j-1))
#    - this cell can add constraints in two ways:
#      (1) constraints which are unfinished _MUST_ be continued
#      (2) new constraints can be started
#    - hypotheses from this beam always update the constraint coverage

# Generating constraint hypotheses
# the hypothesis object holds all of the states needed to generate the n-best continuations at the next timestep


In [4]:
# NOTES:
# without a special feature, generating a word and using the same word from a constraint have the same score,
# thus we need a way to decide whether we are generating a word, or starting a new constraint which begins with 
# that word
# - the constraint pointer model is one way of scoring hypotheses from the different sources differently

In [5]:
vocabulary = [0,1,2,3,4,5,6,7,8,9]

sample_constraints = [
    [1,2],
    [5,6,7]
]



In [None]:
# Thinking buffer
def init_coverage(self, constraints):
    coverage = []
    for c in constraints:
        coverage.append(np.zeros(len(c), dtype='int16'))
    return coverage

In [None]:
class ConstraintHypothesis:
    """A (partial) hypothesis which maintains an additional constraint coverage object
    
    Args:
        token (unicode): the surface form of this hypothesis
        score (float): the score of this hypothesis (higher is better)
        coverage (list of lists): a representation of the area of the constraints covered by this hypothesis
        constraints (list of lists): the constraints that may be used with this hypothesis
        payload (:obj:): additional data that comes with this hypothesis. Functions may 
            require certain data to be present in the payload, such as the previous states, glimpses, etc...
        backpointer (:obj:`ConstraintHypothesis`): a pointer to the hypothesis object which generated this one
        unfinished_constraint (bool): a flag which indicates whether this hyp is inside an unfinished constraint
    
    """
    
    def __init__(self, token, score, coverage, constraints, payload=None, backpointer=None,
                 unfinished_constraint=False):
        self.token = token
        self.score = score
        self.coverage = coverage
        self.constraints = constraints
        self.backpointer = backpointer
        self.payload = payload
        self.unfinished_constraint = unfinished_constraint
        
        
        
            
class AbstractBeam():
    
    def __init__(self, size):
        # note: here we assume bigger scores are better
        self.hypotheses = SortedListWithKey(key=lambda x: -x['score'])
        self.size = size
    
    def add(self, hyp):
        self.hypotheses.add(hyp)
        if len(self.hypotheses) > self.size:
            assert len(self.hypotheses) == self.size + 1
            del self.hypotheses[-1]
    
    def __len__(self):
        return len(self.hypotheses)
    
    def __iter__(self):
        for hyp in self.hypotheses:
            yield hyp
        



In [None]:
def get_generation_hyps(beam, hyp_generation_func):
    """return all hyps which are continuations of the hyps on this beam
    
    hyp_generation_func maps `(hyp) --> continuations`
    
    the coverage vector of the parent hyp is not modified in each child
    """
    
    continuations = (hyp_generation_func(hyp) for hyp in beam if not hyp.unfinished_constraint)
    
    # flatten
    return (new_hyp for hyp_list in continuations for new_hyp in hyp_list)

        
def get_new_constraint_hyps(beam, constraints, constraint_hyp_func):
    """return all hyps which start a new constraint from the hyps on this beam
    
    constraint_hyp_func maps `(hyp, constraints) --> continuations`
    
    the coverage vector of the parent hyp is modified in each child
    """
    
    continuations = (constraint_hyp_func(hyp, constraints)
                     for hyp in beam if not hyp.unfinished_constraint)
    
    # flatten
    return (new_hyp for hyp_list in continuations for new_hyp in hyp_list)


def get_continued_constraint_hyps(beam, constraints, constraint_hyp_func):
    """return all hyps which continue the unfinished constraints on this beam
    
    constraint_hyp_func maps `(hyp, constraints) --> forced_continuations`
    
    the coverage vector of the parent hyp is modified in each child

    """
    continuations = (constraint_hyp_func(hyp, constraints)
                     for hyp in beam if hyp.unfinished_constraint)
    
    return (new_hyp for hyp_list in continuations for new_hyp in hyp_list)
    

In [None]:
# the implementations of the generate, start_constraint, and continue_constraint functions depend upon the decoder


In [None]:
class Beam():
    pass