Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify language_modeling.py and tokenization.py #2703

Merged
merged 119 commits into from Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
f800184
Simplification of language_model.py to remove code duplication
ZanSara Jun 6, 2022
91caa7f
restructure language_model.py
ZanSara Jun 10, 2022
0f0fb64
Merge branch 'master' into image_retriever
ZanSara Jun 15, 2022
23d38ec
Working on removing Tokenizer
ZanSara Jun 15, 2022
c61ed79
Removing Tokenizer
ZanSara Jun 15, 2022
a7c9bc0
working on normalizing DPR implementation too
ZanSara Jun 17, 2022
b6b4e1d
Fixing dpr issue in test
ZanSara Jun 21, 2022
268cacd
Fixing DPRetriever, Embedding Retriever and usage of new API in modeling
ZanSara Jun 22, 2022
39419f3
Update Documentation & Code Style
github-actions[bot] Jun 22, 2022
6d4857b
Remove mentions to data2vecvision
ZanSara Jun 22, 2022
a551c05
Minor fixes
ZanSara Jun 22, 2022
63ab0cb
Update Documentation & Code Style
github-actions[bot] Jun 22, 2022
34b9973
fixing mypy issues
ZanSara Jun 22, 2022
e01efd0
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jun 22, 2022
0253b14
Update Documentation & Code Style
github-actions[bot] Jun 22, 2022
d78d55a
typing tokenization better
ZanSara Jun 22, 2022
263f55d
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jun 22, 2022
7551850
more fixes for mypy
ZanSara Jun 22, 2022
e78fe2e
Update Documentation & Code Style
github-actions[bot] Jun 22, 2022
8ed07ff
pylint
ZanSara Jun 22, 2022
4226eea
more mypy
ZanSara Jun 22, 2022
26d9eb0
more mypy
ZanSara Jun 22, 2022
e4e9ba1
remove merge tags
ZanSara Jun 22, 2022
7fc2443
Update Documentation & Code Style
github-actions[bot] Jun 22, 2022
94d5d0d
mypy
ZanSara Jun 22, 2022
c3db8a4
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jun 22, 2022
d385399
Update Documentation & Code Style
github-actions[bot] Jun 22, 2022
8c20ef0
mypy again
ZanSara Jun 22, 2022
cdb3b11
Update Documentation & Code Style
github-actions[bot] Jun 22, 2022
73f3a4a
last mypy errors
ZanSara Jun 22, 2022
9770986
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jun 22, 2022
dc9f753
Update Documentation & Code Style
github-actions[bot] Jun 22, 2022
0b36915
Add n_added_tokens to DPREncoder.__init__ for compatibility
ZanSara Jun 22, 2022
ba45908
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jun 22, 2022
fe957c1
fix tests
ZanSara Jun 22, 2022
efcd132
Merge branch 'master' into simplify-language-modeling
ZanSara Jun 24, 2022
129cbf6
not all models have encoders
ZanSara Jun 24, 2022
7442438
comma
ZanSara Jun 24, 2022
f4676b5
Update Documentation & Code Style
github-actions[bot] Jun 24, 2022
0c714bb
using params
ZanSara Jun 24, 2022
0cba55b
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jun 24, 2022
23cf8f2
Update Documentation & Code Style
github-actions[bot] Jun 24, 2022
922340a
trying to simplify DPREncoder
ZanSara Jun 24, 2022
10c5c89
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jun 24, 2022
17db500
Update Documentation & Code Style
github-actions[bot] Jun 24, 2022
9060896
Fix wrong param
ZanSara Jun 24, 2022
bf7bbce
mypy
ZanSara Jun 24, 2022
de23cf6
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jun 24, 2022
2b36db8
Update Documentation & Code Style
github-actions[bot] Jun 24, 2022
67b84da
Use segment_ids instead of token_type_ids
ZanSara Jun 24, 2022
43a501b
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jun 24, 2022
4d56310
Update Documentation & Code Style
github-actions[bot] Jun 24, 2022
2872966
Fix question_generator tests & factor out distilbert
ZanSara Jul 4, 2022
20f0e1d
Update Documentation & Code Style
github-actions[bot] Jul 4, 2022
b8287cf
mypy
ZanSara Jul 4, 2022
021f8c2
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 4, 2022
6278f69
mypy again
ZanSara Jul 4, 2022
35e464f
Update Documentation & Code Style
github-actions[bot] Jul 4, 2022
3e5f080
remove infer_tokenizer_classes from dense.py, unused
ZanSara Jul 4, 2022
03e13f5
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 4, 2022
3f97367
Update Documentation & Code Style
github-actions[bot] Jul 4, 2022
4c2d750
Merge branch 'master' into simplify-language-modeling
ZanSara Jul 4, 2022
e514b2f
Remove usage of kwargs in evaluation.py eval()
ZanSara Jul 4, 2022
16967bb
Merge branch 'master' into simplify-language-modeling
ZanSara Jul 4, 2022
b4c6beb
fix log
ZanSara Jul 4, 2022
1797d22
typo
ZanSara Jul 4, 2022
5cf92a6
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 4, 2022
c32e10c
Fix dpr tests
ZanSara Jul 6, 2022
41c8b1d
Update Documentation & Code Style
github-actions[bot] Jul 6, 2022
3770101
typo
ZanSara Jul 7, 2022
4c27ce1
Update Documentation & Code Style
github-actions[bot] Jul 7, 2022
676d554
mypy
ZanSara Jul 7, 2022
fe95b11
pylint
ZanSara Jul 7, 2022
2efa817
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 7, 2022
b662f0f
Update Documentation & Code Style
github-actions[bot] Jul 7, 2022
2d90609
capitalize model type
ZanSara Jul 7, 2022
09eb226
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 7, 2022
7df6778
Update Documentation & Code Style
github-actions[bot] Jul 7, 2022
b5f5b40
mypy
ZanSara Jul 7, 2022
3d58bc7
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 7, 2022
8b56255
mypy
ZanSara Jul 7, 2022
cbb644a
Update Documentation & Code Style
github-actions[bot] Jul 7, 2022
f4a37bf
mypy
ZanSara Jul 7, 2022
29be41c
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 7, 2022
38baeb7
mypy
ZanSara Jul 7, 2022
1a458cf
mypy
ZanSara Jul 7, 2022
976123a
Merge branch 'master' into simplify-language-modeling
ZanSara Jul 7, 2022
1d47c90
fix tests
ZanSara Jul 7, 2022
cacc11d
Update Documentation & Code Style
github-actions[bot] Jul 7, 2022
e4c08ba
typing
ZanSara Jul 8, 2022
6597c94
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 8, 2022
888d3d6
Fix for triadaptive model
ZanSara Jul 12, 2022
1762761
Update Documentation & Code Style
github-actions[bot] Jul 12, 2022
d47338e
Remove comment
ZanSara Jul 12, 2022
d52fe19
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 12, 2022
5e6c83a
split capitalize_and_get_class
ZanSara Jul 13, 2022
81ae6ac
improve triadaptive_model.py
ZanSara Jul 13, 2022
7b464e6
simplifying more **kwargs
ZanSara Jul 13, 2022
8bdb42b
more **kwargs gone
ZanSara Jul 13, 2022
c909599
Update Documentation & Code Style
github-actions[bot] Jul 13, 2022
a1685bd
mypy & pylint
ZanSara Jul 13, 2022
c1034a8
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 13, 2022
44c7726
Update Documentation & Code Style
github-actions[bot] Jul 13, 2022
d5eb606
mypy & pylint again
ZanSara Jul 13, 2022
fbbea3b
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 13, 2022
0bb1104
Improve management of output_hidden_states
ZanSara Jul 13, 2022
34121d7
Update Documentation & Code Style
github-actions[bot] Jul 13, 2022
c5a6dd0
mypy
ZanSara Jul 13, 2022
f1cdba1
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 13, 2022
8df63f7
fix tests
ZanSara Jul 13, 2022
2e9f12f
remove excess params from trainer
ZanSara Jul 14, 2022
3a5b9ec
Update Documentation & Code Style
github-actions[bot] Jul 14, 2022
8cf3969
Merge remote-tracking branch 'upstream/master' into simplify-language…
ZanSara Jul 18, 2022
e7ebad4
simplifying tokenizer tests
ZanSara Jul 20, 2022
a5e2d9d
Test `haystack/modeling/language_model.py` and remove `model_type` (#…
ZanSara Jul 21, 2022
8ec0cbf
Merge branch 'simplify-language-modeling' of github.com:deepset-ai/ha…
ZanSara Jul 21, 2022
b7c3329
fix tokenization tests
ZanSara Jul 21, 2022
216ef43
Update Documentation & Code Style
github-actions[bot] Jul 21, 2022
0e7ec82
Adjust model_type resolution to check config architectures (#2871)
ZanSara Jul 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 5 additions & 8 deletions docs/_src/api/api/retriever.md
Expand Up @@ -519,7 +519,7 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
#### DensePassageRetriever.\_\_init\_\_

```python
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True)
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True)
```

Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
Expand Down Expand Up @@ -561,8 +561,6 @@ The title is expected to be present in doc.meta["name"] and can be supplied in t
before writing them to the DocumentStore like this:
{"text": "my text", "meta": {"name": "my title"}}.
- `use_fast_tokenizers`: Whether to use fast Rust tokenizers
- `infer_tokenizer_classes`: Whether to infer tokenizer class from the model config / name.
If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`.
- `similarity_function`: Which function to apply for calculating the similarity of query and passage embeddings during training.
Options: `dot_product` (Default) or `cosine`
- `global_loss_buffer_size`: Buffer size for all_gather() in DDP.
Expand Down Expand Up @@ -871,7 +869,7 @@ None

```python
@classmethod
def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", query_encoder_dir: str = "query_encoder", passage_encoder_dir: str = "passage_encoder", infer_tokenizer_classes: bool = False)
def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", query_encoder_dir: str = "query_encoder", passage_encoder_dir: str = "passage_encoder")
```

Load DensePassageRetriever from the specified directory.
Expand All @@ -895,7 +893,7 @@ Kostić, Bogdan, et al. (2021): "Multi-modal Retrieval of Tables and Texts Using
#### TableTextRetriever.\_\_init\_\_

```python
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True)
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True, use_fast: bool = True)
```

Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
Expand Down Expand Up @@ -923,8 +921,6 @@ This is the approach used in the original paper and is likely to improve
performance if your titles contain meaningful information for retrieval
(topic, entities etc.).
- `use_fast_tokenizers`: Whether to use fast Rust tokenizers
- `infer_tokenizer_classes`: Whether to infer tokenizer class from the model config / name.
If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`.
- `similarity_function`: Which function to apply for calculating the similarity of query and passage embeddings during training.
Options: `dot_product` (Default) or `cosine`
- `global_loss_buffer_size`: Buffer size for all_gather() in DDP.
Expand All @@ -942,6 +938,7 @@ Additional information can be found here https://huggingface.co/transformers/mai
- `scale_score`: Whether to scale the similarity score to the unit interval (range of [0,1]).
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
- `use_fast`: Whether to use the fast version of DPR tokenizers or fallback to the standard version. Defaults to True.

<a id="dense.TableTextRetriever.retrieve_batch"></a>

Expand Down Expand Up @@ -1153,7 +1150,7 @@ None

```python
@classmethod
def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", query_encoder_dir: str = "query_encoder", passage_encoder_dir: str = "passage_encoder", table_encoder_dir: str = "table_encoder", infer_tokenizer_classes: bool = False)
def load(cls, load_dir: Union[Path, str], document_store: BaseDocumentStore, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", query_encoder_dir: str = "query_encoder", passage_encoder_dir: str = "passage_encoder", table_encoder_dir: str = "table_encoder")
```

Load TableTextRetriever from the specified directory.
Expand Down
9 changes: 6 additions & 3 deletions haystack/document_stores/memory.py
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm

from haystack.schema import Document, Label
from haystack.errors import DuplicateDocumentError
from haystack.errors import DuplicateDocumentError, DocumentStoreError
from haystack.document_stores import BaseDocumentStore
from haystack.document_stores.base import get_batches_from_generator
from haystack.modeling.utils import initialize_device_settings
Expand Down Expand Up @@ -448,8 +448,11 @@ def update_embeddings(
) as progress_bar:
for document_batch in batched_documents:
embeddings = retriever.embed_documents(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)

if not len(document_batch) == len(embeddings):
raise DocumentStoreError(
"The number of embeddings does not match the number of documents in the batch "
f"({len(embeddings)} != {len(document_batch)})"
)
if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(
f"Embedding dim. of model ({embeddings[0].shape[0]})"
Expand Down
7 changes: 7 additions & 0 deletions haystack/errors.py
Expand Up @@ -35,6 +35,13 @@ def __repr__(self):
return str(self)


class ModelingError(HaystackError):
"""Exception for issues raised by the modeling module"""

def __init__(self, message: Optional[str] = None, docs_link: Optional[str] = "https://haystack.deepset.ai/"):
super().__init__(message=message, docs_link=docs_link)


class PipelineError(HaystackError):
"""Exception for issues raised within a pipeline"""

Expand Down
15 changes: 5 additions & 10 deletions haystack/json-schemas/haystack-pipeline-master.schema.json
Expand Up @@ -2116,11 +2116,6 @@
"default": true,
"type": "boolean"
},
"infer_tokenizer_classes": {
"title": "Infer Tokenizer Classes",
"default": false,
"type": "boolean"
},
"similarity_function": {
"title": "Similarity Function",
"default": "dot_product",
Expand Down Expand Up @@ -4326,11 +4321,6 @@
"default": true,
"type": "boolean"
},
"infer_tokenizer_classes": {
"title": "Infer Tokenizer Classes",
"default": false,
"type": "boolean"
},
"similarity_function": {
"title": "Similarity Function",
"default": "dot_product",
Expand Down Expand Up @@ -4375,6 +4365,11 @@
"title": "Scale Score",
"default": true,
"type": "boolean"
},
"use_fast": {
"title": "Use Fast",
"default": true,
"type": "boolean"
}
},
"required": [
Expand Down
11 changes: 10 additions & 1 deletion haystack/modeling/data_handler/data_silo.py
Expand Up @@ -812,7 +812,16 @@ def _run_teacher(self, batch: dict) -> List[torch.Tensor]:
"""
Run the teacher model on the given batch.
"""
return self.teacher.inferencer.model(**batch)
params = {
"input_ids": batch["input_ids"],
"segment_ids": batch["segment_ids"],
"padding_mask": batch["padding_mask"],
}
if "output_hidden_states" in batch.keys():
params["output_hidden_states"] = batch["output_hidden_states"]
if "output_attentions" in batch.keys():
params["output_attentions"] = batch["output_attentions"]
return self.teacher.inferencer.model(**params)

def _pass_batches(
self,
Expand Down
64 changes: 42 additions & 22 deletions haystack/modeling/data_handler/processor.py
@@ -1,4 +1,4 @@
from typing import Optional, Dict, List, Union, Any, Iterable
from typing import Optional, Dict, List, Union, Any, Iterable, Type

import os
import json
Expand All @@ -16,9 +16,11 @@
import requests
from tqdm import tqdm
from torch.utils.data import TensorDataset
import transformers
from transformers import PreTrainedTokenizer

from haystack.modeling.model.tokenization import (
Tokenizer,
get_tokenizer,
tokenize_batch_question_answering,
tokenize_with_metadata,
truncate_sequences,
Expand Down Expand Up @@ -176,11 +178,9 @@ def load_from_dir(cls, load_dir: str):
"Loading tokenizer from deprecated config. "
"If you used `custom_vocab` or `never_split_chars`, this won't work anymore."
)
tokenizer = Tokenizer.load(
load_dir, tokenizer_class=config["tokenizer"], do_lower_case=config["lower_case"]
)
tokenizer = get_tokenizer(load_dir, tokenizer_class=config["tokenizer"], do_lower_case=config["lower_case"])
else:
tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["tokenizer"])
tokenizer = get_tokenizer(load_dir, tokenizer_class=config["tokenizer"])

# we have to delete the tokenizer string from config, because we pass it as Object
del config["tokenizer"]
Expand Down Expand Up @@ -216,7 +216,7 @@ def convert_from_transformers(
**kwargs,
):
tokenizer_args = tokenizer_args or {}
tokenizer = Tokenizer.load(
tokenizer = get_tokenizer(
tokenizer_name_or_path,
tokenizer_class=tokenizer_class,
use_fast=use_fast,
Expand Down Expand Up @@ -308,7 +308,9 @@ def file_to_dicts(self, file: str) -> List[dict]:
raise NotImplementedError()

@abstractmethod
def dataset_from_dicts(self, dicts: List[dict], indices: Optional[List[int]] = None, return_baskets: bool = False):
def dataset_from_dicts(
self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False
):
raise NotImplementedError()

@abstractmethod
Expand Down Expand Up @@ -445,7 +447,9 @@ def __init__(
"using the default task or add a custom task later via processor.add_task()"
)

def dataset_from_dicts(self, dicts: List[dict], indices: Optional[List[int]] = None, return_baskets: bool = False):
def dataset_from_dicts(
self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False
):
"""
Convert input dictionaries into a pytorch dataset for Question Answering.
For this we have an internal representation called "baskets".
Expand Down Expand Up @@ -492,7 +496,7 @@ def file_to_dicts(self, file: str) -> List[dict]:
return dicts

# TODO use Input Objects instead of this function, remove Natural Questions (NQ) related code
def convert_qa_input_dict(self, infer_dict: dict):
def convert_qa_input_dict(self, infer_dict: dict) -> Dict[str, Any]:
"""Input dictionaries in QA can either have ["context", "qas"] (internal format) as keys or
["text", "questions"] (api format). This function converts the latter into the former. It also converts the
is_impossible field to answer_type so that NQ and SQuAD dicts have the same format.
Expand Down Expand Up @@ -929,9 +933,15 @@ def load_from_dir(cls, load_dir: str):
# read config
processor_config_file = Path(load_dir) / "processor_config.json"
config = json.load(open(processor_config_file))
# init tokenizer
query_tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["query_tokenizer"], subfolder="query")
passage_tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["passage_tokenizer"], subfolder="passage")
# init tokenizers
query_tokenizer_class: Type[PreTrainedTokenizer] = getattr(transformers, config["query_tokenizer"])
query_tokenizer = query_tokenizer_class.from_pretrained(
pretrained_model_name_or_path=load_dir, subfolder="query"
)
passage_tokenizer_class: Type[PreTrainedTokenizer] = getattr(transformers, config["passage_tokenizer"])
passage_tokenizer = passage_tokenizer_class.from_pretrained(
pretrained_model_name_or_path=load_dir, subfolder="passage"
)

# we have to delete the tokenizer string from config, because we pass it as Object
del config["query_tokenizer"]
Expand Down Expand Up @@ -978,7 +988,9 @@ def save(self, save_dir: Union[str, Path]):
with open(output_config_file, "w") as file:
json.dump(config, file)

def dataset_from_dicts(self, dicts: List[dict], indices: Optional[List[int]] = None, return_baskets: bool = False):
def dataset_from_dicts(
self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False
):
"""
Convert input dictionaries into a pytorch dataset for TextSimilarity (e.g. DPR).
For conversion we have an internal representation called "baskets".
Expand Down Expand Up @@ -1334,9 +1346,9 @@ def load_from_dir(cls, load_dir: str):
processor_config_file = Path(load_dir) / "processor_config.json"
config = json.load(open(processor_config_file))
# init tokenizer
query_tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["query_tokenizer"], subfolder="query")
passage_tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["passage_tokenizer"], subfolder="passage")
table_tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["table_tokenizer"], subfolder="table")
query_tokenizer = get_tokenizer(load_dir, tokenizer_class=config["query_tokenizer"], subfolder="query")
passage_tokenizer = get_tokenizer(load_dir, tokenizer_class=config["passage_tokenizer"], subfolder="passage")
table_tokenizer = get_tokenizer(load_dir, tokenizer_class=config["table_tokenizer"], subfolder="table")

# we have to delete the tokenizer string from config, because we pass it as Object
del config["query_tokenizer"]
Expand Down Expand Up @@ -1488,7 +1500,9 @@ def _read_multimodal_dpr_json(self, file: str, max_samples: Optional[int] = None
standard_dicts.append(sample)
return standard_dicts

def dataset_from_dicts(self, dicts: List[Dict], indices: Optional[List[int]] = None, return_baskets: bool = False):
def dataset_from_dicts(
self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False
):
"""
Convert input dictionaries into a pytorch dataset for TextSimilarity.
For conversion we have an internal representation called "baskets".
Expand Down Expand Up @@ -1836,7 +1850,9 @@ def __init__(
def file_to_dicts(self, file: str) -> List[Dict]:
raise NotImplementedError

def dataset_from_dicts(self, dicts, indices=None, return_baskets=False, debug=False):
def dataset_from_dicts(
self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False
):
self.baskets = []
# Tokenize in batches
texts = [x["text"] for x in dicts]
Expand Down Expand Up @@ -1958,7 +1974,7 @@ def load_from_dir(cls, load_dir: str):
processor_config_file = Path(load_dir) / "processor_config.json"
config = json.load(open(processor_config_file))
# init tokenizer
tokenizer = Tokenizer.load(load_dir, tokenizer_class=config["tokenizer"])
tokenizer = get_tokenizer(load_dir, tokenizer_class=config["tokenizer"])
# we have to delete the tokenizer string from config, because we pass it as Object
del config["tokenizer"]

Expand All @@ -1979,7 +1995,9 @@ def convert_labels(self, dictionary: Dict):
ret: Dict = {}
return ret

def dataset_from_dicts(self, dicts: List[Dict], indices=None, return_baskets: bool = False, debug: bool = False):
def dataset_from_dicts(
self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False
):
"""
Function to convert input dictionaries containing text into a torch dataset.
For normal operation with Language Models it calls the superclass' TextClassification.dataset_from_dicts method.
Expand Down Expand Up @@ -2067,7 +2085,9 @@ def file_to_dicts(self, file: str) -> List[dict]:
dicts.append({"text": line})
return dicts

def dataset_from_dicts(self, dicts: List[dict], indices: Optional[List[int]] = None, return_baskets: bool = False):
def dataset_from_dicts(
self, dicts: List[Dict], indices: List[int] = [], return_baskets: bool = False, debug: bool = False
):
if return_baskets:
raise NotImplementedError("return_baskets is not supported by UnlabeledTextProcessor")
texts = [dict_["text"] for dict_ in dicts]
Expand Down