From 320ec5df11d104fbe36ba7e6d467b159a4fbb1c9 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 26 Oct 2021 16:45:29 +0200 Subject: [PATCH] fix(api): return model and summary in highlevel fit (#175) --- finetuner/__init__.py | 8 +++++--- finetuner/tuner/summary.py | 1 - tests/integration/fit/test_fit_lstm.py | 3 ++- tests/integration/fit/test_fit_mlp.py | 3 ++- tests/integration/keras/test_tail_and_tune.py | 4 +++- tests/integration/paddle/test_tail_and_tune.py | 4 +++- tests/integration/torch/test_tail_and_tune.py | 4 +++- 7 files changed, 18 insertions(+), 9 deletions(-) diff --git a/finetuner/__init__.py b/finetuner/__init__.py index 922726175..921da0c4a 100644 --- a/finetuner/__init__.py +++ b/finetuner/__init__.py @@ -97,7 +97,7 @@ def fit( def fit( model: 'AnyDNN', train_data: 'DocumentArrayLike', *args, **kwargs -) -> Optional['Summary']: +) -> Optional[Tuple['AnyDNN', 'Summary']]: if kwargs.get('to_embedding_model', False): from .tailor import to_embedding_model @@ -106,8 +106,10 @@ def fit( if kwargs.get('interactive', False): from .labeler import fit - return fit(model, train_data, *args, **kwargs) + # TODO: atm return will never hit as labeler UI hangs the + # flow via `.block()` + fit(model, train_data, *args, **kwargs) else: from .tuner import fit - return fit(model, train_data, *args, **kwargs) + return model, fit(model, train_data, *args, **kwargs) diff --git a/finetuner/tuner/summary.py b/finetuner/tuner/summary.py index b84aa40dc..27e0af050 100644 --- a/finetuner/tuner/summary.py +++ b/finetuner/tuner/summary.py @@ -97,7 +97,6 @@ def plot( **plt_kwargs, ) axes[idx].set_ylabel(record.name) - axes[idx].set_box_aspect(1) axes[idx].set_xlabel('Steps') if output: diff --git a/tests/integration/fit/test_fit_lstm.py b/tests/integration/fit/test_fit_lstm.py index b14160434..4c7a24f79 100644 --- a/tests/integration/fit/test_fit_lstm.py +++ b/tests/integration/fit/test_fit_lstm.py @@ -52,7 +52,7 @@ def test_fit_all(tmpdir): for kb, b in embed_models.items(): for h in all_test_losses: - result = fit( + model, result = fit( b(), loss=h, train_data=lambda: generate_qa_match( @@ -63,4 +63,5 @@ def test_fit_all(tmpdir): ), epochs=2, ) + assert model result.save(tmpdir / f'result-{kb}-{h}.json') diff --git a/tests/integration/fit/test_fit_mlp.py b/tests/integration/fit/test_fit_mlp.py index 5fd2e08b0..a042e9099 100644 --- a/tests/integration/fit/test_fit_mlp.py +++ b/tests/integration/fit/test_fit_mlp.py @@ -44,7 +44,7 @@ def test_fit_all(tmpdir): for kb, b in embed_models.items(): for h in all_test_losses: - result = finetuner.fit( + model, result = finetuner.fit( b(), loss=h, train_data=lambda: generate_fashion_match( @@ -55,4 +55,5 @@ def test_fit_all(tmpdir): ), epochs=2, ) + assert model result.save(tmpdir / f'result-{kb}-{h}.json') diff --git a/tests/integration/keras/test_tail_and_tune.py b/tests/integration/keras/test_tail_and_tune.py index 88f4937b0..5e4f96358 100644 --- a/tests/integration/keras/test_tail_and_tune.py +++ b/tests/integration/keras/test_tail_and_tune.py @@ -19,7 +19,7 @@ def embed_model(): def test_tail_and_tune(embed_model, create_easy_data): data, _ = create_easy_data(10, 128, 1000) - rv = fit( + model, rv = fit( model=embed_model, train_data=data, epochs=5, @@ -29,3 +29,5 @@ def test_tail_and_tune(embed_model, create_easy_data): layer_name='dense_2', ) assert rv.dict() + assert model + assert model != embed_model diff --git a/tests/integration/paddle/test_tail_and_tune.py b/tests/integration/paddle/test_tail_and_tune.py index e57201fe5..7c945dd84 100644 --- a/tests/integration/paddle/test_tail_and_tune.py +++ b/tests/integration/paddle/test_tail_and_tune.py @@ -20,7 +20,7 @@ def embed_model(): def test_tail_and_tune(embed_model, create_easy_data): data, _ = create_easy_data(10, 128, 1000) - rv = fit( + model, rv = fit( model=embed_model, train_data=data, epochs=5, @@ -30,3 +30,5 @@ def test_tail_and_tune(embed_model, create_easy_data): layer_name='linear_4', ) assert rv.dict() + assert model + assert model != embed_model diff --git a/tests/integration/torch/test_tail_and_tune.py b/tests/integration/torch/test_tail_and_tune.py index ef7f6ce74..bfbb48a5e 100644 --- a/tests/integration/torch/test_tail_and_tune.py +++ b/tests/integration/torch/test_tail_and_tune.py @@ -20,7 +20,7 @@ def embed_model(): def test_tail_and_tune(embed_model, create_easy_data): data, _ = create_easy_data(10, 128, 1000) - rv = fit( + model, rv = fit( model=embed_model, train_data=data, epochs=5, @@ -30,3 +30,5 @@ def test_tail_and_tune(embed_model, create_easy_data): layer_name='linear_4', ) assert rv.dict() + assert model + assert model != embed_model