-
Notifications
You must be signed in to change notification settings - Fork 25.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Addition of a DialoguePipeline #5516
Changes from 13 commits
a398564
c08fb5f
29cc23c
4c3d5c4
00160a1
629dde9
d17cf21
f2c19cb
ce28bfb
210aeee
8c5853d
0f1e33e
7d40145
f69c4d8
55f4f6b
4371c51
d26242c
92c042b
da3c68b
77d1826
8d6288b
b295de4
7f6ff5e
d1918be
396c4c4
9734829
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -20,11 +20,13 @@ | |||||||||
import os | ||||||||||
import pickle | ||||||||||
import sys | ||||||||||
import uuid | ||||||||||
from abc import ABC, abstractmethod | ||||||||||
from contextlib import contextmanager | ||||||||||
from itertools import chain | ||||||||||
from os.path import abspath, exists | ||||||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union | ||||||||||
from uuid import UUID | ||||||||||
|
||||||||||
import numpy as np | ||||||||||
|
||||||||||
|
@@ -63,7 +65,6 @@ | |||||||||
from .modeling_utils import PreTrainedModel | ||||||||||
from .modeling_tf_utils import TFPreTrainedModel | ||||||||||
|
||||||||||
|
||||||||||
logger = logging.getLogger(__name__) | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -1664,6 +1665,270 @@ def __call__( | |||||||||
return results | ||||||||||
|
||||||||||
|
||||||||||
class Conversation: | ||||||||||
def __init__(self, text: str = None, conversation_id: UUID = None): | ||||||||||
if not conversation_id: | ||||||||||
conversation_id = uuid.uuid4() | ||||||||||
self.uuid: UUID = conversation_id | ||||||||||
self.past_user_inputs: List[str] = [] | ||||||||||
self.generated_responses: List[str] = [] | ||||||||||
self.history: List[int] = [] | ||||||||||
self.new_user_input: Optional[str] = text | ||||||||||
|
||||||||||
def add_user_input(self, text: str, overwrite: bool = False): | ||||||||||
if self.new_user_input: | ||||||||||
if overwrite: | ||||||||||
logger.warning( | ||||||||||
'User input added while unprocessed input was existing: "{}" was overwritten with: "{}".'.format( | ||||||||||
self.new_user_input, text | ||||||||||
) | ||||||||||
) | ||||||||||
self.new_user_input = text | ||||||||||
else: | ||||||||||
logger.warning( | ||||||||||
'User input added while unprocessed input was existing: "{}" new input ignored: "{}". ' | ||||||||||
"Set `overwrite` to True to overwrite unprocessed user input".format(self.new_user_input, text) | ||||||||||
) | ||||||||||
else: | ||||||||||
self.new_user_input = text | ||||||||||
|
||||||||||
def mark_processed(self): | ||||||||||
if self.new_user_input: | ||||||||||
self.past_user_inputs.append(self.new_user_input) | ||||||||||
self.new_user_input = None | ||||||||||
|
||||||||||
def __repr__(self): | ||||||||||
output = "Conversation id: {} \n".format(self.uuid) | ||||||||||
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses): | ||||||||||
output += "user >> {} \n".format(user_input) | ||||||||||
output += "bot >> {} \n".format(generated_response) | ||||||||||
if self.new_user_input is not None: | ||||||||||
output += "user >> {} \n".format(self.new_user_input) | ||||||||||
return output | ||||||||||
|
||||||||||
|
||||||||||
class DialoguePipeline(Pipeline): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||
""" | ||||||||||
Multi-turn dialogue pipeline. | ||||||||||
|
||||||||||
Usage:: | ||||||||||
dialogue_pipeline = pipeline("dialogue") | ||||||||||
|
||||||||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") | ||||||||||
conversation_2 = Conversation("What's the last book you have read?") | ||||||||||
|
||||||||||
conversation_pipeline([conversation_1, conversation_2]) | ||||||||||
guillaume-be marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
conversation_1.add_user_input("Is it an action movie?") | ||||||||||
|
||||||||||
conversation_pipeline([conversation_1, conversation_2]) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
The models that this pipeline can use are models that have been fine-tuned on a multi-turn dialogue task, | ||||||||||
currently: "microsoft/DialoGPT-small", "microsoft/DialoGPT-medium", "microsoft/DialoGPT-large" | ||||||||||
See the up-to-date list of available models on | ||||||||||
`huggingface.co/models <https://huggingface.co/models?filter=conversational>`__. | ||||||||||
|
||||||||||
Arguments: | ||||||||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`): | ||||||||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string | ||||||||||
checkpoint identifier or an actual pre-trained model inheriting from | ||||||||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for | ||||||||||
TensorFlow. | ||||||||||
If :obj:`None`, the default of the pipeline will be loaded. | ||||||||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`): | ||||||||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`, | ||||||||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from | ||||||||||
:class:`~transformers.PreTrainedTokenizer`. | ||||||||||
If :obj:`None`, the default of the pipeline will be loaded. | ||||||||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`): | ||||||||||
Model card attributed to the model for this pipeline. | ||||||||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`): | ||||||||||
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be | ||||||||||
installed. | ||||||||||
If no framework is specified, will default to the one currently installed. If no framework is specified | ||||||||||
and both frameworks are installed, will default to PyTorch. | ||||||||||
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`): | ||||||||||
Reference to the object in charge of parsing supplied pipeline parameters. | ||||||||||
device (:obj:`int`, `optional`, defaults to :obj:`-1`): | ||||||||||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model | ||||||||||
on the associated CUDA device id. | ||||||||||
""" | ||||||||||
|
||||||||||
def __init__(self, *args, **kwargs): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think as it is implemented now def __init__(self, min_respones_allowed_length=32, *args, **kwargs):
super().__init__(*args, **kwargs) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and maybe the name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch and agreed - will update |
||||||||||
super().__init__(*args, **kwargs) | ||||||||||
assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set" | ||||||||||
if self.tokenizer.pad_token_id is not None: | ||||||||||
self.pad_token_id = self.tokenizer.pad_token_id | ||||||||||
else: | ||||||||||
self.pad_token_id = self.tokenizer.eos_token_id | ||||||||||
self.min_response_allowed_length = kwargs.get("min_response_allowed_length", 32) | ||||||||||
|
||||||||||
def __call__(self, *args, clean_up_tokenization_spaces=True, **generate_kwargs): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we call |
||||||||||
r""" | ||||||||||
Args: | ||||||||||
*args: (list of `:class:Conversation`) Conversations to generate responses for | ||||||||||
**generate_kwargs: extra kwargs passed to `self.model.generate`_ | ||||||||||
|
||||||||||
Returns: | ||||||||||
list of conversations with updated generated responses for those containing a new user input | ||||||||||
""" | ||||||||||
|
||||||||||
active_conversations = [] | ||||||||||
# Input validation | ||||||||||
if isinstance(args[0], list): | ||||||||||
active_conversations_indices = dict() | ||||||||||
active_index = 0 | ||||||||||
for conversation_index, conversation in enumerate(args[0]): | ||||||||||
assert isinstance( | ||||||||||
conversation, Conversation | ||||||||||
), "DialoguePipeline expects a Conversation or list of Conversations as an input" | ||||||||||
if conversation.new_user_input is None: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need the concept of active conversation? Or would it be handled in the application before entering the pipeline anyways? It seems to be making the code kinda longer and complex (with active_indexes, etc.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the concept of Because of this I believe the pipeline needs to be able to handle any conversation passed to it. Pipelines without user input should not be passed through the model (and therefore will not participate to the batch creation). The model outputs then need to be reallocated to the right conversation, requiring the indices. Would you prefer the application to raise a ValueError if an inactive conversation has been passed to the pipeline? This would indeed simplify the code, but be less forgiving to potential users. I am open to both - or maybe you had another idea in mind? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Personally I would raise a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense (I took this from the Rust implementation that was using a ConversationManager instead of single Conversations passed to it - where inactive conversations have to be dealth with). In this case I agree the responsibility could be passed on to the user. Will update and simplify accordingly |
||||||||||
logger.warning( | ||||||||||
"Conversation with id {} does not contain new user input and will not be updated".format( | ||||||||||
conversation.uuid | ||||||||||
) | ||||||||||
) | ||||||||||
else: | ||||||||||
active_conversations_indices[conversation_index] = active_index | ||||||||||
active_conversations.append(conversation) | ||||||||||
active_index += 1 | ||||||||||
assert ( | ||||||||||
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None | ||||||||||
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input" | ||||||||||
elif isinstance(args[0], Conversation): | ||||||||||
active_conversations.append(args[0]) | ||||||||||
else: | ||||||||||
raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input") | ||||||||||
if len(active_conversations) > 0: | ||||||||||
with self.device_placement(): | ||||||||||
|
||||||||||
inputs = self._parse_and_tokenize( | ||||||||||
[conversation.new_user_input for conversation in active_conversations] | ||||||||||
) | ||||||||||
histories = [conversation.history for conversation in active_conversations] | ||||||||||
max_length = generate_kwargs.get("max_length", 1000) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I would prefer to not set the default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree the value of 1000 is arbitrary (taken from the illustrative example in the model card). The issue is that the |
||||||||||
inputs = self._concat_inputs_history(inputs, histories, max_length) | ||||||||||
|
||||||||||
if self.framework == "pt": | ||||||||||
inputs = self.ensure_tensor_on_device(**inputs) | ||||||||||
input_length = inputs["input_ids"].shape[-1] | ||||||||||
|
||||||||||
elif self.framework == "tf": | ||||||||||
input_length = tf.shape(inputs["input_ids"])[-1].numpy() | ||||||||||
|
||||||||||
if input_length > 0.9 * max_length: | ||||||||||
logger.warning( | ||||||||||
"Longest conversation length: {} is bigger than 0.9 * max_length: {}. " | ||||||||||
"You might consider trimming the early phase of the conversation".format( | ||||||||||
input_length, max_length | ||||||||||
) | ||||||||||
) | ||||||||||
generate_kwargs["max_length"] = max_length | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
generated_responses = self.model.generate( | ||||||||||
inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs, | ||||||||||
) | ||||||||||
|
||||||||||
cleaned_history = self._clean_padding_history(generated_responses) | ||||||||||
if isinstance(args[0], Conversation): | ||||||||||
args[0].mark_processed() | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be nice to rename |
||||||||||
args[0].generated_responses.append( | ||||||||||
self.tokenizer.decode( | ||||||||||
cleaned_history[0][input_length:], | ||||||||||
skip_special_tokens=True, | ||||||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||||||||
) | ||||||||||
) | ||||||||||
args[0].history = cleaned_history[0] | ||||||||||
output = args[0] | ||||||||||
else: | ||||||||||
output = [] | ||||||||||
for conversation_index, conversation in enumerate(args[0]): | ||||||||||
if conversation_index in active_conversations_indices: | ||||||||||
conversation.mark_processed() | ||||||||||
active_index = active_conversations_indices[conversation_index] | ||||||||||
conversation.generated_responses.append( | ||||||||||
self.tokenizer.decode( | ||||||||||
cleaned_history[active_index][input_length:], | ||||||||||
skip_special_tokens=True, | ||||||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces, | ||||||||||
) | ||||||||||
) | ||||||||||
conversation.history = cleaned_history[active_index] | ||||||||||
output.append(conversation) | ||||||||||
return output | ||||||||||
else: | ||||||||||
logger.warning( | ||||||||||
"No active conversation provided for generating response. " | ||||||||||
"Add user input to the conversations by calling `conversation.add_user_input(...)`" | ||||||||||
) | ||||||||||
return args[0] | ||||||||||
|
||||||||||
def _parse_and_tokenize(self, *args, **kwargs): | ||||||||||
""" | ||||||||||
Parse arguments and tokenize, adding an EOS token at the end of the user input | ||||||||||
""" | ||||||||||
# Parse arguments | ||||||||||
inputs = self._args_parser(*args, **kwargs) | ||||||||||
inputs = self.tokenizer.batch_encode_plus(inputs, add_special_tokens=False, padding=False).get("input_ids", []) | ||||||||||
for input in inputs: | ||||||||||
input.append(self.tokenizer.eos_token_id) | ||||||||||
return inputs | ||||||||||
|
||||||||||
def _clean_padding_history(self, generated_tensor) -> List[List[int]]: | ||||||||||
""" | ||||||||||
Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as | ||||||||||
an input: | ||||||||||
- at the end of the concatenated history and new user input, so that all input to the model have the same | ||||||||||
length | ||||||||||
- at the end of the generated response, as some responses will be longer than others | ||||||||||
This method cleans up these padding token so that the history for each conversation is not impacted by the | ||||||||||
batching process. | ||||||||||
""" | ||||||||||
outputs = [] | ||||||||||
for sequence in generated_tensor: | ||||||||||
sequence_tokens = [] | ||||||||||
is_previous_pad = False | ||||||||||
for token in sequence: | ||||||||||
if token == self.pad_token_id: | ||||||||||
if is_previous_pad: | ||||||||||
continue | ||||||||||
else: | ||||||||||
is_previous_pad = True | ||||||||||
else: | ||||||||||
is_previous_pad = False | ||||||||||
if self.framework == "pt": | ||||||||||
sequence_tokens.append(token.item()) | ||||||||||
else: | ||||||||||
sequence_tokens.append(int(token.numpy())) | ||||||||||
|
||||||||||
outputs.append(sequence_tokens) | ||||||||||
return outputs | ||||||||||
|
||||||||||
def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int): | ||||||||||
""" | ||||||||||
Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context | ||||||||||
""" | ||||||||||
outputs = [] | ||||||||||
for input, history in zip(inputs, histories): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice! I like it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to change |
||||||||||
if history is not None: | ||||||||||
concatenated_input = history + input | ||||||||||
if len(concatenated_input) > max_length - self.min_response_allowed_length: | ||||||||||
concatenated_input = concatenated_input[ | ||||||||||
len(concatenated_input) - max_length + self.min_response_allowed_length : | ||||||||||
] | ||||||||||
outputs.append(concatenated_input) | ||||||||||
else: | ||||||||||
if len(input) > max_length - self.min_response_allowed_length: | ||||||||||
input = input[len(input) - max_length + self.min_response_allowed_length :] | ||||||||||
outputs.append(input) | ||||||||||
max_len = max([len(item) for item in outputs]) | ||||||||||
outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs] | ||||||||||
outputs = self.tokenizer.batch_encode_plus( | ||||||||||
outputs, add_special_tokens=False, is_pretokenized=True, return_tensors=self.framework, padding=False | ||||||||||
) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly outputs is already list of list of ints before this line so why do we pass it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes - this was the intent. I didn't find a utility class in the library that would conveniently convert a list of integers to a Torch or Tensorflow tensor the way this class did. Is there a better way? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you create a helper to do this in this file? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually use the |
||||||||||
return outputs | ||||||||||
|
||||||||||
|
||||||||||
# Register all the supported tasks here | ||||||||||
SUPPORTED_TASKS = { | ||||||||||
"feature-extraction": { | ||||||||||
|
@@ -1738,6 +2003,12 @@ def __call__( | |||||||||
"pt": AutoModelWithLMHead if is_torch_available() else None, | ||||||||||
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}}, | ||||||||||
}, | ||||||||||
"dialogue": { | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we call this "conversational" to be compatible out of the box with the models that are already tagged on https://huggingface.co/models?filter=conversational ? I think dialogue might be the better term for the NLP research community but conversational is better for ML practitioners (who are the target audience of pipelines) |
||||||||||
"impl": DialoguePipeline, | ||||||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None, | ||||||||||
"pt": AutoModelWithLMHead if is_torch_available() else None, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick, those two classes are deprecated
Suggested change
|
||||||||||
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}}, | ||||||||||
}, | ||||||||||
} | ||||||||||
|
||||||||||
|
||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This abstraction seems a bit weird to me because we are mutating so much of its properties from the pipeline below. i.e. it's not very self-contained. I'm wondering if we should move some of the mutating methods to the class directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created some class functions to avoid direct field mutation - hopefully it goes in the right direction