From 559ba03818634bb85443e9d99826140be70de54f Mon Sep 17 00:00:00 2001
From: anakin87 <44616784+anakin87@users.noreply.github.com>
Date: Thu, 4 Aug 2022 12:16:03 +0200
Subject: [PATCH 1/8] extend query classifier in one commit
---
docs/_src/api/api/query_classifier.md | 23 +++-
.../haystack-pipeline-master.schema.json | 24 ++++
.../nodes/query_classifier/transformers.py | 105 ++++++++++++------
test/nodes/test_query_classifier.py | 94 ++++++++++++++++
4 files changed, 210 insertions(+), 36 deletions(-)
create mode 100644 test/nodes/test_query_classifier.py
diff --git a/docs/_src/api/api/query_classifier.md b/docs/_src/api/api/query_classifier.md
index e56e2095fb..3d8683e564 100644
--- a/docs/_src/api/api/query_classifier.md
+++ b/docs/_src/api/api/query_classifier.md
@@ -95,10 +95,15 @@ queries or statement vs question queries.
class TransformersQueryClassifier(BaseQueryClassifier)
```
-A node to classify an incoming query into one of two categories using a (small) BERT transformer model.
+
+
+#### outgoing\_edges
+
+A node to classify an incoming query into categories using a transformer model.
Depending on the result, the query flows to a different branch in your pipeline and the further processing
-can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2`
+can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n`
from this node.
+This node also supports zero-shot-classification.
**Example**:
@@ -119,7 +124,7 @@ from this node.
Models:
- Pass your own `Transformer` binary classification model from file/huggingface or use one of the following
+ Pass your own `Transformer` classification/zero-shot-classification model from file/huggingface or use one of the following
pretrained ones hosted on Huggingface:
1) Keywords vs. Questions/Statements (Default)
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection"
@@ -142,11 +147,19 @@ from this node.
#### TransformersQueryClassifier.\_\_init\_\_
```python
-def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", use_gpu: bool = True, batch_size: Optional[int] = None)
+def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: Optional[List[str]] = ["LABEL_1", "LABEL_0"], batch_size: Optional[int] = None)
```
**Arguments**:
-- `model_name_or_path`: Transformer based fine tuned mini bert model for query classification
+- `model_name_or_path`: Directory of a saved model or the name of a public model, for example 'shahrukhx01/bert-mini-finetune-question-detection'.
+See [Hugging Face models](https://huggingface.co/models) for a full list of available models.
+- `model_version`: The version of the model to use from the Hugging Face model hub. This can be a tag name, a branch name, or a commit hash.
+- `tokenizer`: The name of the tokenizer (usually the same as model).
- `use_gpu`: Whether to use GPU (if available).
+- `task`: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'.
+- `labels`: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1,
+the second label to output_2, and so on. The labels must match the model labels; only the order can differ. Otherwise, model labels are considered.
+If the task is 'zero-shot-classification', these are the candidate labels.
+- `batch_size`: The number of queries to be processed at a time.
diff --git a/haystack/json-schemas/haystack-pipeline-master.schema.json b/haystack/json-schemas/haystack-pipeline-master.schema.json
index 16644ec408..92613fa56e 100644
--- a/haystack/json-schemas/haystack-pipeline-master.schema.json
+++ b/haystack/json-schemas/haystack-pipeline-master.schema.json
@@ -4726,11 +4726,35 @@
}
]
},
+ "model_version": {
+ "title": "Model Version",
+ "type": "string"
+ },
+ "tokenizer": {
+ "title": "Tokenizer",
+ "type": "string"
+ },
"use_gpu": {
"title": "Use Gpu",
"default": true,
"type": "boolean"
},
+ "task": {
+ "title": "Task",
+ "default": "text-classification",
+ "type": "string"
+ },
+ "labels": {
+ "title": "Labels",
+ "default": [
+ "LABEL_1",
+ "LABEL_0"
+ ],
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
"batch_size": {
"title": "Batch Size",
"type": "integer"
diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py
index 2b73c268da..e6aa12e5b4 100644
--- a/haystack/nodes/query_classifier/transformers.py
+++ b/haystack/nodes/query_classifier/transformers.py
@@ -1,8 +1,8 @@
import logging
from pathlib import Path
-from typing import Union, List, Optional, Dict
+from typing import Union, List, Optional
-from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
+from transformers import pipeline
from haystack.nodes.query_classifier.base import BaseQueryClassifier
from haystack.modeling.utils import initialize_device_settings
@@ -11,11 +11,15 @@
class TransformersQueryClassifier(BaseQueryClassifier):
+
+ outgoing_edges: int = 10
+
"""
- A node to classify an incoming query into one of two categories using a (small) BERT transformer model.
+ A node to classify an incoming query into categories using a transformer model.
Depending on the result, the query flows to a different branch in your pipeline and the further processing
- can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2`
+ can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n`
from this node.
+ This node also supports zero-shot-classification.
Example:
```python
@@ -35,7 +39,7 @@ class TransformersQueryClassifier(BaseQueryClassifier):
Models:
- Pass your own `Transformer` binary classification model from file/huggingface or use one of the following
+ Pass your own `Transformer` classification/zero-shot-classification model from file/huggingface or use one of the following
pretrained ones hosted on Huggingface:
1) Keywords vs. Questions/Statements (Default)
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection"
@@ -57,43 +61,82 @@ class TransformersQueryClassifier(BaseQueryClassifier):
def __init__(
self,
model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection",
+ model_version: Optional[str] = None,
+ tokenizer: Optional[str] = None,
use_gpu: bool = True,
+ task: str = "text-classification",
+ labels: Optional[List[str]] = ["LABEL_1", "LABEL_0"],
batch_size: Optional[int] = None,
):
"""
- :param model_name_or_path: Transformer based fine tuned mini bert model for query classification
+ :param model_name_or_path: Directory of a saved model or the name of a public model, for example 'shahrukhx01/bert-mini-finetune-question-detection'.
+ See [Hugging Face models](https://huggingface.co/models) for a full list of available models.
+ :param model_version: The version of the model to use from the Hugging Face model hub. This can be a tag name, a branch name, or a commit hash.
+ :param tokenizer: The name of the tokenizer (usually the same as model).
:param use_gpu: Whether to use GPU (if available).
+ :param task: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'.
+ :param labels: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1,
+ the second label to output_2, and so on. The labels must match the model labels; only the order can differ. Otherwise, model labels are considered.
+ If the task is 'zero-shot-classification', these are the candidate labels.
+ :param batch_size: The number of queries to be processed at a time.
"""
super().__init__()
-
- self.devices, _ = initialize_device_settings(use_cuda=use_gpu)
+ devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
+ device = 0 if devices[0].type == "cuda" else -1
+
+ self.model = pipeline(
+ task=task, model=model_name_or_path, tokenizer=tokenizer, device=device, revision=model_version
+ )
+
+ self.labels = labels
+ if task == "zero-shot-classification":
+ if labels is None or len(labels) == 0:
+ raise ValueError("Candidate labels must be provided for task zero-shot-classification")
+ elif task == "text-classification":
+ labels_from_model = [label for label in self.model.model.config.id2label.values()]
+ if labels is None:
+ self.labels = labels_from_model
+ elif set(labels) != set(labels_from_model):
+ self.labels = labels_from_model
+ logger.warning(
+ f"The provided labels do not match the model labels, then the model labels are used.\n"
+ f"Provided labels: {labels}\n"
+ f"Model labels: {labels_from_model}"
+ )
+ else:
+ raise ValueError(
+ f"Task not supported: {task}.\n"
+ f"Possible task values are: 'text-classification' or 'zero-shot-classification'"
+ )
+ self.task = task
self.batch_size = batch_size
- device = 0 if self.devices[0].type == "cuda" else -1
-
- model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
- tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
- self.query_classification_pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer, device=device)
+ def _get_edge_number(self, label):
+ return self.labels.index(label) + 1
- def run(self, query):
- is_question: bool = self.query_classification_pipeline(query)[0]["label"] == "LABEL_1"
-
- if is_question:
- return {}, "output_1"
- else:
- return {}, "output_2"
+ def run(self, query: str): # type: ignore
+ if self.task == "zero-shot-classification":
+ prediction = self.model([query], candidate_labels=self.labels, truncation=True)
+ label = prediction[0]["labels"][0]
+ elif self.task == "text-classification":
+ prediction = self.model([query], truncation=True)
+ label = prediction[0]["label"]
+ return {}, f"output_{self._get_edge_number(label)}"
def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # type: ignore
if batch_size is None:
batch_size = self.batch_size
-
- split: Dict[str, Dict[str, List]] = {"output_1": {"queries": []}, "output_2": {"queries": []}}
-
- predictions = self.query_classification_pipeline(queries, batch_size=batch_size)
- for query, pred in zip(queries, predictions):
- if pred["label"] == "LABEL_1":
- split["output_1"]["queries"].append(query)
- else:
- split["output_2"]["queries"].append(query)
-
- return split, "split"
+ if self.task == "zero-shot-classification":
+ predictions = self.model(queries, candidate_labels=self.labels, truncation=True, batch_size=batch_size)
+ elif self.task == "text-classification":
+ predictions = self.model(queries, truncation=True, batch_size=batch_size)
+
+ results = {f"output_{self._get_edge_number(label)}": {"queries": []} for label in self.labels} # type: ignore
+ for query, prediction in zip(queries, predictions):
+ if self.task == "zero-shot-classification":
+ label = prediction["labels"][0]
+ elif self.task == "text-classification":
+ label = prediction["label"]
+ results[f"output_{self._get_edge_number(label)}"]["queries"].append(query)
+
+ return results, "split"
diff --git a/test/nodes/test_query_classifier.py b/test/nodes/test_query_classifier.py
new file mode 100644
index 0000000000..bca56f12dd
--- /dev/null
+++ b/test/nodes/test_query_classifier.py
@@ -0,0 +1,94 @@
+import pytest
+from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier
+
+
+@pytest.fixture
+def transformers_query_classifier():
+ return TransformersQueryClassifier(
+ model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
+ use_gpu=False,
+ task="text-classification",
+ labels=["LABEL_1", "LABEL_0"],
+ )
+
+
+@pytest.fixture
+def zero_shot_transformers_query_classifier():
+ return TransformersQueryClassifier(
+ model_name_or_path="typeform/distilbert-base-uncased-mnli",
+ use_gpu=False,
+ task="zero-shot-classification",
+ labels=["happy", "unhappy", "neutral"],
+ )
+
+
+def test_transformers_query_classifier(transformers_query_classifier):
+ output = transformers_query_classifier.run(query="morse code")
+ assert output == ({}, "output_2")
+
+ output = transformers_query_classifier.run(query="How old is John?")
+ assert output == ({}, "output_1")
+
+
+def test_transformers_query_classifier_batch(transformers_query_classifier):
+ queries = ["morse code", "How old is John?"]
+ output = transformers_query_classifier.run_batch(queries=queries)
+
+ assert output[0] == {"output_2": {"queries": ["morse code"]}, "output_1": {"queries": ["How old is John?"]}}
+
+
+def test_zero_shot_transformers_query_classifier(zero_shot_transformers_query_classifier):
+ output = zero_shot_transformers_query_classifier.run(query="What's the answer?")
+ assert output == ({}, "output_3")
+
+ output = zero_shot_transformers_query_classifier.run(query="Would you be so kind to tell me the answer?")
+ assert output == ({}, "output_1")
+
+ output = zero_shot_transformers_query_classifier.run(query="Can you give me the right answer for once??")
+ assert output == ({}, "output_2")
+
+
+def test_zero_shot_transformers_query_classifier_batch(zero_shot_transformers_query_classifier):
+ queries = [
+ "What's the answer?",
+ "Would you be so kind to tell me the answer?",
+ "Can you give me the right answer for once??",
+ ]
+
+ output = zero_shot_transformers_query_classifier.run_batch(queries=queries)
+
+ assert output[0] == {
+ "output_3": {"queries": ["What's the answer?"]},
+ "output_1": {"queries": ["Would you be so kind to tell me the answer?"]},
+ "output_2": {"queries": ["Can you give me the right answer for once??"]},
+ }
+
+
+def test_transformers_query_classifier_wrong_labels():
+ with pytest.warns(None, match="The provided labels do not match the model labels"):
+ query_classifier = TransformersQueryClassifier(
+ model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
+ use_gpu=False,
+ task="text-classification",
+ labels=["WRONG_LABEL_1", "WRONG_LABEL_2", "WRONG_LABEL_3"],
+ )
+
+
+def test_zero_shot_transformers_query_classifier_no_labels():
+ with pytest.raises(ValueError):
+ query_classifier = TransformersQueryClassifier(
+ model_name_or_path="typeform/distilbert-base-uncased-mnli",
+ use_gpu=False,
+ task="zero-shot-classification",
+ labels=None,
+ )
+
+
+def test_transformers_query_classifier_unsupported_task():
+ with pytest.raises(ValueError):
+ query_classifier = TransformersQueryClassifier(
+ model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
+ use_gpu=False,
+ task="summarization",
+ labels=["LABEL_1", "LABEL_0"],
+ )
From 6ba187b1d6dea514eb01380e8a54ae8d5f35bec9 Mon Sep 17 00:00:00 2001
From: anakin87 <44616784+anakin87@users.noreply.github.com>
Date: Thu, 4 Aug 2022 14:58:00 +0200
Subject: [PATCH 2/8] variable number of outgoing edges
---
docs/_src/api/api/query_classifier.md | 4 ----
.../nodes/query_classifier/transformers.py | 20 +++++++++++--------
2 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/docs/_src/api/api/query_classifier.md b/docs/_src/api/api/query_classifier.md
index b95ab11f00..1fd1695689 100644
--- a/docs/_src/api/api/query_classifier.md
+++ b/docs/_src/api/api/query_classifier.md
@@ -95,10 +95,6 @@ queries or statement vs question queries.
class TransformersQueryClassifier(BaseQueryClassifier)
```
-
-
-#### outgoing\_edges
-
A node to classify an incoming query into categories using a transformer model.
Depending on the result, the query flows to a different branch in your pipeline and the further processing
can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n`
diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py
index 6d549c1d0b..de7c242bd2 100644
--- a/haystack/nodes/query_classifier/transformers.py
+++ b/haystack/nodes/query_classifier/transformers.py
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
-from typing import Union, List, Optional
+from typing import Union, List, Optional, Dict, Any
from transformers import pipeline
from haystack.nodes.query_classifier.base import BaseQueryClassifier
@@ -11,9 +11,6 @@
class TransformersQueryClassifier(BaseQueryClassifier):
-
- outgoing_edges: int = 10
-
"""
A node to classify an incoming query into categories using a transformer model.
Depending on the result, the query flows to a different branch in your pipeline and the further processing
@@ -111,7 +108,14 @@ def __init__(
self.task = task
self.batch_size = batch_size
- def _get_edge_number(self, label):
+ @classmethod
+ def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
+ labels = component_params.get("labels", None)
+ if labels is None:
+ return 2
+ return len(labels)
+
+ def _get_edge_number_from_label(self, label):
return self.labels.index(label) + 1
def run(self, query: str): # type: ignore
@@ -121,7 +125,7 @@ def run(self, query: str): # type: ignore
elif self.task == "text-classification":
prediction = self.model([query], truncation=True)
label = prediction[0]["label"]
- return {}, f"output_{self._get_edge_number(label)}"
+ return {}, f"output_{self._get_edge_number_from_label(label)}"
def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # type: ignore
if batch_size is None:
@@ -131,12 +135,12 @@ def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # ty
elif self.task == "text-classification":
predictions = self.model(queries, truncation=True, batch_size=batch_size)
- results = {f"output_{self._get_edge_number(label)}": {"queries": []} for label in self.labels} # type: ignore
+ results = {f"output_{self._get_edge_number_from_label(label)}": {"queries": []} for label in self.labels} # type: ignore
for query, prediction in zip(queries, predictions):
if self.task == "zero-shot-classification":
label = prediction["labels"][0]
elif self.task == "text-classification":
label = prediction["label"]
- results[f"output_{self._get_edge_number(label)}"]["queries"].append(query)
+ results[f"output_{self._get_edge_number_from_label(label)}"]["queries"].append(query)
return results, "split"
From a49c750a915ad7a8ce0832540c6a0c93a198313a Mon Sep 17 00:00:00 2001
From: anakin87 <44616784+anakin87@users.noreply.github.com>
Date: Thu, 4 Aug 2022 16:16:05 +0200
Subject: [PATCH 3/8] improve tests
---
haystack/nodes/query_classifier/transformers.py | 9 ++-------
test/nodes/test_query_classifier.py | 4 ++--
2 files changed, 4 insertions(+), 9 deletions(-)
diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py
index de7c242bd2..bbd40b66e6 100644
--- a/haystack/nodes/query_classifier/transformers.py
+++ b/haystack/nodes/query_classifier/transformers.py
@@ -55,6 +55,8 @@ class TransformersQueryClassifier(BaseQueryClassifier):
See also the [tutorial](https://haystack.deepset.ai/tutorials/pipelines) on pipelines.
"""
+ outgoing_edges = 10
+
def __init__(
self,
model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection",
@@ -108,13 +110,6 @@ def __init__(
self.task = task
self.batch_size = batch_size
- @classmethod
- def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
- labels = component_params.get("labels", None)
- if labels is None:
- return 2
- return len(labels)
-
def _get_edge_number_from_label(self, label):
return self.labels.index(label) + 1
diff --git a/test/nodes/test_query_classifier.py b/test/nodes/test_query_classifier.py
index bca56f12dd..57b0a922f5 100644
--- a/test/nodes/test_query_classifier.py
+++ b/test/nodes/test_query_classifier.py
@@ -75,7 +75,7 @@ def test_transformers_query_classifier_wrong_labels():
def test_zero_shot_transformers_query_classifier_no_labels():
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError, match="Candidate labels must be provided for task zero-shot-classification"):
query_classifier = TransformersQueryClassifier(
model_name_or_path="typeform/distilbert-base-uncased-mnli",
use_gpu=False,
@@ -85,7 +85,7 @@ def test_zero_shot_transformers_query_classifier_no_labels():
def test_transformers_query_classifier_unsupported_task():
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError, match="Task not supported"):
query_classifier = TransformersQueryClassifier(
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
use_gpu=False,
From 8a3e6b869e02a4861bf671118d93955573187430 Mon Sep 17 00:00:00 2001
From: anakin87 <44616784+anakin87@users.noreply.github.com>
Date: Thu, 4 Aug 2022 16:23:43 +0200
Subject: [PATCH 4/8] fix unused import
---
haystack/nodes/query_classifier/transformers.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py
index bbd40b66e6..8d1b78ee75 100644
--- a/haystack/nodes/query_classifier/transformers.py
+++ b/haystack/nodes/query_classifier/transformers.py
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
-from typing import Union, List, Optional, Dict, Any
+from typing import Union, List, Optional
from transformers import pipeline
from haystack.nodes.query_classifier.base import BaseQueryClassifier
From 780d5a1df41345efaed8ca137165adfbb33da253 Mon Sep 17 00:00:00 2001
From: anakin87 <44616784+anakin87@users.noreply.github.com>
Date: Fri, 5 Aug 2022 12:53:44 +0200
Subject: [PATCH 5/8] lightweight approach
---
docs/_src/api/api/query_classifier.md | 4 +--
.../nodes/query_classifier/transformers.py | 31 +++++++++----------
test/nodes/test_query_classifier.py | 10 +++---
3 files changed, 22 insertions(+), 23 deletions(-)
diff --git a/docs/_src/api/api/query_classifier.md b/docs/_src/api/api/query_classifier.md
index 1fd1695689..3a405e1599 100644
--- a/docs/_src/api/api/query_classifier.md
+++ b/docs/_src/api/api/query_classifier.md
@@ -143,7 +143,7 @@ This node also supports zero-shot-classification.
#### TransformersQueryClassifier.\_\_init\_\_
```python
-def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: Optional[List[str]] = ["LABEL_1", "LABEL_0"], batch_size: int = 16)
+def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: List[str] = ["LABEL_1", "LABEL_0"], batch_size: int = 16)
```
**Arguments**:
@@ -155,7 +155,7 @@ See [Hugging Face models](https://huggingface.co/models) for a full list of avai
- `use_gpu`: Whether to use GPU (if available).
- `task`: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'.
- `labels`: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1,
-the second label to output_2, and so on. The labels must match the model labels; only the order can differ. Otherwise, model labels are considered.
+the second label to output_2, and so on. The labels must match the model labels; only the order can differ.
If the task is 'zero-shot-classification', these are the candidate labels.
- `batch_size`: The number of queries to be processed at a time.
diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py
index 8d1b78ee75..2b4b6cfe70 100644
--- a/haystack/nodes/query_classifier/transformers.py
+++ b/haystack/nodes/query_classifier/transformers.py
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
-from typing import Union, List, Optional
+from typing import Union, List, Optional, Dict, Any
from transformers import pipeline
from haystack.nodes.query_classifier.base import BaseQueryClassifier
@@ -55,8 +55,6 @@ class TransformersQueryClassifier(BaseQueryClassifier):
See also the [tutorial](https://haystack.deepset.ai/tutorials/pipelines) on pipelines.
"""
- outgoing_edges = 10
-
def __init__(
self,
model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection",
@@ -64,7 +62,7 @@ def __init__(
tokenizer: Optional[str] = None,
use_gpu: bool = True,
task: str = "text-classification",
- labels: Optional[List[str]] = ["LABEL_1", "LABEL_0"],
+ labels: List[str] = ["LABEL_1", "LABEL_0"],
batch_size: int = 16,
):
"""
@@ -75,7 +73,7 @@ def __init__(
:param use_gpu: Whether to use GPU (if available).
:param task: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'.
:param labels: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1,
- the second label to output_2, and so on. The labels must match the model labels; only the order can differ. Otherwise, model labels are considered.
+ the second label to output_2, and so on. The labels must match the model labels; only the order can differ.
If the task is 'zero-shot-classification', these are the candidate labels.
:param batch_size: The number of queries to be processed at a time.
"""
@@ -88,21 +86,15 @@ def __init__(
)
self.labels = labels
- if task == "zero-shot-classification":
- if labels is None or len(labels) == 0:
- raise ValueError("Candidate labels must be provided for task zero-shot-classification")
- elif task == "text-classification":
+ if task == "text-classification":
labels_from_model = [label for label in self.model.model.config.id2label.values()]
- if labels is None:
- self.labels = labels_from_model
- elif set(labels) != set(labels_from_model):
- self.labels = labels_from_model
- logger.warning(
- f"The provided labels do not match the model labels, then the model labels are used.\n"
+ if set(labels) != set(labels_from_model):
+ raise ValueError(
+ f"For text-classification, the provided labels must match the model labels; only the order can differ.\n"
f"Provided labels: {labels}\n"
f"Model labels: {labels_from_model}"
)
- else:
+ if task not in ["text-classification", "zero-shot-classification"]:
raise ValueError(
f"Task not supported: {task}.\n"
f"Possible task values are: 'text-classification' or 'zero-shot-classification'"
@@ -110,6 +102,13 @@ def __init__(
self.task = task
self.batch_size = batch_size
+ @classmethod
+ def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
+ labels = component_params.get("labels", None)
+ if labels is None or len(labels) == 0:
+ raise ValueError("The labels must be provided")
+ return len(labels)
+
def _get_edge_number_from_label(self, label):
return self.labels.index(label) + 1
diff --git a/test/nodes/test_query_classifier.py b/test/nodes/test_query_classifier.py
index 57b0a922f5..a96eec594e 100644
--- a/test/nodes/test_query_classifier.py
+++ b/test/nodes/test_query_classifier.py
@@ -65,7 +65,7 @@ def test_zero_shot_transformers_query_classifier_batch(zero_shot_transformers_qu
def test_transformers_query_classifier_wrong_labels():
- with pytest.warns(None, match="The provided labels do not match the model labels"):
+ with pytest.raises(ValueError, match="For text-classification, the provided labels must match the model labels"):
query_classifier = TransformersQueryClassifier(
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
use_gpu=False,
@@ -74,12 +74,12 @@ def test_transformers_query_classifier_wrong_labels():
)
-def test_zero_shot_transformers_query_classifier_no_labels():
- with pytest.raises(ValueError, match="Candidate labels must be provided for task zero-shot-classification"):
+def test_transformers_query_classifier_no_labels():
+ with pytest.raises(ValueError, match="The labels must be provided"):
query_classifier = TransformersQueryClassifier(
- model_name_or_path="typeform/distilbert-base-uncased-mnli",
+ model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
use_gpu=False,
- task="zero-shot-classification",
+ task="text-classification",
labels=None,
)
From f59f396533c6d5173ea56fd2aa6716a8d42db7a8 Mon Sep 17 00:00:00 2001
From: anakin87 <44616784+anakin87@users.noreply.github.com>
Date: Fri, 5 Aug 2022 15:08:01 +0200
Subject: [PATCH 6/8] fix _calculate_outgoing_edges
---
docs/_src/api/api/query_classifier.md | 2 +-
haystack/nodes/query_classifier/transformers.py | 8 ++++++--
2 files changed, 7 insertions(+), 3 deletions(-)
diff --git a/docs/_src/api/api/query_classifier.md b/docs/_src/api/api/query_classifier.md
index 3a405e1599..23ee398b3b 100644
--- a/docs/_src/api/api/query_classifier.md
+++ b/docs/_src/api/api/query_classifier.md
@@ -143,7 +143,7 @@ This node also supports zero-shot-classification.
#### TransformersQueryClassifier.\_\_init\_\_
```python
-def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: List[str] = ["LABEL_1", "LABEL_0"], batch_size: int = 16)
+def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: List[str] = DEFAULT_LABELS, batch_size: int = 16)
```
**Arguments**:
diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py
index 2b4b6cfe70..c59e99ac3f 100644
--- a/haystack/nodes/query_classifier/transformers.py
+++ b/haystack/nodes/query_classifier/transformers.py
@@ -9,6 +9,8 @@
logger = logging.getLogger(__name__)
+DEFAULT_LABELS = ["LABEL_1", "LABEL_0"]
+
class TransformersQueryClassifier(BaseQueryClassifier):
"""
@@ -62,7 +64,7 @@ def __init__(
tokenizer: Optional[str] = None,
use_gpu: bool = True,
task: str = "text-classification",
- labels: List[str] = ["LABEL_1", "LABEL_0"],
+ labels: List[str] = DEFAULT_LABELS,
batch_size: int = 16,
):
"""
@@ -86,6 +88,8 @@ def __init__(
)
self.labels = labels
+ if labels is None or len(labels) == 0:
+ raise ValueError("The labels must be provided")
if task == "text-classification":
labels_from_model = [label for label in self.model.model.config.id2label.values()]
if set(labels) != set(labels_from_model):
@@ -104,7 +108,7 @@ def __init__(
@classmethod
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
- labels = component_params.get("labels", None)
+ labels = component_params.get("labels", DEFAULT_LABELS)
if labels is None or len(labels) == 0:
raise ValueError("The labels must be provided")
return len(labels)
From d87c960c2eeaee2836a48ba9087b9d17dffcc290 Mon Sep 17 00:00:00 2001
From: anakin87 <44616784+anakin87@users.noreply.github.com>
Date: Fri, 5 Aug 2022 15:57:29 +0200
Subject: [PATCH 7/8] remove duplicate label validation
---
haystack/nodes/query_classifier/transformers.py | 2 --
1 file changed, 2 deletions(-)
diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py
index c59e99ac3f..6bf5226eca 100644
--- a/haystack/nodes/query_classifier/transformers.py
+++ b/haystack/nodes/query_classifier/transformers.py
@@ -88,8 +88,6 @@ def __init__(
)
self.labels = labels
- if labels is None or len(labels) == 0:
- raise ValueError("The labels must be provided")
if task == "text-classification":
labels_from_model = [label for label in self.model.model.config.id2label.values()]
if set(labels) != set(labels_from_model):
From 2bd4cbe7eaecd90f7cb4c7928f6433701050d67b Mon Sep 17 00:00:00 2001
From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
Date: Mon, 8 Aug 2022 23:51:58 +0200
Subject: [PATCH 8/8] Remove print
---
haystack/nodes/query_classifier/transformers.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py
index c4c2fb7af5..4b92c840a6 100644
--- a/haystack/nodes/query_classifier/transformers.py
+++ b/haystack/nodes/query_classifier/transformers.py
@@ -149,7 +149,6 @@ def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # ty
desc="Classifying queries",
):
all_predictions.extend([predictions])
- print(all_predictions)
results = {f"output_{self._get_edge_number_from_label(label)}": {"queries": []} for label in self.labels} # type: ignore
for query, prediction in zip(queries, all_predictions):
if self.task == "zero-shot-classification":