Skip to content

Commit

Permalink
functional test
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Aug 13, 2018
1 parent 26e2d4e commit f9f76c5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
29 changes: 22 additions & 7 deletions elephas/utils/functional_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,51 @@
from six.moves import zip


def add_params(p1, p2):
"""Add two lists of parameters
def add_params(param_list_left, param_list_right):
"""Add two lists of parameters one by one
:param param_list_left: list of numpy arrays
:param param_list_right: list of numpy arrays
:return: list of numpy arrays
"""
res = []
for x, y in zip(p1, p2):
for x, y in zip(param_list_left, param_list_right):
res.append(x + y)
return res


def subtract_params(p1, p2):
def subtract_params(param_list_left, param_list_right):
"""Subtract two lists of parameters
:param param_list_left: list of numpy arrays
:param param_list_right: list of numpy arrays
:return: list of numpy arrays
"""
res = []
for x, y in zip(p1, p2):
for x, y in zip(param_list_left, param_list_right):
res.append(x - y)
return res


def get_neutral(array):
def get_neutral(array_list):
"""Get list of zero-valued numpy arrays for
specified list of numpy arrays
:param array_list: list of numpy arrays
:return: list of zeros of same shape as input
"""
res = []
for x in array:
for x in array_list:
res.append(np.zeros_like(x))
return res


def divide_by(array_list, num_workers):
"""Divide a list of parameters by an integer num_workers.
:param array_list:
:param num_workers:
:return:
"""
for i, x in enumerate(array_list):
array_list[i] /= num_workers
Expand Down
42 changes: 42 additions & 0 deletions tests/utils/test_functional_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import numpy as np
from elephas.utils import functional_utils

pytest.mark.usefixtures("spark_context")


def test_add_params():
p1 = [np.ones((5, 5)) for i in range(10)]
p2 = [np.ones((5, 5)) for i in range(10)]

res = functional_utils.add_params(p1, p2)
assert len(res) == 10
for i in range(5):
for j in range(5):
assert res[0][i, j] == 2


def test_subtract_params():
p1 = [np.ones((5, 5)) for i in range(10)]
p2 = [np.ones((5, 5)) for i in range(10)]

res = functional_utils.subtract_params(p1, p2)

assert len(res) == 10
for i in range(5):
for j in range(5):
assert res[0][i, j] == 0


def test_get_neutral():
x = [np.ones((3, 4))]
res = functional_utils.get_neutral(x)
assert res[0].shape == x[0].shape
assert res[0][0, 0] == 0


def test_divide_by():
x = [np.ones((3, 4))]
res = functional_utils.divide_by(x, num_workers=10)
assert res[0].shape == x[0].shape
assert res[0][0, 0] == 0.1

0 comments on commit f9f76c5

Please sign in to comment.