You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Problem
Currently TransformersQueryClassifier is very closely build around the question/keywords/statement classifier model used in the tutorials ("hahrukhx01/bert-mini-finetune-question-detection). In practice, it can only handle models that output binary labels, one of which must be called LABEL_1.
I believe this limitations make it unsuitable for any other model than the one used in the example.
Solution
HuggingFace how hosts a wide array of zero-shot text classification models which could be nicely applied to query classification, for example for sentiment/emotion analysis, or for topic classification. With limited changes, TransformersQueryClassifier can be improved to use these models effectively.
Note
Currently it's possible to write custom nodes for this usecase. Here is an example.
frompprintimportpprintimportloggingfromtypingimportOptional, List, Any, UnionfromtransformersimportpipelinefromhaystackimportDocument, Pipeline, Answerfromhaystack.nodes.baseimportBaseComponentfromhaystack.schemaimportDocumentfromhaystack.nodesimportTransformersQueryClassifierclassZeroshotQueryClassifier(TransformersQueryClassifier):
outgoing_edges: int=10def__init__(
self,
model_name_or_path: str,
labels: List[str],
model_version: Optional[str] =None,
tokenizer: Optional[str] =None,
use_gpu: bool=True,
batch_size: Optional[int] =None,
):
""" :param model_name_or_path: accepts most zero-shot-classification model from HuggingFace (https://huggingface.co/models?pipeline_tag=zero-shot-classification) :param labels: the labels for zero-shot classification (for example a list of the emotions to classify by, or something like ["happy", "unhappy", "neutral"]) """super().__init__(use_gpu=use_gpu, batch_size=batch_size)
iftokenizerisNone:
tokenizer=model_name_or_pathself.model=pipeline(
task="zero-shot-classification", model=model_name_or_path, tokenizer=tokenizer, device=0ifself.devices[0].type=="cuda"else-1, revision=model_version
)
self.labels=labelsdef_get_edge_number(self, label):
returnself.labels.index(label)+1defrun(self, query: str) ->List[Document]:
prediction=self.model([query], candidate_labels=self.labels, truncation=True)
label=prediction[0]["labels"][0]
return {"output": query}, f"output_{self._get_edge_number(label)}"defrun_batch(self, queries: List[str]) ->Union[List[Document], List[List[Document]]]:
predictions=self.model(queries, candidate_labels=self.labels, truncation=True)
results= {f"output_{self._get_edge_number(label)}": {"queries": []} forlabelinself.labels}
forquery, predictioninzip(queries, predictions):
label=prediction["labels"][0]
results[f"output_{self._get_edge_number(label)}"]["queries"].append(query)
returnresults, "split"## Usage as a single node#query_classifier=ZeroshotQueryClassifier(
model_name_or_path="typeform/distilbert-base-uncased-mnli",
labels=["happy", "unhappy", "neutral"]
)
queries= [
"What's the answer?",
"Would you be so kind to tell me the answer?",
"Can you give me the right answer for once??",
]
# Processing all queries in a single calloutput=query_classifier.run_batch(queries=queries)
print()
pprint(output)
print()
# Processing one query at a timeforqueryinqueries:
output=query_classifier.run(query=query)
pprint(output)
print()
## Usage in a pipeline (with stub nodes)# classHappyAnswer(BaseComponent):
defrun(query: str):
return {"answers": [Answer(answer="We're glad you like it!")]}, "output_1"defrun_batch(queries: List[str]):
return {"answers": [Answer(answer="We're glad you like it!")] *len(queries)}, "output_1"classUnhappyAnswer(BaseComponent):
defrun(query: str):
return {"answers": [Answer(answer="We're so sorry you're not happy :(")]}, "output_1"defrun_batch(queries: List[str]):
return {"answers": [Answer(answer="We're so sorry you're not happy :(")] *len(queries)}, "output_1"classNeutralAnswer(BaseComponent):
defrun(query: str):
return {"answers": [Answer(answer="Thanks for your feedback.")]}, "output_1"defrun_batch(queries: List[str]):
return {"answers": [Answer(answer="Thanks for your feedback.")] *len(queries)}, "output_1"pipeline=Pipeline()
pipeline.add_node(component=query_classifier, name="classifier", inputs=["Query"])
pipeline.add_node(component=HappyAnswer(), name="happy", inputs=["classifier.output_1"])
pipeline.add_node(component=UnhappyAnswer(), name="unhappy", inputs=["classifier.output_2"])
pipeline.add_node(component=NeutralAnswer(), name="neutral", inputs=["classifier.output_3"])
pipeline.draw("pipeline.png")
forqueryinqueries:
output=pipeline.run(query=query)
pprint(output)
print()
The text was updated successfully, but these errors were encountered:
@ZanSara just one question to better understand your opinion...
Should we make TransformersQueryClassifier more general and suitable for handling non-binary output labels?
Hey @anakin87! Yes that's the aim. Right now the node is highly tailored for the specific model I mentioned above, so even other binary models would not work. First of all, I think any binary model should work as a QueryClassifier, but honestly I think it's worth to take this occasion to really improve it. if it was able to handle a generic text classification model it would be really cool 😊
By the way: feel free to go for a heavy rewrite if you believe it's a good call. Just make sure that is still compatible with the tutorial.
Problem
Currently
TransformersQueryClassifier
is very closely build around the question/keywords/statement classifier model used in the tutorials ("hahrukhx01/bert-mini-finetune-question-detection). In practice, it can only handle models that output binary labels, one of which must be calledLABEL_1
.I believe this limitations make it unsuitable for any other model than the one used in the example.
Solution
HuggingFace how hosts a wide array of zero-shot text classification models which could be nicely applied to query classification, for example for sentiment/emotion analysis, or for topic classification. With limited changes,
TransformersQueryClassifier
can be improved to use these models effectively.Note
Currently it's possible to write custom nodes for this usecase. Here is an example.
The text was updated successfully, but these errors were encountered: