Skip to content

Commit

Permalink
update test_spark_keras.py and test_spark_lightning.py to cover val d…
Browse files Browse the repository at this point in the history
…ataloader cases

Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc committed May 19, 2021
1 parent 72e38ff commit bef20b8
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 53 deletions.
58 changes: 32 additions & 26 deletions test/integration/test_spark_keras.py
Expand Up @@ -204,32 +204,38 @@ def test_keras_direct_parquet_train(self, mock_fit_fn, mock_pin_gpu_fn):
store.get_train_data_path = lambda v=None: store._train_path
store.get_val_data_path = lambda v=None: store._val_path

with util.prepare_data(backend.num_processes(),
store,
df,
feature_columns=['features'],
label_columns=['y']):
model = create_xor_model()
optimizer = tf.keras.optimizers.SGD(lr=0.1)
loss = 'binary_crossentropy'

for reader_pool_type in ['process', 'thread']:
est = hvd.KerasEstimator(
backend=backend,
store=store,
model=model,
optimizer=optimizer,
loss=loss,
feature_cols=['features'],
label_cols=['y'],
batch_size=1,
epochs=3,
reader_pool_type=reader_pool_type,
verbose=2)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()
# Make sure we cover val dataloader cases
for validation in [0.0, 0.5]:
with util.prepare_data(backend.num_processes(),
store,
df,
feature_columns=['features'],
label_columns=['y'],
validation=validation):
model = create_xor_model()
optimizer = tf.keras.optimizers.SGD(lr=0.1)
loss = 'binary_crossentropy'

for inmemory_cache_all in [False, True]:
for reader_pool_type in ['process', 'thread']:
est = hvd.KerasEstimator(
backend=backend,
store=store,
model=model,
optimizer=optimizer,
loss=loss,
feature_cols=['features'],
label_cols=['y'],
batch_size=1,
epochs=3,
reader_pool_type=reader_pool_type,
validation=validation,
inmemory_cache_all=inmemory_cache_all,
verbose=2)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()

@mock.patch('horovod.spark.keras.remote._pin_gpu_fn')
@mock.patch('horovod.spark.keras.util.TFKerasUtil.fit_fn')
Expand Down
56 changes: 29 additions & 27 deletions test/integration/test_spark_lightning.py
Expand Up @@ -362,33 +362,35 @@ def test_direct_parquet_train(self):
store.get_train_data_path = lambda v=None: store._train_path
store.get_val_data_path = lambda v=None: store._val_path

with util.prepare_data(backend.num_processes(),
store,
df,
feature_columns=['features'],
label_columns=['y'],
validation=0.2):
model = create_xor_model()

for inmemory_cache_all in [False, True]:
for reader_pool_type in ['process', 'thread']:
est = hvd_spark.TorchEstimator(
backend=backend,
store=store,
model=model,
input_shapes=[[-1, 2]],
feature_cols=['features'],
label_cols=['y'],
validation=0.2,
batch_size=1,
epochs=3,
verbose=2,
inmemory_cache_all=inmemory_cache_all,
reader_pool_type=reader_pool_type)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()
# Make sure to cover val dataloader cases
for validation in [0.0, 0.5]:
with util.prepare_data(backend.num_processes(),
store,
df,
feature_columns=['features'],
label_columns=['y'],
validation=validation):
model = create_xor_model()

for inmemory_cache_all in [False, True]:
for reader_pool_type in ['process', 'thread']:
est = hvd_spark.TorchEstimator(
backend=backend,
store=store,
model=model,
input_shapes=[[-1, 2]],
feature_cols=['features'],
label_cols=['y'],
validation=validation,
batch_size=1,
epochs=3,
verbose=2,
inmemory_cache_all=inmemory_cache_all,
reader_pool_type=reader_pool_type)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()

def test_legacy_calculate_loss_with_sample_weight(self):
labels = torch.tensor([[1.0, 2.0, 3.0]])
Expand Down

0 comments on commit bef20b8

Please sign in to comment.