Skip to content

Commit

Permalink
fix(api): return model and summary in highlevel fit (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Oct 26, 2021
1 parent 115a0aa commit 320ec5d
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 9 deletions.
8 changes: 5 additions & 3 deletions finetuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def fit(

def fit(
model: 'AnyDNN', train_data: 'DocumentArrayLike', *args, **kwargs
) -> Optional['Summary']:
) -> Optional[Tuple['AnyDNN', 'Summary']]:
if kwargs.get('to_embedding_model', False):
from .tailor import to_embedding_model

Expand All @@ -106,8 +106,10 @@ def fit(
if kwargs.get('interactive', False):
from .labeler import fit

return fit(model, train_data, *args, **kwargs)
# TODO: atm return will never hit as labeler UI hangs the
# flow via `.block()`
fit(model, train_data, *args, **kwargs)
else:
from .tuner import fit

return fit(model, train_data, *args, **kwargs)
return model, fit(model, train_data, *args, **kwargs)
1 change: 0 additions & 1 deletion finetuner/tuner/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def plot(
**plt_kwargs,
)
axes[idx].set_ylabel(record.name)
axes[idx].set_box_aspect(1)
axes[idx].set_xlabel('Steps')

if output:
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/fit/test_fit_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_fit_all(tmpdir):

for kb, b in embed_models.items():
for h in all_test_losses:
result = fit(
model, result = fit(
b(),
loss=h,
train_data=lambda: generate_qa_match(
Expand All @@ -63,4 +63,5 @@ def test_fit_all(tmpdir):
),
epochs=2,
)
assert model
result.save(tmpdir / f'result-{kb}-{h}.json')
3 changes: 2 additions & 1 deletion tests/integration/fit/test_fit_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_fit_all(tmpdir):

for kb, b in embed_models.items():
for h in all_test_losses:
result = finetuner.fit(
model, result = finetuner.fit(
b(),
loss=h,
train_data=lambda: generate_fashion_match(
Expand All @@ -55,4 +55,5 @@ def test_fit_all(tmpdir):
),
epochs=2,
)
assert model
result.save(tmpdir / f'result-{kb}-{h}.json')
4 changes: 3 additions & 1 deletion tests/integration/keras/test_tail_and_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def embed_model():

def test_tail_and_tune(embed_model, create_easy_data):
data, _ = create_easy_data(10, 128, 1000)
rv = fit(
model, rv = fit(
model=embed_model,
train_data=data,
epochs=5,
Expand All @@ -29,3 +29,5 @@ def test_tail_and_tune(embed_model, create_easy_data):
layer_name='dense_2',
)
assert rv.dict()
assert model
assert model != embed_model
4 changes: 3 additions & 1 deletion tests/integration/paddle/test_tail_and_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def embed_model():

def test_tail_and_tune(embed_model, create_easy_data):
data, _ = create_easy_data(10, 128, 1000)
rv = fit(
model, rv = fit(
model=embed_model,
train_data=data,
epochs=5,
Expand All @@ -30,3 +30,5 @@ def test_tail_and_tune(embed_model, create_easy_data):
layer_name='linear_4',
)
assert rv.dict()
assert model
assert model != embed_model
4 changes: 3 additions & 1 deletion tests/integration/torch/test_tail_and_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def embed_model():

def test_tail_and_tune(embed_model, create_easy_data):
data, _ = create_easy_data(10, 128, 1000)
rv = fit(
model, rv = fit(
model=embed_model,
train_data=data,
epochs=5,
Expand All @@ -30,3 +30,5 @@ def test_tail_and_tune(embed_model, create_easy_data):
layer_name='linear_4',
)
assert rv.dict()
assert model
assert model != embed_model

0 comments on commit 320ec5d

Please sign in to comment.