Skip to content

Commit

Permalink
simplify test
Browse files Browse the repository at this point in the history
  • Loading branch information
danielenricocahall committed Jan 24, 2021
1 parent 380cc5c commit 00bade9
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions tests/integration/test_custom_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import random

import numpy as np
import pytest
from tensorflow.keras.backend import sigmoid
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
Expand All @@ -11,8 +9,7 @@
from elephas.utils import to_simple_rdd


@pytest.mark.parametrize('mode', ['synchronous', 'asynchronous', 'hogwild'])
def test_training_custom_activation(mode, spark_context):
def test_training_custom_activation(spark_context):
def custom_activation(x):
return sigmoid(x) + 1

Expand All @@ -30,8 +27,7 @@ def custom_activation(x):
y_train[:500] = 1
rdd = to_simple_rdd(spark_context, x_train, y_train)

spark_model = SparkModel(model, frequency='epoch', mode=mode,
port=4000 + random.randint(0, 300),
spark_model = SparkModel(model, frequency='epoch', mode='synchronous',
custom_objects={'custom_activation': custom_activation})
spark_model.fit(rdd, epochs=1, batch_size=16, verbose=0, validation_split=0.1)
assert spark_model.predict(x_test)
Expand Down

0 comments on commit 00bade9

Please sign in to comment.