Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TransformersQueryClassifier #2587

Closed
ZanSara opened this issue May 23, 2022 · 2 comments · Fixed by #2965
Closed

Improve TransformersQueryClassifier #2587

ZanSara opened this issue May 23, 2022 · 2 comments · Fixed by #2965
Labels
Contributions wanted! Looking for external contributions type:feature New feature or request

Comments

@ZanSara
Copy link
Contributor

ZanSara commented May 23, 2022

Problem
Currently TransformersQueryClassifier is very closely build around the question/keywords/statement classifier model used in the tutorials ("hahrukhx01/bert-mini-finetune-question-detection). In practice, it can only handle models that output binary labels, one of which must be called LABEL_1.

I believe this limitations make it unsuitable for any other model than the one used in the example.

Solution
HuggingFace how hosts a wide array of zero-shot text classification models which could be nicely applied to query classification, for example for sentiment/emotion analysis, or for topic classification. With limited changes, TransformersQueryClassifier can be improved to use these models effectively.

Note
Currently it's possible to write custom nodes for this usecase. Here is an example.

from pprint import pprint
import logging
from typing import Optional, List, Any, Union

from transformers import pipeline

from haystack import Document, Pipeline, Answer
from haystack.nodes.base import BaseComponent
from haystack.schema import Document
from haystack.nodes import TransformersQueryClassifier


class ZeroshotQueryClassifier(TransformersQueryClassifier):

    outgoing_edges: int = 10

    def __init__(
        self,
        model_name_or_path: str,
        labels: List[str],
        model_version: Optional[str] = None,
        tokenizer: Optional[str] = None,
        use_gpu: bool = True,
        batch_size: Optional[int] = None,
    ):
        """
        :param model_name_or_path: accepts most zero-shot-classification model from HuggingFace (https://huggingface.co/models?pipeline_tag=zero-shot-classification)
        :param labels: the labels for zero-shot classification (for example a list of the emotions to classify by, or something like ["happy", "unhappy", "neutral"])
        """
        super().__init__(use_gpu=use_gpu, batch_size=batch_size)
        if tokenizer is None:
            tokenizer = model_name_or_path
        self.model = pipeline(
            task="zero-shot-classification", model=model_name_or_path, tokenizer=tokenizer, device= 0 if self.devices[0].type == "cuda" else -1, revision=model_version
        )
        self.labels = labels

    def _get_edge_number(self, label):
        return self.labels.index(label)+1

    def run(self, query: str) -> List[Document]:
        prediction = self.model([query], candidate_labels=self.labels, truncation=True)
        label = prediction[0]["labels"][0]
        return {"output": query}, f"output_{self._get_edge_number(label)}"

    def run_batch(self, queries: List[str]) -> Union[List[Document], List[List[Document]]]:
        predictions = self.model(queries, candidate_labels=self.labels, truncation=True)

        results = {f"output_{self._get_edge_number(label)}": {"queries": []} for label in self.labels}
        for query, prediction in zip(queries, predictions):
            label = prediction["labels"][0]
            results[f"output_{self._get_edge_number(label)}"]["queries"].append(query)

        return results, "split"


#
# Usage as a single node
#

query_classifier = ZeroshotQueryClassifier(
    model_name_or_path="typeform/distilbert-base-uncased-mnli", 
    labels=["happy", "unhappy", "neutral"]
)

queries = [
    "What's the answer?",
    "Would you be so kind to tell me the answer?",
    "Can you give me the right answer for once??",
]

# Processing all queries in a single call
output = query_classifier.run_batch(queries=queries)
print()
pprint(output)
print()

# Processing one query at a time
for query in queries:
    output = query_classifier.run(query=query)
    pprint(output)
    print()


#
# Usage in a pipeline (with stub nodes)
# 

class HappyAnswer(BaseComponent):

    def run(query: str):
        return {"answers": [Answer(answer="We're glad you like it!")]}, "output_1"

    def run_batch(queries: List[str]):
        return {"answers": [Answer(answer="We're glad you like it!")] * len(queries)}, "output_1"


class UnhappyAnswer(BaseComponent):

    def run(query: str):
        return {"answers": [Answer(answer="We're so sorry you're not happy :(")]}, "output_1"

    def run_batch(queries: List[str]):
        return {"answers": [Answer(answer="We're so sorry you're not happy :(")] * len(queries)}, "output_1"


class NeutralAnswer(BaseComponent):

    def run(query: str):
        return {"answers": [Answer(answer="Thanks for your feedback.")]}, "output_1"

    def run_batch(queries: List[str]):
        return {"answers": [Answer(answer="Thanks for your feedback.")] * len(queries)}, "output_1"


pipeline = Pipeline()
pipeline.add_node(component=query_classifier, name="classifier", inputs=["Query"])
pipeline.add_node(component=HappyAnswer(), name="happy", inputs=["classifier.output_1"])
pipeline.add_node(component=UnhappyAnswer(), name="unhappy", inputs=["classifier.output_2"])
pipeline.add_node(component=NeutralAnswer(), name="neutral", inputs=["classifier.output_3"])

pipeline.draw("pipeline.png")

for query in queries:
    output = pipeline.run(query=query)
    pprint(output)
    print()
@ZanSara ZanSara added type:feature New feature or request journey:intermediate labels May 23, 2022
@ZanSara ZanSara added the Contributions wanted! Looking for external contributions label Jul 19, 2022
@anakin87
Copy link
Member

@ZanSara just one question to better understand your opinion...
Should we make TransformersQueryClassifier more general and suitable for handling non-binary output labels?

@ZanSara
Copy link
Contributor Author

ZanSara commented Jul 26, 2022

Hey @anakin87! Yes that's the aim. Right now the node is highly tailored for the specific model I mentioned above, so even other binary models would not work. First of all, I think any binary model should work as a QueryClassifier, but honestly I think it's worth to take this occasion to really improve it. if it was able to handle a generic text classification model it would be really cool 😊

By the way: feel free to go for a heavy rewrite if you believe it's a good call. Just make sure that is still compatible with the tutorial.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Contributions wanted! Looking for external contributions type:feature New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants