Skip to content

Commit

Permalink
Merge 335b616 into 2ebb455
Browse files Browse the repository at this point in the history
  • Loading branch information
trax-robot committed May 1, 2020
2 parents 2ebb455 + 335b616 commit d8c7f8f
Show file tree
Hide file tree
Showing 15 changed files with 901 additions and 974 deletions.
4 changes: 2 additions & 2 deletions trax/tf_numpy/extensions/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

from trax.tf_numpy.numpy import arrays
from trax.tf_numpy.numpy import random
from trax.tf_numpy.numpy.array_creation import array
from trax.tf_numpy.numpy.array_creation import asarray
from trax.tf_numpy.numpy.array_ops import array
from trax.tf_numpy.numpy.array_ops import asarray
from trax.tf_numpy.numpy.arrays import ndarray
from trax.tf_numpy.numpy.arrays import ShardedNdArray

Expand Down
37 changes: 19 additions & 18 deletions trax/tf_numpy/extensions/extensions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
import tensorflow.compat.v2 as tf

from trax.tf_numpy import extensions
from trax.tf_numpy.numpy import array_creation
from trax.tf_numpy.numpy import array_methods
from trax.tf_numpy.numpy import array_ops
from trax.tf_numpy.numpy import arrays
from trax.tf_numpy.numpy import math
from trax.tf_numpy.numpy import math_ops
from trax.tf_numpy.numpy import random
from trax.tf_numpy.numpy.array_creation import asarray
from trax.tf_numpy.numpy.array_ops import asarray

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -87,7 +86,8 @@ def _hasGPU(self):

def testGrad(self):
def f(a, b):
return array_methods.sum(math.sqrt(math.exp(a)) + b)
return array_ops.sum(math_ops.sqrt(math_ops.exp(a)) + b)

g = extensions.grad(f)
def compare(a, b):
with tf.GradientTape() as tape:
Expand Down Expand Up @@ -125,7 +125,8 @@ def g_jitted(a):

def testJit(self):
def f(a, b):
return array_methods.sum(math.sqrt(math.exp(a)) + b)
return array_ops.sum(math_ops.sqrt(math_ops.exp(a)) + b)

f_jitted = extensions.jit(f)
shape = [10]
a = random.randn(*shape)
Expand Down Expand Up @@ -155,12 +156,13 @@ def f(a):

def _testEvalOnShapes(self, transformer):
def f(a, b):
return array_methods.sum(math.sqrt(math.exp(a)) + b)
return array_ops.sum(math_ops.sqrt(math_ops.exp(a)) + b)

f_prime = transformer(f)
shape = [10]
dtype = np.float16
a = array_creation.zeros(shape=shape, dtype=dtype)
b = array_creation.zeros(shape=shape, dtype=dtype)
a = array_ops.zeros(shape=shape, dtype=dtype)
b = array_ops.zeros(shape=shape, dtype=dtype)
expected = f(a, b)
got = f_prime(a, b)
self.assertAllEqual(expected.shape, got.shape)
Expand All @@ -181,8 +183,7 @@ def transformer(f):
@extensions.jit
def f_prime(a, b):
shape_dtype = extensions.eval_on_shapes(f)(a, b)
return array_creation.zeros(shape=shape_dtype.shape,
dtype=shape_dtype.dtype)
return array_ops.zeros(shape=shape_dtype.shape, dtype=shape_dtype.dtype)
return f_prime
self._testEvalOnShapes(transformer)

Expand Down Expand Up @@ -363,7 +364,7 @@ def _train_and_reduce(params, inputs, targets, learning_rate=0.1):
_train_and_reduce, devices=devices)

def replicate(x, num_devices=2):
return array_methods.broadcast_to(x, (num_devices,) + x.shape)
return array_ops.broadcast_to(x, (num_devices,) + x.shape)

params = tf.nest.map_structure(replicate, params)

Expand All @@ -374,7 +375,7 @@ def reshape(x, num_devices=2):

# New shape.
new_shape_prefix = [num_devices, batch_size_per_device]
return array_methods.reshape(x, new_shape_prefix + x_shape[1:])
return array_ops.reshape(x, new_shape_prefix + x_shape[1:])

inputs = tf.nest.map_structure(reshape, inputs)
targets = tf.nest.map_structure(reshape, targets)
Expand All @@ -398,7 +399,7 @@ def testPsum(self):
def reduce_sum(f):
return extensions.psum(f)

data = array_creation.asarray(tf.convert_to_tensor(value=[1, 3]))
data = array_ops.asarray(tf.convert_to_tensor(value=[1, 3]))
pmapped = extensions.pmap(reduce_sum, devices=devices)
result = pmapped(data)

Expand All @@ -413,7 +414,7 @@ def testPmean(self):
def reduce_mean(f):
return extensions.pmean(f)

data = array_creation.asarray(tf.convert_to_tensor(value=[1, 3]))
data = array_ops.asarray(tf.convert_to_tensor(value=[1, 3]))
pmapped = extensions.pmap(reduce_mean, devices=devices)
result = pmapped(data)

Expand All @@ -426,7 +427,7 @@ def testAxisName(self):
def reduce_sum(f):
return extensions.psum(f, axis_name="foo")

data = array_creation.asarray(tf.convert_to_tensor(value=[1, 3]))
data = array_ops.asarray(tf.convert_to_tensor(value=[1, 3]))
pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices)
pmapped(data)

Expand All @@ -436,7 +437,7 @@ def testWrongAxisName(self):
def reduce_sum(f):
return extensions.psum(f, axis_name="bar")

data = array_creation.asarray(tf.convert_to_tensor(value=[1, 3]))
data = array_ops.asarray(tf.convert_to_tensor(value=[1, 3]))
with self.assertRaisesWithPredicateMatch(
ValueError, r"axis_name (.*) is not equal to that of the surrounding"):
pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices)
Expand All @@ -448,7 +449,7 @@ def testNoNestedPmap(self):
def f(x):
return x + 1.0

data = array_creation.asarray(tf.convert_to_tensor(value=[1, 3]))
data = array_ops.asarray(tf.convert_to_tensor(value=[1, 3]))
with self.assertRaisesWithPredicateMatch(
ValueError, r"Nested pmap is not supported"):
f = extensions.pmap(f, devices=devices)
Expand Down
6 changes: 2 additions & 4 deletions trax/tf_numpy/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@
from trax.tf_numpy.numpy import random

# pylint: disable=wildcard-import
from trax.tf_numpy.numpy.array_creation import *
from trax.tf_numpy.numpy.array_manipulation import *
from trax.tf_numpy.numpy.array_methods import *
from trax.tf_numpy.numpy.array_ops import *
from trax.tf_numpy.numpy.arrays import ndarray
from trax.tf_numpy.numpy.dtypes import *
from trax.tf_numpy.numpy.math import *
from trax.tf_numpy.numpy.math_ops import *
from trax.tf_numpy.numpy.utils import finfo
from trax.tf_numpy.numpy.utils import promote_types
from trax.tf_numpy.numpy.utils import result_type
Expand Down

0 comments on commit d8c7f8f

Please sign in to comment.