Skip to content

Commit

Permalink
refactor: unify the sequential run code logic between beam search and…
Browse files Browse the repository at this point in the history
… contrastive search
  • Loading branch information
Saibo Geng committed Nov 13, 2023
1 parent a4087dc commit b00efa3
Showing 1 changed file with 11 additions and 37 deletions.
48 changes: 11 additions & 37 deletions src/transformers/generation/utils.py
Expand Up @@ -2236,8 +2236,7 @@ def contrastive_search(
model_kwargs["past_key_values"] = tuple(new_key_values)

if sequential:
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
all_last_hstates, all_hstates, all_logits = [], [], []
all_outputs = []
for i in range(top_k):
# compute the candidate tokens by the language model and collect their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
Expand All @@ -2248,33 +2247,8 @@ def contrastive_search(
output_hidden_states=True,
output_attentions=output_attentions,
)
for key in all_outputs:
all_outputs[key].append(outputs[key])

if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states

else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states

all_last_hstates.append(torch.squeeze(next_hidden, 0))
all_hstates.append(full_hidden_states)
all_logits.append(outputs.logits[:, -1, :])

# stack hidden states
next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0)
final_full_hstates = [0 for i in range(len(full_hidden_states))]
for layer in range(len(full_hidden_states)):
final_full_hstates[layer] = torch.stack(
[torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0
)
full_hidden_states = tuple(final_full_hstates)

# stack logits
logits = torch.cat(all_logits, dim=0)

all_outputs.append(outputs)
outputs = stack_model_outputs(all_outputs)
else:
# compute the candidate tokens by the language model and collect their hidden_states
# assembles top_k_ids into batch of size k
Expand All @@ -2286,15 +2260,15 @@ def contrastive_search(
output_hidden_states=True,
output_attentions=output_attentions,
)
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states

logits = outputs.logits[:, -1, :]
logits = outputs.logits[:, -1, :]

context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)

Expand Down

0 comments on commit b00efa3

Please sign in to comment.