Skip to content

Commit

Permalink
Addition of a DialoguePipeline (#5516)
Browse files Browse the repository at this point in the history
* initial commit for pipeline implementation

Addition of input processing and history concatenation

* Conversation pipeline tested and working for single & multiple conversation inputs

* Added docstrings for dialogue pipeline

* Addition of dialogue pipeline integration tests

* Delete test_t5.py

* Fixed max code length

* Updated styling

* Fixed test broken by formatting tools

* Removed unused import

* Added unit test for DialoguePipeline

* Fixed Tensorflow compatibility

* Fixed multi-framework support using framework flag

* - Fixed docstring
- Added `min_length_for_response` as an initialization parameter
- Renamed `*args` to `conversations`, `conversations` being a `Conversation` or a `List[Conversation]`
- Updated truncation to truncate entire segments of conversations, instead of cutting in the middle of a user/bot input

* - renamed pipeline name from dialogue to conversational
- removed hardcoded default value of 1000 and use config.max_length instead
- added `append_response` and `set_history` method to the Conversation class to avoid direct fields mutation
- fixed bug in history truncation method

* - Updated ConversationalPipeline to accept only active conversations (otherwise a ValueError is raised)

* - Simplified input tensor conversion

* - Updated attention_mask value for Tensorflow compatibility

* - Updated last dialogue reference to conversational & fixed integration tests

* Fixed conflict with master

* Updates following review comments

* Updated formatting

* Added Conversation and ConversationalPipeline to the library __init__, addition of docstrings for Conversation, added both to the docs

* Update src/transformers/pipelines.py

Updated docsting following review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
guillaume-be and sgugger committed Jul 30, 2020
1 parent ec02674 commit e642c78
Show file tree
Hide file tree
Showing 4 changed files with 428 additions and 3 deletions.
8 changes: 8 additions & 0 deletions docs/source/main_classes/pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,11 @@ TextGenerationPipeline
==========================================

.. autoclass:: transformers.TextGenerationPipeline


ConversationalPipeline
==========================================

.. autoclass:: transformers.Conversation

.. autoclass:: transformers.ConversationalPipeline
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@

# Pipelines
from .pipelines import (
Conversation,
ConversationalPipeline,
CsvPipelineDataFormat,
FeatureExtractionPipeline,
FillMaskPipeline,
Expand Down
326 changes: 325 additions & 1 deletion src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import os
import pickle
import sys
import uuid
from abc import ABC, abstractmethod
from contextlib import contextmanager
from itertools import chain
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from uuid import UUID

import numpy as np

Expand All @@ -36,7 +38,7 @@
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_base import PaddingStrategy
from .tokenization_utils_base import BatchEncoding, PaddingStrategy


if is_tf_available():
Expand All @@ -51,6 +53,7 @@
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TFAutoModelForCausalLM,
)

if is_torch_available():
Expand Down Expand Up @@ -1895,6 +1898,321 @@ def __call__(
return results


class Conversation:
"""
Utility class containing a conversation and its history. This class is meant to be used as an input to the
:obj:`~transformers.ConversationalPipeline`. The conversation contains a number of utility function to manage the addition of new
user input and generated model responses. A conversation needs to contain an unprocessed user input before being
passed to the :obj:`~transformers.ConversationalPipeline`. This user input is either created when the class is instantiated, or by calling
`append_response("input")` after a conversation turn.
Usage::
conversation = Conversation("Going to the movies tonight - any suggestions?")
# Steps usually performed by the model when generating a response:
# 1. Mark the user input as processed (moved to the history)
conversation.mark_processed()
# 2. Append a mode response
conversation.append_response("The Big lebowski.")
conversation.add_user_input("Is it good?")
Arguments:
text (:obj:`str`, `optional`, defaults to :obj:`None`):
The initial user input to start the conversation.
If :obj:`None`, a user input needs to be provided manually using `add_user_input` before the conversation can begin.
conversation_id (:obj:`uuid.UUID`, `optional`, defaults to :obj:`None`):
Unique identifier for the conversation
If :obj:`None`, the random UUID4 id will be assigned to the conversation.
"""

def __init__(self, text: str = None, conversation_id: UUID = None):
if not conversation_id:
conversation_id = uuid.uuid4()
self.uuid: UUID = conversation_id
self.past_user_inputs: List[str] = []
self.generated_responses: List[str] = []
self.history: List[int] = []
self.new_user_input: Optional[str] = text

def add_user_input(self, text: str, overwrite: bool = False):
"""
Add a user input to the conversation for the next round. This populates the internal `new_user_input` field.
Args:
text: str, the user input for the next conversation round
overwrite: bool, flag indicating if existing and unprocessed user input should be overwritten when this function is called
"""
if self.new_user_input:
if overwrite:
logger.warning(
'User input added while unprocessed input was existing: "{}" was overwritten with: "{}".'.format(
self.new_user_input, text
)
)
self.new_user_input = text
else:
logger.warning(
'User input added while unprocessed input was existing: "{}" new input ignored: "{}". '
"Set `overwrite` to True to overwrite unprocessed user input".format(self.new_user_input, text)
)
else:
self.new_user_input = text

def mark_processed(self):
"""
Mark the conversation as processed (moves the content of `new_user_input` to `past_user_inputs`) and empties the
`new_user_input` field.
"""
if self.new_user_input:
self.past_user_inputs.append(self.new_user_input)
self.new_user_input = None

def append_response(self, response: str):
"""
Append a response to the list of generated responses.
Args:
response: str, the model generated response
"""
self.generated_responses.append(response)

def set_history(self, history: List[int]):
"""
Updates the value of the history of the conversation. The history is represented by a list of `token_ids`. The
history is used by the model to generate responses based on the previous conversation turns.
Args:
history: (list of int), history of tokens provided and generated for this conversation
"""
self.history = history

def __repr__(self):
"""
Generates a string representation of the conversation.
Return:
:obj:`str` or :obj:`Dict`:
Example:
Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114
user >> Going to the movies tonight - any suggestions?
bot >> The Big Lebowski
"""
output = "Conversation id: {} \n".format(self.uuid)
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
output += "user >> {} \n".format(user_input)
output += "bot >> {} \n".format(generated_response)
if self.new_user_input is not None:
output += "user >> {} \n".format(self.new_user_input)
return output


class ConversationalPipeline(Pipeline):
"""
Multi-turn conversational pipeline.
Usage::
conversational_pipeline = pipeline("conversational")
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?")
conversational_pipeline([conversation_1, conversation_2])
conversation_1.add_user_input("Is it an action movie?")
conversation_2.add_user_input("What is the genre of this book?")
conversational_pipeline([conversation_1, conversation_2])
The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task,
currently: "microsoft/DialoGPT-small", "microsoft/DialoGPT-medium", "microsoft/DialoGPT-large"
See the up-to-date list of available models on
`huggingface.co/models <https://huggingface.co/models?filter=conversational>`__.
Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
"""

def __init__(self, min_length_for_response=32, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.tokenizer.eos_token_id is not None, "DialoguePipeline tokenizer should have an EOS token set"
if self.tokenizer.pad_token_id is not None:
self.pad_token_id = self.tokenizer.pad_token_id
else:
self.pad_token_id = self.tokenizer.eos_token_id
self.min_length_for_response = min_length_for_response

def __call__(
self,
conversations: Union[Conversation, List[Conversation]],
clean_up_tokenization_spaces=True,
**generate_kwargs
):
r"""
Args:
conversations: (list of :class:`~transformers.pipelines.Conversation`) Conversations to generate responses for
**generate_kwargs: extra kwargs passed to `self.model.generate`_
Returns:
list of conversations with updated generated responses for those containing a new user input
"""

# Input validation
if isinstance(conversations, list):
for conversation in conversations:
assert isinstance(
conversation, Conversation
), "DialoguePipeline expects a Conversation or list of Conversations as an input"
if conversation.new_user_input is None:
raise ValueError(
"Conversation with UUID {} does not contain new user input to process. "
"Add user inputs with the conversation's `add_user_input` method".format(
type(conversation.uuid)
)
)
assert (
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
elif isinstance(conversations, Conversation):
conversations = [conversations]
else:
raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")

with self.device_placement():

inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
histories = [conversation.history for conversation in conversations]
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
inputs = self._concat_inputs_history(inputs, histories, max_length)

if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]

elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy()

if input_length > 0.9 * max_length:
logger.warning(
"Longest conversation length: {} is bigger than 0.9 * max_length: {}. "
"You might consider trimming the early phase of the conversation".format(input_length, max_length)
)
generated_responses = self.model.generate(
inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
)

cleaned_history = self._clean_padding_history(generated_responses)
output = []
for conversation_index, conversation in enumerate(conversations):
conversation.mark_processed()
conversation.generated_responses.append(
self.tokenizer.decode(
cleaned_history[conversation_index][input_length:],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
)
conversation.set_history(cleaned_history[conversation_index])
output.append(conversation)
if len(output) == 1:
return output[0]
else:
return output

def _parse_and_tokenize(self, *args, **kwargs):
"""
Parse arguments and tokenize, adding an EOS token at the end of the user input
"""
# Parse arguments
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer.batch_encode_plus(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
for input in inputs:
input.append(self.tokenizer.eos_token_id)
return inputs

def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
"""
Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
an input:
- at the end of the concatenated history and new user input, so that all input to the model have the same
length
- at the end of the generated response, as some responses will be longer than others
This method cleans up these padding token so that the history for each conversation is not impacted by the
batching process.
"""
outputs = []
for sequence in generated_tensor:
sequence_tokens = []
is_previous_pad = False
for token in sequence:
if token == self.pad_token_id:
if is_previous_pad:
continue
else:
is_previous_pad = True
else:
is_previous_pad = False
if self.framework == "pt":
sequence_tokens.append(token.item())
else:
sequence_tokens.append(int(token.numpy()))

outputs.append(sequence_tokens)
return outputs

def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int):
"""
Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context
"""
outputs = []
for new_input, history in zip(inputs, histories):
if history is not None:
new_input = history + new_input
if len(new_input) > max_length - self.min_length_for_response:
cutoff_eos_index = 0
while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response:
if cutoff_eos_index >= len(new_input):
break
cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id)
if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
break
else:
new_input = new_input[cutoff_eos_index + 1 :]
outputs.append(new_input)
max_len = max([len(item) for item in outputs])
outputs = [output + [self.pad_token_id] * (max_len - len(output)) for output in outputs]
outputs = BatchEncoding(
{"input_ids": outputs, "attention_mask": [1] * len(outputs)}, tensor_type=self.framework
)
return outputs


# Register all the supported tasks here
SUPPORTED_TASKS = {
"feature-extraction": {
Expand Down Expand Up @@ -1979,6 +2297,12 @@ def __call__(
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
},
},
"conversational": {
"impl": ConversationalPipeline,
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
"pt": AutoModelForCausalLM if is_torch_available() else None,
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
},
}


Expand Down
Loading

0 comments on commit e642c78

Please sign in to comment.