Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ConversationalPipeline to receive string input #30958

Closed
mattdeeperinsights opened this issue May 22, 2024 · 3 comments
Closed

Allow ConversationalPipeline to receive string input #30958

mattdeeperinsights opened this issue May 22, 2024 · 3 comments
Labels
Core: Pipeline Internals of the library; Pipeline. Feature request Request for a new feature

Comments

@mattdeeperinsights
Copy link

mattdeeperinsights commented May 22, 2024

Feature request

Currently ConversationalPipeline expects a conversation as input (a list of custom objects).

Sometimes, you may want to quickly feed in a piece of text as a simple str to a conversation model without wrapping up first.

Motivation

I think this would simplify the process of trying out chat based models and ensuring that we call them with the correct format etc.

Your contribution

Currently, I can easily get around this by updating the ConversationalPipeline.__call__ method by handling the case that conversations is a str:

from typing import List, Union

from transformers import ConversationalPipeline

class CustomConversationalPipeline(ConversationalPipeline):
    def __call__(self, conversations: Union[Conversation, List[Conversation]], num_workers=0, **kwargs):
        r"""
        Generate responses for the conversation(s) given as inputs.

        Args:
            conversations (a :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`):
                Conversations to generate responses for.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate method
                corresponding to your framework `here <./model.html#generative-models>`__).

        Returns:
            :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`: Conversation(s) with
            updated generated responses for those containing a new user input.
        """
        if isinstance(conversations, str):
            conversations = [
                {'role': 'user', 'content': conversations}
            ]
        
        # XXX: num_workers==0 is required to be backward compatible
        # Otherwise the threads will require a Conversation copy.
        # This will definitely hinder performance on GPU, but has to be opted
        # in because of this BC change.
        outputs = super().__call__(conversations, num_workers=num_workers, **kwargs)
        if isinstance(outputs, list) and len(outputs) == 1:
            return outputs[0]
        return outputs

Full workable solution is then:

from transformers import AutoModelForCausalLM, AutoTokenizer, ConversationalPipeline

model_checkpoint = "google/gemma-2b-it"

MODEL = AutoModelForCausalLM.from_pretrained(
    model_checkpoint,
    ...
)

TOKENIZER = AutoTokenizer.from_pretrained(
    model_checkpoint
)

generator = CustomConversationalPipeline(
    model=MODEL,
    tokenizer=TOKENIZER
)

generator("Hello!")
@amyeroberts amyeroberts added Core: Pipeline Internals of the library; Pipeline. Feature request Request for a new feature labels May 22, 2024
@amyeroberts
Copy link
Collaborator

cc @Rocketknight1

@Rocketknight1
Copy link
Member

Hi @mattdeeperinsights, we're actually deprecating the ConversationalPipeline in favour of the TextGenerationPipeline. ConversationalPipeline will be removed in the very near future, but you can now pass chats directly to TextGenerationPipeline instead.

Unfortunately, this means that your shortcut won't work that well - I think it would have been reasonably clean for ConversationalPipeline, but TextGenerationPipeline interprets a single string as a naked string to be completed, rather than a user message.

@mattdeeperinsights
Copy link
Author

okay thanks that makes sense.

It turns out that later version of the pipeline has the expected behaviour anyway, so I will stick to that:

transformers.pipeline(
    task='text-generation',
    model=MODEL,
    tokenizer=TOKENIZER,
)

A string input here will be embedded into a conversation behind the scenes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Pipeline Internals of the library; Pipeline. Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

3 participants