Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure eval mode for farm and transformer models for predictions #3791

Merged
merged 21 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6d9d11c
Updated DefaultEmbeddingEncoder to call model.eval() and added test.
sjrl Dec 30, 2022
01d22b5
Updated retribert retriever to set model.eval() and added test
sjrl Dec 30, 2022
8505cb3
Updated test to test _SentenceTransformersEmbeddingEncoder
sjrl Dec 30, 2022
4da4c14
Set model.eval() in FARMReader and TransformerReader when running pre…
sjrl Dec 30, 2022
3ce1bef
Added model.eval calls to _get_prediction methods of Inferencer
sjrl Dec 30, 2022
ebb58e4
Added model.eval to PromptNode and DocumentClassifier. Also added new…
sjrl Dec 30, 2022
3c7299e
Added model.eval to Text2SparqlRetriever
sjrl Dec 30, 2022
a83eff1
Fix to unit test
sjrl Dec 30, 2022
71b5a4e
Added model.eval to TransformersTranslator
sjrl Dec 30, 2022
ab6eab1
Added model.eval() to EntityExtractor and Text2Speech
sjrl Dec 30, 2022
5284ab3
Added model.eval to TransformersQueryClassifier and QuestionGenerator
sjrl Dec 30, 2022
1335c32
Added model eval to RAGenerator
sjrl Dec 30, 2022
09e8870
Added model eval to Seq2SeqGenerator
sjrl Dec 30, 2022
6cb4df7
Merge branch 'main' of github.com:deepset-ai/haystack into inf-model-…
sjrl Mar 27, 2023
ad789f7
Undoing additions of eval as discussed in PR
sjrl Mar 27, 2023
68def3a
Added self.model.eval() to end of training loop in FARMReader
sjrl Mar 27, 2023
1ce9e90
Try removing integration tags
sjrl Mar 27, 2023
85015d8
Merge branch 'main' of github.com:deepset-ai/haystack into inf-model-…
sjrl Apr 18, 2023
0f31a6f
Merge branch 'main' of github.com:deepset-ai/haystack into inf-model-…
sjrl Apr 19, 2023
dbbac67
Merge branch 'main' into inf-model-eval
sjrl May 8, 2023
2002b65
Merge branch 'main' into inf-model-eval
masci May 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions haystack/modeling/training/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def train(self):
)
self.test_result = evaluator_test.eval(self.model)
evaluator_test.log_results(self.test_result, "Test", self.global_step)
self.model.eval()
return self.model

def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
Expand Down
28 changes: 28 additions & 0 deletions test/nodes/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,34 @@ def test_no_answer_reader_skips_empty_documents(no_answer_reader):
assert predictions["answers"][1][1].answer == "Carla" # answer given for 2nd query as usual


def test_reader_training_returns_eval(tmp_path):
max_seq_len = 16
max_query_length = 8
reader = FARMReader(
model_name_or_path="deepset/tinyroberta-squad2",
use_gpu=False,
num_processes=0,
max_seq_len=max_seq_len,
doc_stride=2,
max_query_length=max_query_length,
)

save_dir = f"{tmp_path}/test_dpr_training"
reader.train(
data_dir=str(SAMPLES_PATH / "squad"),
train_filename="tiny.json",
dev_filename="tiny.json",
n_epochs=1,
batch_size=1,
grad_acc_steps=1,
evaluate_every=0,
save_dir=save_dir,
max_seq_len=max_seq_len,
max_query_length=max_query_length,
)
assert reader.inferencer.model.training is False


def test_reader_training(tmp_path):
max_seq_len = 16
max_query_length = 8
Expand Down
3 changes: 0 additions & 3 deletions test/nodes/test_retriever.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from typing import List

import os
import logging
import os
from math import isclose
Expand Down