Skip to content

Commit

Permalink
Add asserts to prevent issues with sliding window (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
brandenchan committed Sep 14, 2020
1 parent 9a89adc commit ce34cc2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
4 changes: 4 additions & 0 deletions farm/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,10 @@ def __init__(
self.target = "classification"
self.ph_output_type = "per_token_squad"

assert doc_stride < max_seq_len, "doc_stride is longer than max_seq_len. This means that there will be gaps " \
"as the passage windows slide, causing the model to skip over parts of the document. "\
"Please set a lower value for doc_stride (Suggestions: doc_stride=128, max_seq_len=384) "

self.doc_stride = doc_stride
self.max_query_length = max_query_length
self.max_answers = max_answers
Expand Down
5 changes: 5 additions & 0 deletions farm/data_handler/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ def chunk_into_passages(doc_offsets,
doc_text):
""" Returns a list of dictionaries which each describe the start, end and id of a passage
that is formed when chunking a document using a sliding window approach. """

assert doc_stride < passage_len_t, "doc_stride is longer than passage_len_t. This means that there will be gaps " \
"as the passage windows slide, causing the model to skip over parts of the document. "\
"Please set a lower value for doc_stride (Suggestions: doc_stride=128, max_seq_len=384) "

passage_spans = []
passage_id = 0
doc_len_t = len(doc_offsets)
Expand Down
3 changes: 3 additions & 0 deletions farm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def load(
# override processor attributes loaded from config file with inferencer params
processor.max_seq_len = max_seq_len
if hasattr(processor, "doc_stride"):
assert doc_stride < max_seq_len, "doc_stride is longer than max_seq_len. This means that there will be gaps " \
"as the passage windows slide, causing the model to skip over parts of the document. "\
"Please set a lower value for doc_stride (Suggestions: doc_stride=128, max_seq_len=384) "
processor.doc_stride = doc_stride

# b) or from remote transformers model hub
Expand Down

0 comments on commit ce34cc2

Please sign in to comment.