-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upgrade to Transformers 4.40 #1027
Conversation
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
review in progress, adding some comments for now
eos_token_id = eos_token_id[0] if eos_token_id else None | ||
if eos_token_id is None and self.generation_config.eos_token_id is not None: | ||
eos_token_id = self.generation_config.eos_token_id | ||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: add a warning here:
something like:
if generation_config.static_shapes and (generation_config.pad_token_id == generation_config.eos_token_id):
logger.warning(
f"For Gaudi, we pad input_ids with pad_token_id (= {generation_config.pad_token_id}), which is equal to eos_token_id for this model, and EosTokenCriteria stopping criteria is requested, so this option (ignore_eos=False) isn't available. Try setting `--ignore_eos` to False."
)
We can add this in a separate PR later, just making a note/todo here.
eos_token_id = eos_token_id[0] if eos_token_id else None | ||
if eos_token_id is None and self.generation_config.eos_token_id is not None: | ||
eos_token_id = self.generation_config.eos_token_id | ||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same stopping criteria warning as mentioned in greedy()
@@ -2355,6 +2530,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): | |||
"transo_xl", | |||
"xlnet", | |||
"cpm", | |||
"jamba", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this supported?
eos_token_id = eos_token_id[0] if eos_token_id else None | ||
if eos_token_id is None and self.generation_config.eos_token_id is not None: | ||
eos_token_id = self.generation_config.eos_token_id | ||
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as greedy()
context_mask = 1 - (col_indices >= row_indices + diagonal).int().expand_as( | ||
mask | ||
) # Expand to match mask shape | ||
context_mask = (col_indices <= row_indices + diagonal).bool().expand_as(mask) # Expand to match mask shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have a .bool() here instead of .int(). We dont support bool, not sure if this will cause a fallback to cpu, etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address the comments in a different PR.
fill_value=1, | ||
dtype=torch.bool, | ||
) | ||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check why this is removed
What does this PR do?
As per title.
Before submitting