Skip to content

Commit

Permalink
Adds regression test for #2081
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffreyangus committed Jun 2, 2022
1 parent 698a0e0 commit 2ce6f4c
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions tests/integration_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,42 @@ def test_api_callbacks_fixed_train_steps(csv_filename):
assert mock_callback.on_epoch_start.call_count == 10


def test_api_callbacks_fixed_train_steps_less_than_one_epoch(csv_filename):
# If train_steps is set manually, epochs is ignored.
train_steps = total_batches = 6
steps_per_checkpoint = 2
batch_size = 8
num_examples = 80
mock_callback = mock.Mock(wraps=Callback())

with tempfile.TemporaryDirectory() as output_dir:
input_features = [sequence_feature(reduce_output="sum")]
output_features = [category_feature(vocab_size=5, reduce_input="sum")]
config = {
"input_features": input_features,
"output_features": output_features,
"combiner": {"type": "concat", "output_size": 14},
TRAINER: {
"train_steps": train_steps,
"steps_per_checkpoint": steps_per_checkpoint,
"batch_size": batch_size,
},
}
model = LudwigModel(config, callbacks=[mock_callback])
model.train(
training_set=generate_data(
input_features, output_features, os.path.join(output_dir, csv_filename), num_examples=num_examples
)
)

assert mock_callback.on_epoch_start.call_count == 1
assert mock_callback.on_epoch_end.call_count == 1
# The total number of batches is the number of train_steps
assert mock_callback.on_batch_end.call_count == total_batches
# The total number of evals is the number of times checkpoints are made
assert mock_callback.on_eval_end.call_count == train_steps // steps_per_checkpoint


def test_api_save_torchscript(tmpdir):
"""Tests successful saving and loading of model in TorchScript format."""
input_features = [category_feature(vocab_size=5)]
Expand Down

0 comments on commit 2ce6f4c

Please sign in to comment.