Skip to content

Commit

Permalink
[api] Fix singular positive logit (microsoft#167)
Browse files Browse the repository at this point in the history
* Save work

* Correctly handle.
  • Loading branch information
stephenroller committed Jun 22, 2022
1 parent f5442a1 commit 1aa40c2
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
1 change: 1 addition & 0 deletions metaseq/hub_utils.py
Expand Up @@ -628,6 +628,7 @@ def generate(
tokens, scores, distributions
)
prompt_len = lengths[i]

if echo:
# don't cut off prompt
tokens = tokens[: prompt_len + max_tokens[i]]
Expand Down
17 changes: 16 additions & 1 deletion metaseq/search.py
Expand Up @@ -22,7 +22,13 @@ def __init__(self, tgt_dict):
self.stop_on_max_len = False

def step(
self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None
self,
step,
lprobs,
scores,
offset=None,
prev_output_tokens=None,
original_batch_idxs=None,
):
"""Take a single search step.
Expand All @@ -32,6 +38,9 @@ def step(
the model's log-probabilities over the vocabulary at the current step
scores: (bsz x input_beam_size x step)
the historical model scores of each hypothesis up to this point
offset: (bsz x input_beam_size)
the NLL of the prompt. used on the first step to maintain
consistent cumulative sums
prev_output_tokens: (bsz x step)
the previously generated output tokens
original_batch_idxs: (bsz)
Expand Down Expand Up @@ -105,6 +114,7 @@ def step(
step: int,
lprobs,
scores: Optional[Tensor],
offset: Optional[Tensor] = None,
prev_output_tokens: Optional[Tensor] = None,
original_batch_idxs: Optional[Tensor] = None,
):
Expand All @@ -113,6 +123,8 @@ def step(
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
if offset is not None:
lprobs += offset
lprobs = lprobs[:, ::beam_size, :].contiguous()
else:
# make probs contain cumulative scores for each hypothesis
Expand Down Expand Up @@ -198,6 +210,7 @@ def step(
step: int,
lprobs,
scores,
offset: Optional[Tensor] = None,
prev_output_tokens: Optional[Tensor] = None,
original_batch_idxs: Optional[Tensor] = None,
):
Expand Down Expand Up @@ -252,6 +265,8 @@ def step(

if step == 0:
beams_buf = indices_buf.new_zeros(bsz, beam_size)
if offset is not None:
scores_buf.add_(offset)
else:
beams_buf = torch.arange(0, beam_size).to(indices_buf).repeat(bsz, 1)
# make scores cumulative
Expand Down
21 changes: 13 additions & 8 deletions metaseq/sequence_generator.py
Expand Up @@ -256,6 +256,9 @@ def _generate(
# finally, scores is actually stored as the cumulative NLL, but we have
# individual NLL scores right now
scores = scores.cumsum(dim=1)
# the first step of beam search also needs lprobs to be cumulative
# in order to keep the running sum correct
first_offset = scores[:, -1:]

# start from previous timestep because we still have to do beam search
# bookkeeping (i.e. finalize the hypothesis if it's the final token)
Expand Down Expand Up @@ -314,12 +317,14 @@ def _generate(
cand_scores, cand_indices, cand_beams = self.search.step(
# underlying search indexes from first token being generated,
# so we need to account for the size of the prompt.
step - start_step + 1,
lprobs.view(bsz, -1, self.vocab_size),
scores[:, start_step - 1 : step].view(bsz, beam_size, -1),
tokens[:, start_step - 1 : step + 1],
original_batch_idxs,
step=step - start_step + 1,
lprobs=lprobs.view(bsz, -1, self.vocab_size),
scores=scores[:, start_step - 1 : step].view(bsz, beam_size, -1),
offset=first_offset,
prev_output_tokens=tokens[:, start_step - 1 : step + 1],
original_batch_idxs=original_batch_idxs,
)
first_offset = None # reset after the first step

# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
Expand Down Expand Up @@ -445,14 +450,14 @@ def _generate(
# the prompt tokens here for ease of bookkeeping.

# Set the tokens for each beam (can select the same row more than once)
tokens[:, start_step : step + 1] = torch.index_select(
tokens[:, start_step : step + 1], dim=0, index=active_bbsz_idx
tokens[:, start_step:step] = torch.index_select(
tokens[:, start_step:step], dim=0, index=active_bbsz_idx
)
# Select the next token for each of them
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
cand_indices, dim=1, index=active_hypos
)
if step > start_step:
if step >= start_step:
scores[:, start_step:step] = torch.index_select(
scores[:, start_step:step], dim=0, index=active_bbsz_idx
)
Expand Down

0 comments on commit 1aa40c2

Please sign in to comment.