Skip to content

Clearer indication for overridden method in generation #12212

@ktangri

Description

@ktangri

The expectation for the prepare_inputs_for_generation function to be overridden can be made clearer by changing

def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
"""
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the
generate method.
"""
return {"input_ids": input_ids}

to raise a NotImplementedError that provides the information mentioned in the function's comment.

@patrickvonplaten

Metadata

Metadata

Labels

WIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions