Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of a DialoguePipeline #5516

Merged
merged 26 commits into from
Jul 30, 2020
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a398564
initial commit for pipeline implementation
guillaume-be Jul 1, 2020
c08fb5f
Conversation pipeline tested and working for single & multiple conver…
guillaume-be Jul 4, 2020
29cc23c
Added docstrings for dialogue pipeline
guillaume-be Jul 4, 2020
4c3d5c4
Addition of dialogue pipeline integration tests
guillaume-be Jul 4, 2020
00160a1
Merge remote-tracking branch 'remotes/upstream/master' into conversat…
guillaume-be Jul 4, 2020
629dde9
Delete test_t5.py
guillaume-be Jul 4, 2020
d17cf21
Fixed max code length
guillaume-be Jul 4, 2020
f2c19cb
Updated styling
guillaume-be Jul 4, 2020
ce28bfb
Fixed test broken by formatting tools
guillaume-be Jul 4, 2020
210aeee
Removed unused import
guillaume-be Jul 4, 2020
8c5853d
Added unit test for DialoguePipeline
guillaume-be Jul 4, 2020
0f1e33e
Fixed Tensorflow compatibility
guillaume-be Jul 4, 2020
7d40145
Fixed multi-framework support using framework flag
guillaume-be Jul 4, 2020
f69c4d8
- Fixed docstring
guillaume-be Jul 15, 2020
55f4f6b
- renamed pipeline name from dialogue to conversational
guillaume-be Jul 16, 2020
4371c51
- Updated ConversationalPipeline to accept only active conversations …
guillaume-be Jul 16, 2020
d26242c
- Simplified input tensor conversion
guillaume-be Jul 16, 2020
92c042b
- Updated attention_mask value for Tensorflow compatibility
guillaume-be Jul 16, 2020
da3c68b
- Updated last dialogue reference to conversational & fixed integrati…
guillaume-be Jul 16, 2020
77d1826
Merge remote-tracking branch 'remotes/upstream/master' into conversat…
guillaume-be Jul 23, 2020
8d6288b
Fixed conflict with master
guillaume-be Jul 23, 2020
b295de4
Merge remote-tracking branch 'remotes/upstream/master' into conversat…
guillaume-be Jul 30, 2020
7f6ff5e
Updates following review comments
guillaume-be Jul 30, 2020
d1918be
Updated formatting
guillaume-be Jul 30, 2020
396c4c4
Added Conversation and ConversationalPipeline to the library __init__…
guillaume-be Jul 30, 2020
9734829
Update src/transformers/pipelines.py
guillaume-be Jul 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/main_classes/pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,11 @@ TextGenerationPipeline
==========================================

.. autoclass:: transformers.TextGenerationPipeline


ConversationalPipeline
==========================================

.. autoclass:: transformers.Conversation

.. autoclass:: transformers.ConversationalPipeline
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@

# Pipelines
from .pipelines import (
Conversation,
ConversationalPipeline,
CsvPipelineDataFormat,
FeatureExtractionPipeline,
FillMaskPipeline,
Expand Down
326 changes: 325 additions & 1 deletion src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import os
import pickle
import sys
import uuid
from abc import ABC, abstractmethod
from contextlib import contextmanager
from itertools import chain
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from uuid import UUID

import numpy as np

Expand All @@ -36,7 +38,7 @@
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_base import PaddingStrategy
from .tokenization_utils_base import BatchEncoding, PaddingStrategy


if is_tf_available():
Expand All @@ -51,6 +53,7 @@
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TFAutoModelForCausalLM,
)

if is_torch_available():
Expand Down Expand Up @@ -1895,6 +1898,321 @@ def __call__(
return results


class Conversation:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This abstraction seems a bit weird to me because we are mutating so much of its properties from the pipeline below. i.e. it's not very self-contained. I'm wondering if we should move some of the mutating methods to the class directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created some class functions to avoid direct field mutation - hopefully it goes in the right direction

"""
Utility class containing a conversation and its history. This class is meant to be used as an input to the
:obj:`~transformers.ConversationalPipeline`. The conversation contains a number of utility function to manage the addition of new
user input and generated model responses. A conversation needs to contain an unprocessed user input before being
passed to the :obj:`~transformers.ConversationalPipeline`. This user input is either created when the class is instantiated, or by calling
`append_response("input")` after a conversation turn.

Usage::

conversation = Conversation("Going to the movies tonight - any suggestions?")

# Steps usually performed by the model when generating a response:
# 1. Mark the user input as processed (moved to the history)
conversation.mark_processed()
# 2. Append a mode response
conversation.append_response("The Big lebowski.")

conversation.add_user_input("Is it good?")

Arguments:
text (:obj:`str`, `optional`, defaults to :obj:`None`):
The initial user input to start the conversation.
If :obj:`None`, a user input needs to be provided manually using `add_user_input` before the conversation can begin.
conversation_id (:obj:`uuid.UUID`, `optional`, defaults to :obj:`None`):
Unique identifier for the conversation
If :obj:`None`, the random UUID4 id will be assigned to the conversation.
"""

def __init__(self, text: str = None, conversation_id: UUID = None):
if not conversation_id:
conversation_id = uuid.uuid4()
self.uuid: UUID = conversation_id
self.past_user_inputs: List[str] = []
self.generated_responses: List[str] = []
self.history: List[int] = []
self.new_user_input: Optional[str] = text

def add_user_input(self, text: str, overwrite: bool = False):
"""
Add a user input to the conversation for the next round. This populates the internal `new_user_input` field.

Args:
text: str, the user input for the next conversation round
overwrite: bool, flag indicating if existing and unprocessed user input should be overwritten when this function is called

"""
if self.new_user_input:
if overwrite:
logger.warning(
'User input added while unprocessed input was existing: "{}" was overwritten with: "{}".'.format(
self.new_user_input, text
)
)
self.new_user_input = text
else:
logger.warning(
'User input added while unprocessed input was existing: "{}" new input ignored: "{}". '
"Set `overwrite` to True to overwrite unprocessed user input".format(self.new_user_input, text)
)
else:
self.new_user_input = text

def mark_processed(self):
"""
Mark the conversation as processed (moves the content of `new_user_input` to `past_user_inputs`) and empties the
`new_user_input` field.
"""
if self.new_user_input:
self.past_user_inputs.append(self.new_user_input)
self.new_user_input = None

def append_response(self, response: str):
"""
Append a response to the list of generated responses.

Args:
response: str, the model generated response
"""
self.generated_responses.append(response)

def set_history(self, history: List[int]):
"""
Updates the value of the history of the conversation. The history is represented by a list of `token_ids`. The
history is used by the model to generate responses based on the previous conversation turns.

Args:
history: (list of int), history of tokens provided and generated for this conversation
"""
self.history = history

def __repr__(self):
"""
Generates a string representation of the conversation.

Return:
:obj:`str` or :obj:`Dict`:

Example:
Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114
user >> Going to the movies tonight - any suggestions?
bot >> The Big Lebowski
"""
output = "Conversation id: {} \n".format(self.uuid)
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
output += "user >> {} \n".format(user_input)
output += "bot >> {} \n".format(generated_response)
if self.new_user_input is not None:
output += "user >> {} \n".format(self.new_user_input)
return output


class ConversationalPipeline(Pipeline):
"""
Multi-turn conversational pipeline.

Usage::

conversational_pipeline = pipeline("conversational")

conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?")

conversational_pipeline([conversation_1, conversation_2])

conversation_1.add_user_input("Is it an action movie?")
conversation_2.add_user_input("What is the genre of this book?")

conversational_pipeline([conversation_1, conversation_2])

The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task,
currently: "microsoft/DialoGPT-small", "microsoft/DialoGPT-medium", "microsoft/DialoGPT-large"
See the up-to-date list of available models on
`huggingface.co/models <https://huggingface.co/models?filter=conversational>`__.

Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
"""

def __init__(self, min_length_for_response=32, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
if self.tokenizer.pad_token_id is not None:
self.pad_token_id = self.tokenizer.pad_token_id
else:
self.pad_token_id = self.tokenizer.eos_token_id
self.min_length_for_response = min_length_for_response

def __call__(
self,
conversations: Union[Conversation, List[Conversation]],
clean_up_tokenization_spaces=True,
**generate_kwargs
):
r"""
Args:
*conversations: (list of `:class:Conversation`) Conversations to generate responses for
guillaume-be marked this conversation as resolved.
Show resolved Hide resolved
**generate_kwargs: extra kwargs passed to `self.model.generate`_

Returns:
list of conversations with updated generated responses for those containing a new user input
"""

# Input validation
if isinstance(conversations, list):
for conversation in conversations:
assert isinstance(
conversation, Conversation
), "DialoguePipeline expects a Conversation or list of Conversations as an input"
if conversation.new_user_input is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need the concept of active conversation? Or would it be handled in the application before entering the pipeline anyways?

It seems to be making the code kinda longer and complex (with active_indexes, etc.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the concept of new_user_input because it plays well with the history kept in the conversation itself. Without the new_user_input, the part of the user inputs that need to be tokenized before constructing the model input may be ambiguous.

Because of this I believe the pipeline needs to be able to handle any conversation passed to it. Pipelines without user input should not be passed through the model (and therefore will not participate to the batch creation). The model outputs then need to be reallocated to the right conversation, requiring the indices.

Would you prefer the application to raise a ValueError if an inactive conversation has been passed to the pipeline? This would indeed simplify the code, but be less forgiving to potential users. I am open to both - or maybe you had another idea in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would raise a ValueError indeed in case of inactive conversation, but up to you:)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense (I took this from the Rust implementation that was using a ConversationManager instead of single Conversations passed to it - where inactive conversations have to be dealth with). In this case I agree the responsibility could be passed on to the user. Will update and simplify accordingly

raise ValueError(
"Conversation with UUID {} does not contain new user input to process. "
"Add user inputs with the conversation's `add_user_input` method".format(
type(conversation.uuid)
)
)
assert (
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
elif isinstance(conversations, Conversation):
conversations = [conversations]
else:
raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")

with self.device_placement():

inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
histories = [conversation.history for conversation in conversations]
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
inputs = self._concat_inputs_history(inputs, histories, max_length)

if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]

elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy()

if input_length > 0.9 * max_length:
logger.warning(
"Longest conversation length: {} is bigger than 0.9 * max_length: {}. "
"You might consider trimming the early phase of the conversation".format(input_length, max_length)
)
generated_responses = self.model.generate(
inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
)

cleaned_history = self._clean_padding_history(generated_responses)
output = []
for conversation_index, conversation in enumerate(conversations):
conversation.mark_processed()
conversation.generated_responses.append(
self.tokenizer.decode(
cleaned_history[conversation_index][input_length:],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
)
conversation.set_history(cleaned_history[conversation_index])
output.append(conversation)
if len(output) == 1:
return output[0]
else:
return output

def _parse_and_tokenize(self, *args, **kwargs):
"""
Parse arguments and tokenize, adding an EOS token at the end of the user input
"""
# Parse arguments
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer.batch_encode_plus(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
for input in inputs:
input.append(self.tokenizer.eos_token_id)
return inputs

def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
"""
Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
an input:
- at the end of the concatenated history and new user input, so that all input to the model have the same
length
- at the end of the generated response, as some responses will be longer than others
This method cleans up these padding token so that the history for each conversation is not impacted by the
batching process.
"""
outputs = []
for sequence in generated_tensor:
sequence_tokens = []
is_previous_pad = False
for token in sequence:
if token == self.pad_token_id:
if is_previous_pad:
continue
else:
is_previous_pad = True
else:
is_previous_pad = False
if self.framework == "pt":
sequence_tokens.append(token.item())
else:
sequence_tokens.append(int(token.numpy()))

outputs.append(sequence_tokens)
return outputs

def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int):
"""
Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context
"""
outputs = []
for new_input, history in zip(inputs, histories):
if history is not None:
new_input = history + new_input
if len(new_input) > max_length - self.min_length_for_response:
cutoff_eos_index = 0
while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response:
if cutoff_eos_index >= len(new_input):
break
cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id)
if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
break
else:
new_input = new_input[cutoff_eos_index + 1 :]
outputs.append(new_input)
max_len = max([len(item) for item in outputs])
outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs]
outputs = BatchEncoding(
{"input_ids": outputs, "attention_mask": [1] * len(outputs)}, tensor_type=self.framework
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly outputs is already list of list of ints before this line so why do we pass it to self.tokenizer? Is it because it's the simplest way to tensorize to pytorch vs tf? (if it is, the intent is not super explicit)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - this was the intent. I didn't find a utility class in the library that would conveniently convert a list of integers to a Torch or Tensorflow tensor the way this class did. Is there a better way?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you create a helper to do this in this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually use the BatchEncoding directly from the generated input_ids. Cleans up quite a bit and better represents the intent, thank you for the suggestion

return outputs


# Register all the supported tasks here
SUPPORTED_TASKS = {
"feature-extraction": {
Expand Down Expand Up @@ -1979,6 +2297,12 @@ def __call__(
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
},
},
"conversational": {
"impl": ConversationalPipeline,
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
"pt": AutoModelForCausalLM if is_torch_available() else None,
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
},
}


Expand Down
Loading