In [None]:
# This is a sketch of the constrained decoding algorithm by Hokamp & Liu

In [1]:
import copy

from collections import defaultdict, OrderedDict
import numpy as np
from sortedcontainers import SortedListWithKey

In [2]:
# imagine a max_source_len x constraint_len+1 grid
# at the top left corner, there is a triangle with constraint_len-1 sides cut out
# at the bottom right corner, there is a triangle with constraint_len sides cut out

# we move to the right, filling the beams in each column starting with the bottommost, and moving upwards
#        - filling the beams in a column can be done in parallel, since there are no dependencies within the column

# the horizontal (t) axis represents time
#     - every hypothesis in a column t has the same number of tokens
# the vertical axis (j) represents coverage of constraints
#     - every hypothesis in a row j covers the same number of constraint tokens

# FILLING CELL (i,j)
# there are two source beams from which we can generate hypotheses:
# LEFT (cell (i-1, j))
#    - this cell can only generate 
# BELOW+LEFT (cell (i-1, j-1))
#    - this cell can add constraints in two ways:
#      (1) constraints which are unfinished _MUST_ be continued
#      (2) new constraints can be started
#    - hypotheses from this beam always update the constraint coverage

# Generating constraint hypotheses
# the hypothesis object holds all of the states needed to generate the n-best continuations at the next timestep
# this is the purpuse of the `payload` attribute of hypothesis objects


In [3]:
# NOTES:
# without a special feature, generating a word and using the same word from a constraint have the same score,
# thus we need a way to decide whether we are generating a word, or starting a new constraint which begins with 
# that word
# - the constraint pointer model is one way of scoring hypotheses from the different sources differently

In [4]:
vocabulary = [0,1,2,3,4,5,6,7,8,9]

sample_constraints = [
    [1,2],
    [5,6,7]
]

In [5]:
# Thinking buffer
def init_coverage(constraints):
    coverage = []
    for c in constraints:
        coverage.append(np.zeros(len(c), dtype='int16'))
    return coverage


    

In [40]:
class ConstraintHypothesis:
    """A (partial) hypothesis which maintains an additional constraint coverage object
    
    Args:
        token (unicode): the surface form of this hypothesis
        score (float): the score of this hypothesis (higher is better)
        coverage (list of lists): a representation of the area of the constraints covered by this hypothesis
        constraints (list of lists): the constraints that may be used with this hypothesis
        payload (:obj:): additional data that comes with this hypothesis. Functions may 
            require certain data to be present in the payload, such as the previous states, glimpses, etc...
        backpointer (:obj:`ConstraintHypothesis`): a pointer to the hypothesis object which generated this one
        constraint_index (tuple): if this hyp is part of a constraint, the index into `self.constraints` which
            is covered by this hyp `(constraint_idx, token_idx)`
        unfinished_constraint (bool): a flag which indicates whether this hyp is inside an unfinished constraint
    
    """
    
    def __init__(self, token, score, coverage, constraints, payload=None, backpointer=None,
                 constraint_index=None, unfinished_constraint=False):
        self.token = token
        self.score = score
        
        assert len(coverage) == len(constraints), 'constraints and coverage length must match'
        assert all(len(cov) == len(cons) for cov, cons in zip(coverage, constraints)), \
            'each coverage and constraint vector must match'
        
        self.coverage = coverage
        self.constraints = constraints
        self.backpointer = backpointer
        self.payload = payload
        self.constraint_index = constraint_index
        self.unfinished_constraint = unfinished_constraint
        
    def __str__(self):
        return u'token: {}, sequence: {}, score: {}, coverage: {}, constraints: {},'.format(
            self.token, self.sequence, self.score, self.coverage, self.constraints)
    
    def __getitem__(self, key):
        return getattr(self, key)
    
    @property
    def sequence(self):
        sequence = []
        current_hyp = self
        while current_hyp.backpointer is not None:
            sequence.append((current_hyp.token, current_hyp.constraint_index))
            current_hyp = current_hyp.backpointer
        sequence.append((current_hyp.token, current_hyp.constraint_index))
        return sequence[::-1]
    
    def constraint_candidates(self):
        available_constraints = []
        for idx in range(len(self.coverage)):
            if self.coverage[idx][0] == 0:
                available_constraints.append(idx)
            
        return available_constraints
    
    
        
        
class AbstractBeam():
    
    def __init__(self, size):
        # note: here we assume bigger scores are better
        self.hypotheses = SortedListWithKey(key=lambda x: -x['score'])
        self.size = size
    
    def add(self, hyp):
        self.hypotheses.add(hyp)
        if len(self.hypotheses) > self.size:
            assert len(self.hypotheses) == self.size + 1
            del self.hypotheses[-1]
    
    def __len__(self):
        return len(self.hypotheses)
    
    def __iter__(self):
        for hyp in self.hypotheses:
            yield hyp
        


In [41]:
# FUNCTIONS USED BY THE CONSTRAINED DECODER
# Note: hyps on the top level may be finished (End with EOS), or may be continuing (haven't gotten an EOS yet)

class ConstrainedDecoder(object):
    
    def __init__(self, hyp_generation_func, constraint_generation_func, continue_constraint_func,
                 beam_implementation=AbstractBeam):
        self.hyp_generation_func = hyp_generation_func
        self.constraint_generation_func = constraint_generation_func
        self.continue_constraint_func = continue_constraint_func
        self.beam_implementation = beam_implementation
        
    # QUESTION: are mid-constraint hyps allowed to fall off of the beam or not?
    def search(self, start_hyp, constraints, max_source_len, beam_size=10):
        """create a constrained search
            - fill the search grid
        """
        
        # the total number of constraint tokens determines the height of the grid
        grid_height = sum(len(c) for c in constraints)
        
        search_grid = OrderedDict()
        
        # a beam with one hyp starts the search
        start_beam = self.beam_implementation(size=1)
        start_beam.add(start_hyp)
        
        search_grid[(0,0)] = start_beam
        
        for i in range(1, max_source_len + 1):
            print('TIME: {}'.format(i+1))
            j_start = max(i - (max_source_len - grid_height), 0)
            j_end = min(i, grid_height) + 1
            beams_in_i = j_end - j_start
            for j in range(j_start, min(i, grid_height) + 1):
                # create the new beam
                new_beam = self.beam_implementation(size=beam_size)
                # generate hyps from (i-1, j-1), and (i-1, j), and add them to the beam
                # cell to the left generates
                if (i-1, j) in search_grid:
                    generation_hyps = self.get_generation_hyps(search_grid[(i-1, j)])
                    for hyp in generation_hyps:
                        new_beam.add(hyp)
                # lower left diagonal cell adds hyps from constraints
                if (i-1, j-1) in search_grid:
                    new_constraint_hyps = self.get_new_constraint_hyps(search_grid[(i-1, j-1)])
                    continued_constraint_hyps = self.get_continued_constraint_hyps(search_grid[(i-1, j-1)])
                    for hyp in new_constraint_hyps:
                        new_beam.add(hyp)
                    for hyp in continued_constraint_hyps:
                        new_beam.add(hyp)
                
                search_grid[(i,j)] = new_beam
            
                print('index: {}'.format((i,j)))
        
        return search_grid
    
# FILLING CELL (i,j)
# there are two source beams from which we can generate hypotheses:
# LEFT (cell (i-1, j))
#    - this cell can only generate 
# BELOW+LEFT (cell (i-1, j-1))
#    - this cell can add constraints in two ways:
#      (1) constraints which are unfinished _MUST_ be continued
#      (2) new constraints can be started
#    - hypotheses from this beam always update the constraint coverage
    def get_generation_hyps(self, beam):
        """return all hyps which are continuations of the hyps on this beam
    
        hyp_generation_func maps `(hyp) --> continuations`
    
        the coverage vector of the parent hyp is not modified in each child
        """
    
        continuations = (self.hyp_generation_func(hyp) for hyp in beam if not hyp.unfinished_constraint)
    
        # flatten
        return (new_hyp for hyp_list in continuations for new_hyp in hyp_list)
        
    def get_new_constraint_hyps(self, beam):
        """return all hyps which start a new constraint from the hyps on this beam
    
        constraint_hyp_func maps `(hyp) --> continuations`
    
        the coverage vector of the parent hyp is modified in each child
        """
    
        continuations = (self.constraint_generation_func(hyp)
                         for hyp in beam if not hyp.unfinished_constraint)
    
        # flatten
        return (new_hyp for hyp_list in continuations for new_hyp in hyp_list)


    def get_continued_constraint_hyps(self, beam):
        """return all hyps which continue the unfinished constraints on this beam
    
        constraint_hyp_func maps `(hyp, constraints) --> forced_continuations`
    
        the coverage vector of the parent hyp is modified in each child

        """
        continuations = (self.continue_constraint_func(hyp)
                         for hyp in beam if hyp.unfinished_constraint)
    
        return continuations
    

In [43]:
# the implementations of the generate, start_constraint, and continue_constraint functions depend upon the decoder

# to generate, we query the decoder with a hypothesis
# the decoder uses the payload to compute the next most probable continuations

# to start new constrained hypotheses, we need to:
# constrained_hyps = []
# for constraint_idx in constraint_candidates(hyp):
#     new_hyp = build_hyp(hyp, constraint_idx, constraints)
#     constrained_hyps.append(new_hyp)


# to continue a constrained hypothesis, we need to:
# (1) find the constraint to continue via the coverage object
# (2) find the next token in the constraint
# (3) 
# (4) get the data for this token (score, states, etc...) -- i.e. forced decode for one step



In [44]:
# DUMBEST POSSIBLE IMPLEMENTATION of generation functions
# Note that generation and search are done by _different_ classes

class DumbTranslationModel(object):
    
    def __init__(self, vocabulary):
        self.vocabulary = vocabulary
    
    def dumb_generate(self, hyp, n_best=1):
        # make k_best random hyp objects
        next_tokens = np.random.choice(self.vocabulary, size=n_best)
        next_scores = np.random.random(size=n_best)
        
        new_hyps = []
        for i in range(n_best):
            new_hyp = ConstraintHypothesis(token=next_tokens[i],
                                           score=next_scores[i],
                                           coverage=copy.deepcopy(hyp.coverage),
                                           constraints=hyp.constraints,
                                           payload=None,
                                           backpointer=hyp,
                                           constraint_index=None,
                                           unfinished_constraint=False
                                          )
            new_hyps.append(new_hyp)
        
        return new_hyps
    
    def dumb_generate_from_constraints(self, hyp):
        """Look at the coverage of the hyp to get constraint candidates"""
        
        assert hyp.unfinished_constraint is not True, 'hyp must not be part of an unfinished constraint'
        new_constraint_hyps = []
        available_constraints = hyp.constraint_candidates()
        for idx in available_constraints:
            # starting a new constraint
            constraint_token = hyp.constraints[idx][0]
            # this should come from the model
            score = np.random.random()
            coverage = copy.deepcopy(hyp.coverage)
            coverage[idx][0] = 1
            if len(coverage[idx]) > 1:
                unfinished_constraint = True
            else:
                unfinished_constraint = False
                
            new_hyp = ConstraintHypothesis(token=constraint_token,
                                           score=score,
                                           coverage=coverage,
                                           constraints=hyp.constraints,
                                           payload=None,
                                           backpointer=hyp,
                                           constraint_index=(idx, 0),
                                           unfinished_constraint=unfinished_constraint
                                          )
            new_constraint_hyps.append(new_hyp)
        
        return new_constraint_hyps
        
    
    def dumb_continue_unfinished_constraint(self, hyp):
        assert hyp.unfinished_constraint is True, 'hyp must be part of an unfinished constraint'
        
        # this should come from the model
        score = np.random.random()
        
        constraint_row_index = hyp.constraint_index[0]
        # the index of the next token in the constraint
        constraint_tok_index = hyp.constraint_index[1] + 1
        constraint_index = (constraint_row_index, constraint_tok_index)
        
        continued_constraint_token = hyp.constraints[constraint_index[0]][constraint_index[1]]
        
        coverage = copy.deepcopy(hyp.coverage)
        coverage[constraint_row_index][constraint_tok_index] = 1
        
        if len(hyp.constraints[constraint_row_index]) > constraint_tok_index + 1:
            unfinished_constraint = True
        else:
            unfinished_constraint = False
        
        new_hyp = ConstraintHypothesis(token=continued_constraint_token,
                                       score=score,
                                       coverage=coverage,
                                       constraints=hyp.constraints,
                                       payload=None,
                                       backpointer=hyp,
                                       constraint_index=constraint_index,
                                       unfinished_constraint=unfinished_constraint
                                      )
        return new_hyp

        

In [45]:
START_TOKEN = u'<S>'
P_START = 1.0
N_BEST = 5

In [46]:
dumb_tm = DumbTranslationModel(vocabulary)


In [47]:
decoder = ConstrainedDecoder(hyp_generation_func=dumb_tm.dumb_generate,
                             constraint_generation_func=dumb_tm.dumb_generate_from_constraints,
                             continue_constraint_func=dumb_tm.dumb_continue_unfinished_constraint,
                             beam_implementation=AbstractBeam)

#     def __init__(self, hyp_generation_func, constraint_generation_func, continue_constraint_func,
#                  beam_implementation=AbstractBeam):

In [48]:
start_hyp = ConstraintHypothesis(token=START_TOKEN, score=P_START,
                                 coverage=init_coverage(sample_constraints),
                                 constraints=sample_constraints,
                                 payload=None,
                                 backpointer=None,
                                 unfinished_constraint=False
                                )

In [49]:
output_grid = decoder.search(start_hyp=start_hyp, constraints=sample_constraints, max_source_len=15, beam_size=5)

TIME: 2
index: (1, 0)
index: (1, 1)
TIME: 3
index: (2, 0)
index: (2, 1)
index: (2, 2)
TIME: 4
index: (3, 0)
index: (3, 1)
index: (3, 2)
index: (3, 3)
TIME: 5
index: (4, 0)
index: (4, 1)
index: (4, 2)
index: (4, 3)
index: (4, 4)
TIME: 6
index: (5, 0)
index: (5, 1)
index: (5, 2)
index: (5, 3)
index: (5, 4)
index: (5, 5)
TIME: 7
index: (6, 0)
index: (6, 1)
index: (6, 2)
index: (6, 3)
index: (6, 4)
index: (6, 5)
TIME: 8
index: (7, 0)
index: (7, 1)
index: (7, 2)
index: (7, 3)
index: (7, 4)
index: (7, 5)
TIME: 9
index: (8, 0)
index: (8, 1)
index: (8, 2)
index: (8, 3)
index: (8, 4)
index: (8, 5)
TIME: 10
index: (9, 0)
index: (9, 1)
index: (9, 2)
index: (9, 3)
index: (9, 4)
index: (9, 5)
TIME: 11
index: (10, 0)
index: (10, 1)
index: (10, 2)
index: (10, 3)
index: (10, 4)
index: (10, 5)
TIME: 12
index: (11, 1)
index: (11, 2)
index: (11, 3)
index: (11, 4)
index: (11, 5)
TIME: 13
index: (12, 2)
index: (12, 3)
index: (12, 4)
index: (12, 5)
TIME: 14
index: (13, 3)
index: (13, 4)
index: (13, 5)
TIME:

In [50]:
[k for k in output_grid]
for k in output_grid:
    for hyp in output_grid[k]:
        print('key: {} hyp: {}'.format(k, hyp.sequence))

key: (0, 0) hyp: [(u'<S>', None)]
key: (1, 0) hyp: [(u'<S>', None), (6, None)]
key: (1, 1) hyp: [(u'<S>', None), (5, (1, 0))]
key: (1, 1) hyp: [(u'<S>', None), (1, (0, 0))]
key: (2, 0) hyp: [(u'<S>', None), (6, None), (3, None)]
key: (2, 1) hyp: [(u'<S>', None), (6, None), (5, (1, 0))]
key: (2, 1) hyp: [(u'<S>', None), (6, None), (1, (0, 0))]
key: (2, 2) hyp: [(u'<S>', None), (5, (1, 0)), (6, (1, 1))]
key: (2, 2) hyp: [(u'<S>', None), (1, (0, 0)), (2, (0, 1))]
key: (3, 0) hyp: [(u'<S>', None), (6, None), (3, None), (7, None)]
key: (3, 1) hyp: [(u'<S>', None), (6, None), (3, None), (1, (0, 0))]
key: (3, 1) hyp: [(u'<S>', None), (6, None), (3, None), (5, (1, 0))]
key: (3, 2) hyp: [(u'<S>', None), (6, None), (1, (0, 0)), (2, (0, 1))]
key: (3, 2) hyp: [(u'<S>', None), (6, None), (5, (1, 0)), (6, (1, 1))]
key: (3, 2) hyp: [(u'<S>', None), (1, (0, 0)), (2, (0, 1)), (6, None)]
key: (3, 3) hyp: [(u'<S>', None), (1, (0, 0)), (2, (0, 1)), (5, (1, 0))]
key: (3, 3) hyp: [(u'<S>', None), (5, (1, 0)

In [39]:
len(output_grid)

66

In [None]:
next_hyps = dumb_tm.dumb_generate(start_hyp, n_best=N_BEST)
next_new_constraint_hyps = dumb_tm.dumb_generate_from_constraints(start_hyp)
len(next_new_constraint_hyps)

In [None]:
next_continued_constraint_hyps = [dumb_tm.dumb_continue_unfinished_constraint(hyp)
                                  for hyp in next_new_constraint_hyps]

In [None]:
len(next_continued_constraint_hyps)

In [None]:
print(next_continued_constraint_hyps[1].unfinished_constraint)

In [None]:
print(next_new_constraint_hyps[3])

In [None]:
next_next_hyps = [dumb_tm.dumb_generate(h, n_best=N_BEST) for h in next_hyps]
next_next_hyps = [h for s in next_next_hyps for h in s]

In [None]:
len(next_next_hyps)

In [None]:
[h.sequence for h in next_next_hyps]

In [None]:
print(next_hyps[1])

In [None]:
#     def __init__(self, token, score, coverage, constraints, payload=None, backpointer=None,
#                  unfinished_constraint=False):
start_hyp = 

In [None]:
def build_hyp(previous_hyp, next_token=None, constraint_idx=None, constraint_token_idx=None)
    """Utility function to create a hypothesis in different ways"""
    pass