-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
array_test.py
200 lines (173 loc) · 7.5 KB
/
array_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for GlobalDeviceArray."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.lib import xla_client as xc
from jax._src.util import prod
from jax.experimental import PartitionSpec as P
from jax.experimental import sharding
from jax.experimental import array
from jax.config import config
config.parse_flags_with_absl()
def create_array(shape, sharding, global_data=None):
if global_data is None:
global_data = np.arange(prod(shape)).reshape(shape)
return array.make_array_from_callback(
shape, sharding, lambda idx: global_data[idx]), global_data
class JaxArrayTest(jtu.JaxTestCase):
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_jax_array_value(self, mesh_axes):
with jax._src.config.jax_array(True):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, global_data = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, mesh_axes))
for s in arr.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], global_data[s.index])
self.assertArraysEqual(arr._value, global_data)
self.assertArraysEqual(arr._npy_value, global_data)
def test_array_delete(self):
with jax._src.config.jax_array(True):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
arr.delete()
with self.assertRaisesRegex(ValueError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)
def test_device_put(self):
with jax._src.config.jax_array(True):
numpy_array = np.array([1, 2, 3])
arr = jax.device_put(numpy_array, jax.devices()[0])
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
self.assertArraysEqual(arr, numpy_array)
self.assertEqual(arr._committed, True)
for i in arr.addressable_shards:
self.assertArraysEqual(i.data, numpy_array)
self.assertEqual(i.device, jax.devices()[0])
self.assertEqual(i.index, (slice(None),))
def test_device_put_array_delete(self):
with jax._src.config.jax_array(True):
arr = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
arr.delete()
with self.assertRaisesRegex(ValueError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)
def test_array_device_get(self):
with jax._src.config.jax_array(True):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
self.assertArraysEqual(jax.device_get(arr), input_data)
def test_repr(self):
with jax._src.config.jax_array(True):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
repr(arr) # doesn't crash
def test_jnp_array(self):
with jax._src.config.jax_array(True):
arr = jnp.array([1, 2, 3])
self.assertIsInstance(arr, array.Array)
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
self.assertEqual(arr._committed, False)
def test_jnp_array_jit_add(self):
with jax._src.config.jax_array(True):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = jax.jit(lambda x, y: x + y)(a, b)
self.assertIsInstance(arr, array.Array)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
def test_jnp_array_jnp_add(self):
with jax._src.config.jax_array(True):
arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
self.assertIsInstance(arr, array.Array)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
def test_jnp_array_normal_add(self):
with jax._src.config.jax_array(True):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = a + b
self.assertIsInstance(arr, array.Array)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
def test_array_sharded_astype(self):
with jax._src.config.jax_array(True):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
def test_jnp_array_astype(self):
with jax._src.config.jax_array(True):
arr = jnp.array([1, 2, 3])
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, arr.astype(np.float32))
class ShardingTest(jtu.JaxTestCase):
def test_mesh_pspec_sharding_interface(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
pspec = P('y', 'x')
global_shape = (8, 4)
mp_sharding = sharding.MeshPspecSharding(mesh, pspec)
di_map = mp_sharding.devices_indices_map(global_shape)
op_sharding = mp_sharding._to_xla_op_sharding(len(global_shape))
device_assignment = mp_sharding._device_assignment()
self.assertEqual(di_map[mesh.devices.flat[0]], (slice(0, 4), slice(0, 1)))
self.assertArraysEqual(device_assignment, list(mesh.devices.flat))
self.assertEqual(op_sharding.type, xc.OpSharding.Type.OTHER)
self.assertListEqual(op_sharding.tile_assignment_dimensions, [2, 4])
self.assertListEqual(op_sharding.tile_assignment_devices,
[0, 2, 4, 6, 1, 3, 5, 7])
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_none_x", P(None, "x")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_op_sharding_indices(self, pspec):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.MeshPspecSharding(mesh, pspec)
ops = sharding.OpShardingSharding(
list(mesh.devices.flat), mps._to_xla_op_sharding(len(shape)))
self.assertDictEqual(
ops.devices_indices_map(shape), mps.devices_indices_map(shape))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())