-
Notifications
You must be signed in to change notification settings - Fork 25.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Addition of a DialoguePipeline #5516
Addition of a DialoguePipeline #5516
Conversation
Addition of input processing and history concatenation
…ion_pipeline # Conflicts: # tests/test_pipelines.py
This is a back-port of guillaume-be/rust-bert#57. I did not implement the ConversationManager as I felt it did not quite fit the general API of this library. I however added the concept of print(conversation)
(ps: note that this example is the response of |
Codecov Report
@@ Coverage Diff @@
## master #5516 +/- ##
==========================================
+ Coverage 78.35% 79.85% +1.49%
==========================================
Files 146 146
Lines 26454 26568 +114
==========================================
+ Hits 20729 21215 +486
+ Misses 5725 5353 -372
Continue to review full report at Codecov.
|
src/transformers/pipelines.py
Outdated
|
||
conversation_1.add_user_input("Is it an action movie?") | ||
|
||
conversation_pipeline([conversation_1, conversation_2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conversation_pipeline([conversation_1, conversation_2]) | |
dialogue_pipeline([conversation_1, conversation_2]) |
src/transformers/pipelines.py
Outdated
self.pad_token_id = self.tokenizer.eos_token_id | ||
self.min_response_allowed_length = kwargs.get("min_response_allowed_length", 32) | ||
|
||
def __call__(self, *args, clean_up_tokenization_spaces=True, **generate_kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we call args
=> conversations
?
src/transformers/pipelines.py
Outdated
[conversation.new_user_input for conversation in active_conversations] | ||
) | ||
histories = [conversation.history for conversation in active_conversations] | ||
max_length = generate_kwargs.get("max_length", 1000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I would prefer to not set the default max_length
to 1000 here. The user can set the default value for each model individually in the model's config (under task specific params) => compare for example with XLNet here: https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_length = generate_kwargs.get("max_length", 1000) | |
max_length = generate_kwargs.get("max_length", self.model.config.max_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree the value of 1000 is arbitrary (taken from the illustrative example in the model card). The issue is that the DialoGPT
configuration files do not set a max_length
. To my understanding, this means that without specifying it to the generate
method, it will be set to the GPT2 default, that is 20. This seems very low for a conversation pipeline as the input alone is likely to exceed this value. I am not sure defaulting to the configuration value is going to be a good user experience. Maybe the way to go is to force the user to provide a value? Or maybe update the configuration of the DialoGPT
configuration files?
src/transformers/pipelines.py
Outdated
input_length, max_length | ||
) | ||
) | ||
generate_kwargs["max_length"] = max_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generate_kwargs["max_length"] = max_length |
src/transformers/pipelines.py
Outdated
|
||
cleaned_history = self._clean_padding_history(generated_responses) | ||
if isinstance(args[0], Conversation): | ||
args[0].mark_processed() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to rename args
to conversations
src/transformers/pipelines.py
Outdated
Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context | ||
""" | ||
outputs = [] | ||
for input, history in zip(inputs, histories): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! I like it
src/transformers/pipelines.py
Outdated
on the associated CUDA device id. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think as it is implemented now dialogue_pipeline = pipeline("dialogue", min_response_allowed_length=32)
would throw an error because it is passed to super().__init__(*args, **kwargs)
=> can we just change it to:
def __init__(self, min_respones_allowed_length=32, *args, **kwargs):
super().__init__(*args, **kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and maybe the name min_length_for_response
is better here. The word allowed
is confusing me a bit. What do you think @guillaume-be ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch and agreed - will update
tests/test_pipelines.py
Outdated
@require_torch | ||
def test_integration_torch_dialogue(self): | ||
# When | ||
nlp = pipeline(task="dialogue", device=DEFAULT_DEVICE_NUM) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe pass min_response_allowed_length
or (IMO clearer name min_length_for_response
) here to test it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the PR! Thanks a lot @guillaume-be!
This makes the dialogue pipeline a bit different from the other pipelines in that it expects a Conversation
object instead of a string, but that's OK IMO.
One other option would be to integrate the conversation class under-the-hood into the DialoguePipeline
so that the user would always either input a string or a list of strings. This way the user would not have to use the predefined Conversation
object, but could just input strings. The advantage is that we don't need to expose Conversation
this way - the disadvantage is that we would need an additional function for DialoguePipeline
that shows the current conversation.
IMO, the logic / design is good as it is now. Since Dialogue is a special pipeline, the user will have to create a Conversation
object first - which is OK for me. What do you think @julien-c @mfuntowicz (also considering the connection to the API?)
One thing which we should still add here @guillaume-be is to add Conversation
(or DialoguePipelineConversation
so it's clear that the class is related to pipelines) to __init__
so that the user can import it directly from transformers
(we more or less expose all classes in transformers
as far as I know).
- Added `min_length_for_response` as an initialization parameter - Renamed `*args` to `conversations`, `conversations` being a `Conversation` or a `List[Conversation]` - Updated truncation to truncate entire segments of conversations, instead of cutting in the middle of a user/bot input
- removed hardcoded default value of 1000 and use config.max_length instead - added `append_response` and `set_history` method to the Conversation class to avoid direct fields mutation - fixed bug in history truncation method
…(otherwise a ValueError is raised)
LGTM |
…ion_pipeline # Conflicts: # src/transformers/pipelines.py
@julien-c @patrickvonplaten I believe all comments have been addressed - please let me know if I have missed anything. Just resolved the conflict with master. Getting an error with |
Given that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really cool! I left a few comments, as I think there's a few remaining bugs.
I think this would greatly benefit from having a documentation, which would document both the pipeline and the Conversation
. The pipeline already has some documentation, but it would need to be added to the pipelines.rst
file. You could add it beneath the GenerationPipeline
, alongside a bit of docs for the Conversation
class.
src/transformers/pipelines.py
Outdated
"tf": TFAutoModelWithLMHead if is_tf_available() else None, | ||
"pt": AutoModelWithLMHead if is_torch_available() else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick, those two classes are deprecated
"tf": TFAutoModelWithLMHead if is_tf_available() else None, | |
"pt": AutoModelWithLMHead if is_torch_available() else None, | |
"tf": TFAutoModelForCausalLM if is_tf_available() else None, | |
"pt": AutoModelForCausalLM if is_torch_available() else None, |
src/transformers/pipelines.py
Outdated
conversational_pipeline = pipeline("conversational") | ||
|
||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") | ||
conversation_2 = Conversation("What's the last book you have read?") | ||
|
||
conversational_pipeline([conversation_1, conversation_2]) | ||
|
||
conversation_1.add_user_input("Is it an action movie?") | ||
|
||
conversational_pipeline([conversation_1, conversation_2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example does not work for me. It fails with the following:
ValueError: Conversation with UUID <class 'uuid.UUID'> does not contain new user input to process. Add user inputs with the conversation's `add_user_input` method
A user input must be added to the second conversation as well for that example to work.
conversational_pipeline = pipeline("conversational") | |
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") | |
conversation_2 = Conversation("What's the last book you have read?") | |
conversational_pipeline([conversation_1, conversation_2]) | |
conversation_1.add_user_input("Is it an action movie?") | |
conversational_pipeline([conversation_1, conversation_2]) | |
conversational_pipeline = pipeline("conversational") | |
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") | |
conversation_2 = Conversation("What's the last book you have read?") | |
conversational_pipeline([conversation_1, conversation_2]) | |
conversation_1.add_user_input("Is it an action movie?") | |
conversation_2.add_user_input("What is the genre of this book?") | |
conversational_pipeline([conversation_1, conversation_2]) |
src/transformers/pipelines.py
Outdated
Usage:: | ||
conversational_pipeline = pipeline("conversational") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a line return to be correctly rendered in the docs
Usage:: | |
conversational_pipeline = pipeline("conversational") | |
Usage:: | |
conversational_pipeline = pipeline("conversational") |
src/transformers/pipelines.py
Outdated
Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context | ||
""" | ||
outputs = [] | ||
for input, history in zip(inputs, histories): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to change input
to something not shadowing the built-in input
?
src/transformers/pipelines.py
Outdated
cutoff_eos_index = input[cutoff_eos_index:].index(self.tokenizer.eos_token_id) | ||
if cutoff_eos_index == 0 or cutoff_eos_index == len(input) - 1: | ||
break | ||
else: | ||
input = input[cutoff_eos_index + 1 :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also break when the cutoff_eos_index is larger than the length of the remaining input. Otherwise it fails because input[cutoff_eos_index:]
returns an empty list, which cannot be indexed with self.tokenizer.eos_token_id
.
An easy fix is the following, which could probably be made cleaner.
cutoff_eos_index = input[cutoff_eos_index:].index(self.tokenizer.eos_token_id) | |
if cutoff_eos_index == 0 or cutoff_eos_index == len(input) - 1: | |
break | |
else: | |
input = input[cutoff_eos_index + 1 :] | |
if cutoff_eos_index >= len(input): | |
break | |
cutoff_eos_index = input[cutoff_eos_index:].index(self.tokenizer.eos_token_id) | |
if cutoff_eos_index == 0 or cutoff_eos_index == len(input) - 1: | |
break | |
else: | |
input = input[cutoff_eos_index + 1 :] |
…ion_pipeline # Conflicts: # src/transformers/pipelines.py # tests/test_pipelines.py
…, addition of docstrings for Conversation, added both to the docs
@LysandreJik Thank you very much for the review. Good catch on the behaviour of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for iterating @guillaume-be!
Thanks for the PR @guillaume-be |
Updated docsting following review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@sgugger Thank you for the review - was indeed a typo on my end. The tests got triggered again and unfortunately a hash verification on torch fails. Could you please restart the build if you have a chance? |
This is awesome, congrats everyone on shipping this! 🔥 |
|
Will take a look. |
Should be fixed in #7970 |
microsoft/DialoGPT-medium