From 559ba03818634bb85443e9d99826140be70de54f Mon Sep 17 00:00:00 2001 From: anakin87 <44616784+anakin87@users.noreply.github.com> Date: Thu, 4 Aug 2022 12:16:03 +0200 Subject: [PATCH 1/8] extend query classifier in one commit --- docs/_src/api/api/query_classifier.md | 23 +++- .../haystack-pipeline-master.schema.json | 24 ++++ .../nodes/query_classifier/transformers.py | 105 ++++++++++++------ test/nodes/test_query_classifier.py | 94 ++++++++++++++++ 4 files changed, 210 insertions(+), 36 deletions(-) create mode 100644 test/nodes/test_query_classifier.py diff --git a/docs/_src/api/api/query_classifier.md b/docs/_src/api/api/query_classifier.md index e56e2095fb..3d8683e564 100644 --- a/docs/_src/api/api/query_classifier.md +++ b/docs/_src/api/api/query_classifier.md @@ -95,10 +95,15 @@ queries or statement vs question queries. class TransformersQueryClassifier(BaseQueryClassifier) ``` -A node to classify an incoming query into one of two categories using a (small) BERT transformer model. + + +#### outgoing\_edges + +A node to classify an incoming query into categories using a transformer model. Depending on the result, the query flows to a different branch in your pipeline and the further processing -can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2` +can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n` from this node. +This node also supports zero-shot-classification. **Example**: @@ -119,7 +124,7 @@ from this node. Models: - Pass your own `Transformer` binary classification model from file/huggingface or use one of the following + Pass your own `Transformer` classification/zero-shot-classification model from file/huggingface or use one of the following pretrained ones hosted on Huggingface: 1) Keywords vs. Questions/Statements (Default) model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection" @@ -142,11 +147,19 @@ from this node. #### TransformersQueryClassifier.\_\_init\_\_ ```python -def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", use_gpu: bool = True, batch_size: Optional[int] = None) +def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: Optional[List[str]] = ["LABEL_1", "LABEL_0"], batch_size: Optional[int] = None) ``` **Arguments**: -- `model_name_or_path`: Transformer based fine tuned mini bert model for query classification +- `model_name_or_path`: Directory of a saved model or the name of a public model, for example 'shahrukhx01/bert-mini-finetune-question-detection'. +See [Hugging Face models](https://huggingface.co/models) for a full list of available models. +- `model_version`: The version of the model to use from the Hugging Face model hub. This can be a tag name, a branch name, or a commit hash. +- `tokenizer`: The name of the tokenizer (usually the same as model). - `use_gpu`: Whether to use GPU (if available). +- `task`: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'. +- `labels`: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1, +the second label to output_2, and so on. The labels must match the model labels; only the order can differ. Otherwise, model labels are considered. +If the task is 'zero-shot-classification', these are the candidate labels. +- `batch_size`: The number of queries to be processed at a time. diff --git a/haystack/json-schemas/haystack-pipeline-master.schema.json b/haystack/json-schemas/haystack-pipeline-master.schema.json index 16644ec408..92613fa56e 100644 --- a/haystack/json-schemas/haystack-pipeline-master.schema.json +++ b/haystack/json-schemas/haystack-pipeline-master.schema.json @@ -4726,11 +4726,35 @@ } ] }, + "model_version": { + "title": "Model Version", + "type": "string" + }, + "tokenizer": { + "title": "Tokenizer", + "type": "string" + }, "use_gpu": { "title": "Use Gpu", "default": true, "type": "boolean" }, + "task": { + "title": "Task", + "default": "text-classification", + "type": "string" + }, + "labels": { + "title": "Labels", + "default": [ + "LABEL_1", + "LABEL_0" + ], + "type": "array", + "items": { + "type": "string" + } + }, "batch_size": { "title": "Batch Size", "type": "integer" diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py index 2b73c268da..e6aa12e5b4 100644 --- a/haystack/nodes/query_classifier/transformers.py +++ b/haystack/nodes/query_classifier/transformers.py @@ -1,8 +1,8 @@ import logging from pathlib import Path -from typing import Union, List, Optional, Dict +from typing import Union, List, Optional -from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline +from transformers import pipeline from haystack.nodes.query_classifier.base import BaseQueryClassifier from haystack.modeling.utils import initialize_device_settings @@ -11,11 +11,15 @@ class TransformersQueryClassifier(BaseQueryClassifier): + + outgoing_edges: int = 10 + """ - A node to classify an incoming query into one of two categories using a (small) BERT transformer model. + A node to classify an incoming query into categories using a transformer model. Depending on the result, the query flows to a different branch in your pipeline and the further processing - can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2` + can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n` from this node. + This node also supports zero-shot-classification. Example: ```python @@ -35,7 +39,7 @@ class TransformersQueryClassifier(BaseQueryClassifier): Models: - Pass your own `Transformer` binary classification model from file/huggingface or use one of the following + Pass your own `Transformer` classification/zero-shot-classification model from file/huggingface or use one of the following pretrained ones hosted on Huggingface: 1) Keywords vs. Questions/Statements (Default) model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection" @@ -57,43 +61,82 @@ class TransformersQueryClassifier(BaseQueryClassifier): def __init__( self, model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", + model_version: Optional[str] = None, + tokenizer: Optional[str] = None, use_gpu: bool = True, + task: str = "text-classification", + labels: Optional[List[str]] = ["LABEL_1", "LABEL_0"], batch_size: Optional[int] = None, ): """ - :param model_name_or_path: Transformer based fine tuned mini bert model for query classification + :param model_name_or_path: Directory of a saved model or the name of a public model, for example 'shahrukhx01/bert-mini-finetune-question-detection'. + See [Hugging Face models](https://huggingface.co/models) for a full list of available models. + :param model_version: The version of the model to use from the Hugging Face model hub. This can be a tag name, a branch name, or a commit hash. + :param tokenizer: The name of the tokenizer (usually the same as model). :param use_gpu: Whether to use GPU (if available). + :param task: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'. + :param labels: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1, + the second label to output_2, and so on. The labels must match the model labels; only the order can differ. Otherwise, model labels are considered. + If the task is 'zero-shot-classification', these are the candidate labels. + :param batch_size: The number of queries to be processed at a time. """ super().__init__() - - self.devices, _ = initialize_device_settings(use_cuda=use_gpu) + devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) + device = 0 if devices[0].type == "cuda" else -1 + + self.model = pipeline( + task=task, model=model_name_or_path, tokenizer=tokenizer, device=device, revision=model_version + ) + + self.labels = labels + if task == "zero-shot-classification": + if labels is None or len(labels) == 0: + raise ValueError("Candidate labels must be provided for task zero-shot-classification") + elif task == "text-classification": + labels_from_model = [label for label in self.model.model.config.id2label.values()] + if labels is None: + self.labels = labels_from_model + elif set(labels) != set(labels_from_model): + self.labels = labels_from_model + logger.warning( + f"The provided labels do not match the model labels, then the model labels are used.\n" + f"Provided labels: {labels}\n" + f"Model labels: {labels_from_model}" + ) + else: + raise ValueError( + f"Task not supported: {task}.\n" + f"Possible task values are: 'text-classification' or 'zero-shot-classification'" + ) + self.task = task self.batch_size = batch_size - device = 0 if self.devices[0].type == "cuda" else -1 - - model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path) - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - self.query_classification_pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer, device=device) + def _get_edge_number(self, label): + return self.labels.index(label) + 1 - def run(self, query): - is_question: bool = self.query_classification_pipeline(query)[0]["label"] == "LABEL_1" - - if is_question: - return {}, "output_1" - else: - return {}, "output_2" + def run(self, query: str): # type: ignore + if self.task == "zero-shot-classification": + prediction = self.model([query], candidate_labels=self.labels, truncation=True) + label = prediction[0]["labels"][0] + elif self.task == "text-classification": + prediction = self.model([query], truncation=True) + label = prediction[0]["label"] + return {}, f"output_{self._get_edge_number(label)}" def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # type: ignore if batch_size is None: batch_size = self.batch_size - - split: Dict[str, Dict[str, List]] = {"output_1": {"queries": []}, "output_2": {"queries": []}} - - predictions = self.query_classification_pipeline(queries, batch_size=batch_size) - for query, pred in zip(queries, predictions): - if pred["label"] == "LABEL_1": - split["output_1"]["queries"].append(query) - else: - split["output_2"]["queries"].append(query) - - return split, "split" + if self.task == "zero-shot-classification": + predictions = self.model(queries, candidate_labels=self.labels, truncation=True, batch_size=batch_size) + elif self.task == "text-classification": + predictions = self.model(queries, truncation=True, batch_size=batch_size) + + results = {f"output_{self._get_edge_number(label)}": {"queries": []} for label in self.labels} # type: ignore + for query, prediction in zip(queries, predictions): + if self.task == "zero-shot-classification": + label = prediction["labels"][0] + elif self.task == "text-classification": + label = prediction["label"] + results[f"output_{self._get_edge_number(label)}"]["queries"].append(query) + + return results, "split" diff --git a/test/nodes/test_query_classifier.py b/test/nodes/test_query_classifier.py new file mode 100644 index 0000000000..bca56f12dd --- /dev/null +++ b/test/nodes/test_query_classifier.py @@ -0,0 +1,94 @@ +import pytest +from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier + + +@pytest.fixture +def transformers_query_classifier(): + return TransformersQueryClassifier( + model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection", + use_gpu=False, + task="text-classification", + labels=["LABEL_1", "LABEL_0"], + ) + + +@pytest.fixture +def zero_shot_transformers_query_classifier(): + return TransformersQueryClassifier( + model_name_or_path="typeform/distilbert-base-uncased-mnli", + use_gpu=False, + task="zero-shot-classification", + labels=["happy", "unhappy", "neutral"], + ) + + +def test_transformers_query_classifier(transformers_query_classifier): + output = transformers_query_classifier.run(query="morse code") + assert output == ({}, "output_2") + + output = transformers_query_classifier.run(query="How old is John?") + assert output == ({}, "output_1") + + +def test_transformers_query_classifier_batch(transformers_query_classifier): + queries = ["morse code", "How old is John?"] + output = transformers_query_classifier.run_batch(queries=queries) + + assert output[0] == {"output_2": {"queries": ["morse code"]}, "output_1": {"queries": ["How old is John?"]}} + + +def test_zero_shot_transformers_query_classifier(zero_shot_transformers_query_classifier): + output = zero_shot_transformers_query_classifier.run(query="What's the answer?") + assert output == ({}, "output_3") + + output = zero_shot_transformers_query_classifier.run(query="Would you be so kind to tell me the answer?") + assert output == ({}, "output_1") + + output = zero_shot_transformers_query_classifier.run(query="Can you give me the right answer for once??") + assert output == ({}, "output_2") + + +def test_zero_shot_transformers_query_classifier_batch(zero_shot_transformers_query_classifier): + queries = [ + "What's the answer?", + "Would you be so kind to tell me the answer?", + "Can you give me the right answer for once??", + ] + + output = zero_shot_transformers_query_classifier.run_batch(queries=queries) + + assert output[0] == { + "output_3": {"queries": ["What's the answer?"]}, + "output_1": {"queries": ["Would you be so kind to tell me the answer?"]}, + "output_2": {"queries": ["Can you give me the right answer for once??"]}, + } + + +def test_transformers_query_classifier_wrong_labels(): + with pytest.warns(None, match="The provided labels do not match the model labels"): + query_classifier = TransformersQueryClassifier( + model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection", + use_gpu=False, + task="text-classification", + labels=["WRONG_LABEL_1", "WRONG_LABEL_2", "WRONG_LABEL_3"], + ) + + +def test_zero_shot_transformers_query_classifier_no_labels(): + with pytest.raises(ValueError): + query_classifier = TransformersQueryClassifier( + model_name_or_path="typeform/distilbert-base-uncased-mnli", + use_gpu=False, + task="zero-shot-classification", + labels=None, + ) + + +def test_transformers_query_classifier_unsupported_task(): + with pytest.raises(ValueError): + query_classifier = TransformersQueryClassifier( + model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection", + use_gpu=False, + task="summarization", + labels=["LABEL_1", "LABEL_0"], + ) From 6ba187b1d6dea514eb01380e8a54ae8d5f35bec9 Mon Sep 17 00:00:00 2001 From: anakin87 <44616784+anakin87@users.noreply.github.com> Date: Thu, 4 Aug 2022 14:58:00 +0200 Subject: [PATCH 2/8] variable number of outgoing edges --- docs/_src/api/api/query_classifier.md | 4 ---- .../nodes/query_classifier/transformers.py | 20 +++++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/_src/api/api/query_classifier.md b/docs/_src/api/api/query_classifier.md index b95ab11f00..1fd1695689 100644 --- a/docs/_src/api/api/query_classifier.md +++ b/docs/_src/api/api/query_classifier.md @@ -95,10 +95,6 @@ queries or statement vs question queries. class TransformersQueryClassifier(BaseQueryClassifier) ``` - - -#### outgoing\_edges - A node to classify an incoming query into categories using a transformer model. Depending on the result, the query flows to a different branch in your pipeline and the further processing can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n` diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py index 6d549c1d0b..de7c242bd2 100644 --- a/haystack/nodes/query_classifier/transformers.py +++ b/haystack/nodes/query_classifier/transformers.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Union, List, Optional +from typing import Union, List, Optional, Dict, Any from transformers import pipeline from haystack.nodes.query_classifier.base import BaseQueryClassifier @@ -11,9 +11,6 @@ class TransformersQueryClassifier(BaseQueryClassifier): - - outgoing_edges: int = 10 - """ A node to classify an incoming query into categories using a transformer model. Depending on the result, the query flows to a different branch in your pipeline and the further processing @@ -111,7 +108,14 @@ def __init__( self.task = task self.batch_size = batch_size - def _get_edge_number(self, label): + @classmethod + def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int: + labels = component_params.get("labels", None) + if labels is None: + return 2 + return len(labels) + + def _get_edge_number_from_label(self, label): return self.labels.index(label) + 1 def run(self, query: str): # type: ignore @@ -121,7 +125,7 @@ def run(self, query: str): # type: ignore elif self.task == "text-classification": prediction = self.model([query], truncation=True) label = prediction[0]["label"] - return {}, f"output_{self._get_edge_number(label)}" + return {}, f"output_{self._get_edge_number_from_label(label)}" def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # type: ignore if batch_size is None: @@ -131,12 +135,12 @@ def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # ty elif self.task == "text-classification": predictions = self.model(queries, truncation=True, batch_size=batch_size) - results = {f"output_{self._get_edge_number(label)}": {"queries": []} for label in self.labels} # type: ignore + results = {f"output_{self._get_edge_number_from_label(label)}": {"queries": []} for label in self.labels} # type: ignore for query, prediction in zip(queries, predictions): if self.task == "zero-shot-classification": label = prediction["labels"][0] elif self.task == "text-classification": label = prediction["label"] - results[f"output_{self._get_edge_number(label)}"]["queries"].append(query) + results[f"output_{self._get_edge_number_from_label(label)}"]["queries"].append(query) return results, "split" From a49c750a915ad7a8ce0832540c6a0c93a198313a Mon Sep 17 00:00:00 2001 From: anakin87 <44616784+anakin87@users.noreply.github.com> Date: Thu, 4 Aug 2022 16:16:05 +0200 Subject: [PATCH 3/8] improve tests --- haystack/nodes/query_classifier/transformers.py | 9 ++------- test/nodes/test_query_classifier.py | 4 ++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py index de7c242bd2..bbd40b66e6 100644 --- a/haystack/nodes/query_classifier/transformers.py +++ b/haystack/nodes/query_classifier/transformers.py @@ -55,6 +55,8 @@ class TransformersQueryClassifier(BaseQueryClassifier): See also the [tutorial](https://haystack.deepset.ai/tutorials/pipelines) on pipelines. """ + outgoing_edges = 10 + def __init__( self, model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", @@ -108,13 +110,6 @@ def __init__( self.task = task self.batch_size = batch_size - @classmethod - def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int: - labels = component_params.get("labels", None) - if labels is None: - return 2 - return len(labels) - def _get_edge_number_from_label(self, label): return self.labels.index(label) + 1 diff --git a/test/nodes/test_query_classifier.py b/test/nodes/test_query_classifier.py index bca56f12dd..57b0a922f5 100644 --- a/test/nodes/test_query_classifier.py +++ b/test/nodes/test_query_classifier.py @@ -75,7 +75,7 @@ def test_transformers_query_classifier_wrong_labels(): def test_zero_shot_transformers_query_classifier_no_labels(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Candidate labels must be provided for task zero-shot-classification"): query_classifier = TransformersQueryClassifier( model_name_or_path="typeform/distilbert-base-uncased-mnli", use_gpu=False, @@ -85,7 +85,7 @@ def test_zero_shot_transformers_query_classifier_no_labels(): def test_transformers_query_classifier_unsupported_task(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Task not supported"): query_classifier = TransformersQueryClassifier( model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection", use_gpu=False, From 8a3e6b869e02a4861bf671118d93955573187430 Mon Sep 17 00:00:00 2001 From: anakin87 <44616784+anakin87@users.noreply.github.com> Date: Thu, 4 Aug 2022 16:23:43 +0200 Subject: [PATCH 4/8] fix unused import --- haystack/nodes/query_classifier/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py index bbd40b66e6..8d1b78ee75 100644 --- a/haystack/nodes/query_classifier/transformers.py +++ b/haystack/nodes/query_classifier/transformers.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Union, List, Optional, Dict, Any +from typing import Union, List, Optional from transformers import pipeline from haystack.nodes.query_classifier.base import BaseQueryClassifier From 780d5a1df41345efaed8ca137165adfbb33da253 Mon Sep 17 00:00:00 2001 From: anakin87 <44616784+anakin87@users.noreply.github.com> Date: Fri, 5 Aug 2022 12:53:44 +0200 Subject: [PATCH 5/8] lightweight approach --- docs/_src/api/api/query_classifier.md | 4 +-- .../nodes/query_classifier/transformers.py | 31 +++++++++---------- test/nodes/test_query_classifier.py | 10 +++--- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/docs/_src/api/api/query_classifier.md b/docs/_src/api/api/query_classifier.md index 1fd1695689..3a405e1599 100644 --- a/docs/_src/api/api/query_classifier.md +++ b/docs/_src/api/api/query_classifier.md @@ -143,7 +143,7 @@ This node also supports zero-shot-classification. #### TransformersQueryClassifier.\_\_init\_\_ ```python -def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: Optional[List[str]] = ["LABEL_1", "LABEL_0"], batch_size: int = 16) +def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: List[str] = ["LABEL_1", "LABEL_0"], batch_size: int = 16) ``` **Arguments**: @@ -155,7 +155,7 @@ See [Hugging Face models](https://huggingface.co/models) for a full list of avai - `use_gpu`: Whether to use GPU (if available). - `task`: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'. - `labels`: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1, -the second label to output_2, and so on. The labels must match the model labels; only the order can differ. Otherwise, model labels are considered. +the second label to output_2, and so on. The labels must match the model labels; only the order can differ. If the task is 'zero-shot-classification', these are the candidate labels. - `batch_size`: The number of queries to be processed at a time. diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py index 8d1b78ee75..2b4b6cfe70 100644 --- a/haystack/nodes/query_classifier/transformers.py +++ b/haystack/nodes/query_classifier/transformers.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Union, List, Optional +from typing import Union, List, Optional, Dict, Any from transformers import pipeline from haystack.nodes.query_classifier.base import BaseQueryClassifier @@ -55,8 +55,6 @@ class TransformersQueryClassifier(BaseQueryClassifier): See also the [tutorial](https://haystack.deepset.ai/tutorials/pipelines) on pipelines. """ - outgoing_edges = 10 - def __init__( self, model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", @@ -64,7 +62,7 @@ def __init__( tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", - labels: Optional[List[str]] = ["LABEL_1", "LABEL_0"], + labels: List[str] = ["LABEL_1", "LABEL_0"], batch_size: int = 16, ): """ @@ -75,7 +73,7 @@ def __init__( :param use_gpu: Whether to use GPU (if available). :param task: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'. :param labels: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1, - the second label to output_2, and so on. The labels must match the model labels; only the order can differ. Otherwise, model labels are considered. + the second label to output_2, and so on. The labels must match the model labels; only the order can differ. If the task is 'zero-shot-classification', these are the candidate labels. :param batch_size: The number of queries to be processed at a time. """ @@ -88,21 +86,15 @@ def __init__( ) self.labels = labels - if task == "zero-shot-classification": - if labels is None or len(labels) == 0: - raise ValueError("Candidate labels must be provided for task zero-shot-classification") - elif task == "text-classification": + if task == "text-classification": labels_from_model = [label for label in self.model.model.config.id2label.values()] - if labels is None: - self.labels = labels_from_model - elif set(labels) != set(labels_from_model): - self.labels = labels_from_model - logger.warning( - f"The provided labels do not match the model labels, then the model labels are used.\n" + if set(labels) != set(labels_from_model): + raise ValueError( + f"For text-classification, the provided labels must match the model labels; only the order can differ.\n" f"Provided labels: {labels}\n" f"Model labels: {labels_from_model}" ) - else: + if task not in ["text-classification", "zero-shot-classification"]: raise ValueError( f"Task not supported: {task}.\n" f"Possible task values are: 'text-classification' or 'zero-shot-classification'" @@ -110,6 +102,13 @@ def __init__( self.task = task self.batch_size = batch_size + @classmethod + def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int: + labels = component_params.get("labels", None) + if labels is None or len(labels) == 0: + raise ValueError("The labels must be provided") + return len(labels) + def _get_edge_number_from_label(self, label): return self.labels.index(label) + 1 diff --git a/test/nodes/test_query_classifier.py b/test/nodes/test_query_classifier.py index 57b0a922f5..a96eec594e 100644 --- a/test/nodes/test_query_classifier.py +++ b/test/nodes/test_query_classifier.py @@ -65,7 +65,7 @@ def test_zero_shot_transformers_query_classifier_batch(zero_shot_transformers_qu def test_transformers_query_classifier_wrong_labels(): - with pytest.warns(None, match="The provided labels do not match the model labels"): + with pytest.raises(ValueError, match="For text-classification, the provided labels must match the model labels"): query_classifier = TransformersQueryClassifier( model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection", use_gpu=False, @@ -74,12 +74,12 @@ def test_transformers_query_classifier_wrong_labels(): ) -def test_zero_shot_transformers_query_classifier_no_labels(): - with pytest.raises(ValueError, match="Candidate labels must be provided for task zero-shot-classification"): +def test_transformers_query_classifier_no_labels(): + with pytest.raises(ValueError, match="The labels must be provided"): query_classifier = TransformersQueryClassifier( - model_name_or_path="typeform/distilbert-base-uncased-mnli", + model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection", use_gpu=False, - task="zero-shot-classification", + task="text-classification", labels=None, ) From f59f396533c6d5173ea56fd2aa6716a8d42db7a8 Mon Sep 17 00:00:00 2001 From: anakin87 <44616784+anakin87@users.noreply.github.com> Date: Fri, 5 Aug 2022 15:08:01 +0200 Subject: [PATCH 6/8] fix _calculate_outgoing_edges --- docs/_src/api/api/query_classifier.md | 2 +- haystack/nodes/query_classifier/transformers.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/_src/api/api/query_classifier.md b/docs/_src/api/api/query_classifier.md index 3a405e1599..23ee398b3b 100644 --- a/docs/_src/api/api/query_classifier.md +++ b/docs/_src/api/api/query_classifier.md @@ -143,7 +143,7 @@ This node also supports zero-shot-classification. #### TransformersQueryClassifier.\_\_init\_\_ ```python -def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: List[str] = ["LABEL_1", "LABEL_0"], batch_size: int = 16) +def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: List[str] = DEFAULT_LABELS, batch_size: int = 16) ``` **Arguments**: diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py index 2b4b6cfe70..c59e99ac3f 100644 --- a/haystack/nodes/query_classifier/transformers.py +++ b/haystack/nodes/query_classifier/transformers.py @@ -9,6 +9,8 @@ logger = logging.getLogger(__name__) +DEFAULT_LABELS = ["LABEL_1", "LABEL_0"] + class TransformersQueryClassifier(BaseQueryClassifier): """ @@ -62,7 +64,7 @@ def __init__( tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", - labels: List[str] = ["LABEL_1", "LABEL_0"], + labels: List[str] = DEFAULT_LABELS, batch_size: int = 16, ): """ @@ -86,6 +88,8 @@ def __init__( ) self.labels = labels + if labels is None or len(labels) == 0: + raise ValueError("The labels must be provided") if task == "text-classification": labels_from_model = [label for label in self.model.model.config.id2label.values()] if set(labels) != set(labels_from_model): @@ -104,7 +108,7 @@ def __init__( @classmethod def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int: - labels = component_params.get("labels", None) + labels = component_params.get("labels", DEFAULT_LABELS) if labels is None or len(labels) == 0: raise ValueError("The labels must be provided") return len(labels) From d87c960c2eeaee2836a48ba9087b9d17dffcc290 Mon Sep 17 00:00:00 2001 From: anakin87 <44616784+anakin87@users.noreply.github.com> Date: Fri, 5 Aug 2022 15:57:29 +0200 Subject: [PATCH 7/8] remove duplicate label validation --- haystack/nodes/query_classifier/transformers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py index c59e99ac3f..6bf5226eca 100644 --- a/haystack/nodes/query_classifier/transformers.py +++ b/haystack/nodes/query_classifier/transformers.py @@ -88,8 +88,6 @@ def __init__( ) self.labels = labels - if labels is None or len(labels) == 0: - raise ValueError("The labels must be provided") if task == "text-classification": labels_from_model = [label for label in self.model.model.config.id2label.values()] if set(labels) != set(labels_from_model): From 2bd4cbe7eaecd90f7cb4c7928f6433701050d67b Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Mon, 8 Aug 2022 23:51:58 +0200 Subject: [PATCH 8/8] Remove print --- haystack/nodes/query_classifier/transformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py index c4c2fb7af5..4b92c840a6 100644 --- a/haystack/nodes/query_classifier/transformers.py +++ b/haystack/nodes/query_classifier/transformers.py @@ -149,7 +149,6 @@ def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # ty desc="Classifying queries", ): all_predictions.extend([predictions]) - print(all_predictions) results = {f"output_{self._get_edge_number_from_label(label)}": {"queries": []} for label in self.labels} # type: ignore for query, prediction in zip(queries, all_predictions): if self.task == "zero-shot-classification":