From 1aa40c2ae65d6643785a6663ffb5ad71e6931a36 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 21 Jun 2022 21:46:13 -0400 Subject: [PATCH] [api] Fix singular positive logit (#167) * Save work * Correctly handle. --- metaseq/hub_utils.py | 1 + metaseq/search.py | 17 ++++++++++++++++- metaseq/sequence_generator.py | 21 +++++++++++++-------- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/metaseq/hub_utils.py b/metaseq/hub_utils.py index 7eac74bbe9..af36798ff5 100644 --- a/metaseq/hub_utils.py +++ b/metaseq/hub_utils.py @@ -628,6 +628,7 @@ def generate( tokens, scores, distributions ) prompt_len = lengths[i] + if echo: # don't cut off prompt tokens = tokens[: prompt_len + max_tokens[i]] diff --git a/metaseq/search.py b/metaseq/search.py index c0cd2ce70c..66ddd7b5f1 100644 --- a/metaseq/search.py +++ b/metaseq/search.py @@ -22,7 +22,13 @@ def __init__(self, tgt_dict): self.stop_on_max_len = False def step( - self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None + self, + step, + lprobs, + scores, + offset=None, + prev_output_tokens=None, + original_batch_idxs=None, ): """Take a single search step. @@ -32,6 +38,9 @@ def step( the model's log-probabilities over the vocabulary at the current step scores: (bsz x input_beam_size x step) the historical model scores of each hypothesis up to this point + offset: (bsz x input_beam_size) + the NLL of the prompt. used on the first step to maintain + consistent cumulative sums prev_output_tokens: (bsz x step) the previously generated output tokens original_batch_idxs: (bsz) @@ -105,6 +114,7 @@ def step( step: int, lprobs, scores: Optional[Tensor], + offset: Optional[Tensor] = None, prev_output_tokens: Optional[Tensor] = None, original_batch_idxs: Optional[Tensor] = None, ): @@ -113,6 +123,8 @@ def step( if step == 0: # at the first step all hypotheses are equally likely, so use # only the first beam + if offset is not None: + lprobs += offset lprobs = lprobs[:, ::beam_size, :].contiguous() else: # make probs contain cumulative scores for each hypothesis @@ -198,6 +210,7 @@ def step( step: int, lprobs, scores, + offset: Optional[Tensor] = None, prev_output_tokens: Optional[Tensor] = None, original_batch_idxs: Optional[Tensor] = None, ): @@ -252,6 +265,8 @@ def step( if step == 0: beams_buf = indices_buf.new_zeros(bsz, beam_size) + if offset is not None: + scores_buf.add_(offset) else: beams_buf = torch.arange(0, beam_size).to(indices_buf).repeat(bsz, 1) # make scores cumulative diff --git a/metaseq/sequence_generator.py b/metaseq/sequence_generator.py index b73286db2f..c5334a90ed 100644 --- a/metaseq/sequence_generator.py +++ b/metaseq/sequence_generator.py @@ -256,6 +256,9 @@ def _generate( # finally, scores is actually stored as the cumulative NLL, but we have # individual NLL scores right now scores = scores.cumsum(dim=1) + # the first step of beam search also needs lprobs to be cumulative + # in order to keep the running sum correct + first_offset = scores[:, -1:] # start from previous timestep because we still have to do beam search # bookkeeping (i.e. finalize the hypothesis if it's the final token) @@ -314,12 +317,14 @@ def _generate( cand_scores, cand_indices, cand_beams = self.search.step( # underlying search indexes from first token being generated, # so we need to account for the size of the prompt. - step - start_step + 1, - lprobs.view(bsz, -1, self.vocab_size), - scores[:, start_step - 1 : step].view(bsz, beam_size, -1), - tokens[:, start_step - 1 : step + 1], - original_batch_idxs, + step=step - start_step + 1, + lprobs=lprobs.view(bsz, -1, self.vocab_size), + scores=scores[:, start_step - 1 : step].view(bsz, beam_size, -1), + offset=first_offset, + prev_output_tokens=tokens[:, start_step - 1 : step + 1], + original_batch_idxs=original_batch_idxs, ) + first_offset = None # reset after the first step # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), @@ -445,14 +450,14 @@ def _generate( # the prompt tokens here for ease of bookkeeping. # Set the tokens for each beam (can select the same row more than once) - tokens[:, start_step : step + 1] = torch.index_select( - tokens[:, start_step : step + 1], dim=0, index=active_bbsz_idx + tokens[:, start_step:step] = torch.index_select( + tokens[:, start_step:step], dim=0, index=active_bbsz_idx ) # Select the next token for each of them tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( cand_indices, dim=1, index=active_hypos ) - if step > start_step: + if step >= start_step: scores[:, start_step:step] = torch.index_select( scores[:, start_step:step], dim=0, index=active_bbsz_idx )