Skip to content

Commit

Permalink
Add AMP patching of npi ops in _api_internal module (apache#19488)
Browse files Browse the repository at this point in the history
  • Loading branch information
mk-61 committed Nov 19, 2020
1 parent 5dc404d commit 6648866
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 29 deletions.
15 changes: 10 additions & 5 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,9 @@ cd_unittest_ubuntu() {
MXNET_ENGINE_TYPE=NaiveEngine \
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'test_operator' -n 4 --durations=50 --verbose tests/python/gpu
MXNET_GPU_MEM_POOL_TYPE=Unpooled \
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'not test_operator' -n 4 --durations=50 --verbose tests/python/gpu
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'not test_operator and not test_amp_init.py' -n 4 --durations=50 --verbose tests/python/gpu
pytest -m 'serial' --durations=50 --verbose tests/python/gpu
pytest --durations=50 --verbose tests/python/gpu/test_amp_init.py

# TODO(szha): fix and reenable the hanging issue. tracked in #18098
# integrationtest_ubuntu_gpu_dist_kvstore
Expand Down Expand Up @@ -833,11 +834,12 @@ unittest_ubuntu_python3_gpu() {
export MXNET_ENABLE_CYTHON=0
export DMLC_LOG_STACK_TRACE_DEPTH=10
MXNET_GPU_MEM_POOL_TYPE=Unpooled \
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'not test_operator' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --verbose tests/python/gpu
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'not test_operator and not test_amp_init.py' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --verbose tests/python/gpu
MXNET_GPU_MEM_POOL_TYPE=Unpooled \
MXNET_ENGINE_TYPE=NaiveEngine \
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'test_operator' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu
pytest -m 'serial' --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu
pytest --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu/test_amp_init.py
}

unittest_ubuntu_python3_gpu_cython() {
Expand All @@ -852,11 +854,12 @@ unittest_ubuntu_python3_gpu_cython() {
export DMLC_LOG_STACK_TRACE_DEPTH=10
check_cython
MXNET_GPU_MEM_POOL_TYPE=Unpooled \
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'not test_operator' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --verbose tests/python/gpu
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'not test_operator and not test_amp_init.py' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --verbose tests/python/gpu
MXNET_GPU_MEM_POOL_TYPE=Unpooled \
MXNET_ENGINE_TYPE=NaiveEngine \
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'test_operator' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu
pytest -m 'serial' --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu
pytest --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu/test_amp_init.py
}

unittest_ubuntu_python3_gpu_nocudnn() {
Expand All @@ -868,11 +871,12 @@ unittest_ubuntu_python3_gpu_nocudnn() {
export MXNET_ENABLE_CYTHON=0
export DMLC_LOG_STACK_TRACE_DEPTH=10
MXNET_GPU_MEM_POOL_TYPE=Unpooled \
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'not test_operator' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --verbose tests/python/gpu
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'not test_operator and not test_amp_init.py' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --verbose tests/python/gpu
MXNET_GPU_MEM_POOL_TYPE=Unpooled \
MXNET_ENGINE_TYPE=NaiveEngine \
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'test_operator' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu
pytest -m 'serial' --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu
pytest --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu/test_amp_init.py
}

unittest_cpp() {
Expand All @@ -898,11 +902,12 @@ unittest_centos7_gpu() {
export CUDNN_VERSION=${CUDNN_VERSION:-7.0.3}
export DMLC_LOG_STACK_TRACE_DEPTH=10
MXNET_GPU_MEM_POOL_TYPE=Unpooled \
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'not test_operator' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'not test_operator and not test_amp_init.py' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu
MXNET_GPU_MEM_POOL_TYPE=Unpooled \
MXNET_ENGINE_TYPE=NaiveEngine \
OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -m 'not serial' -k 'test_operator' -n 4 --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu
pytest -m 'serial' --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu
pytest --durations=50 --cov-report xml:tests_gpu.xml --cov-append --verbose tests/python/gpu/test_amp_init.py
}

integrationtest_ubuntu_cpu_onnx() {
Expand Down
46 changes: 22 additions & 24 deletions python/mxnet/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,27 @@ def _get_nd_fun_to_wrap(name, module, submodule_dict):
else:
func_name = name
cur_module = module
return func_name, cur_module
return func_name, [cur_module]

def _get_np_fun_to_wrap(name, ns_prefix):
for pre, mod, subs in ((_NP_OP_PREFIX, 'numpy', _NP_OP_SUBMODULE_LIST),
(_NP_EXT_OP_PREFIX, 'numpy_extension', _NP_EXT_OP_SUBMODULE_LIST),
(_NP_INTERNAL_OP_PREFIX, 'numpy._internal', [])):
if name.startswith(pre):
name = name[len(pre):]
nm = name[len(pre):]
for sub in subs:
if name.startswith(sub):
return name[len(sub):], sys.modules[f'{ns_prefix}.{mod}.{sub[1:-1]}']
return name, sys.modules[f'{ns_prefix}.{mod}']
assert False
return None # for pylint
if nm.startswith(sub):
func, modules = nm[len(sub):], [sys.modules[f'{ns_prefix}.{mod}.{sub[1:-1]}']]
break
else:
func, modules = nm, [sys.modules[f'{ns_prefix}.{mod}']]
break
else:
assert False, f'Unable to find target module for {name} in {ns_prefix}'
if name.startswith(_NP_INTERNAL_OP_PREFIX) and ns_prefix == 'mxnet.ndarray':
if hasattr(ndarray.numpy._api_internal, func):
modules.append(ndarray.numpy._api_internal)
return func, modules

def _wrap_module_functions(module, is_numpy_module, target_dtype, get_aliases, get_cond_aliases,
get_fun_to_wrap, target_precision_ops=None, conditional_fp32_ops=None,
Expand Down Expand Up @@ -209,49 +216,40 @@ def _new_fun(*args, **kwargs):
wrap_list = target_precision_ops if target_precision_ops is not None \
else list_lp16_ops(target_dtype)
for fun_name in get_aliases(wrap_list):
try:
fun_name, cur_module = get_fun_to_wrap(fun_name, module)
fun_name, modules = get_fun_to_wrap(fun_name, module)
for cur_module in modules:
f_to_wrap = getattr(cur_module, fun_name)
fp32_param = fp32_param_list[fun_name] if (fp32_param_list and fun_name in fp32_param_list) else None
setattr(cur_module, fun_name, _wrapper(f_to_wrap, target_dtype, fp32_param=fp32_param))
if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _wrapper(f_to_wrap, target_dtype, fp32_param=fp32_param))
except AttributeError:
raise

wrap_list = fp32_ops if fp32_ops is not None else list_fp32_ops(target_dtype)
for fun_name in get_aliases(wrap_list):
try:
fun_name, cur_module = get_fun_to_wrap(fun_name, module)
fun_name, modules = get_fun_to_wrap(fun_name, module)
for cur_module in modules:
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32))
if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32))
except AttributeError:
raise

wrap_list = conditional_fp32_ops if conditional_fp32_ops is not None \
else list_conditional_fp32_ops(target_dtype)
for fun_name, arg, arg_values in get_cond_aliases(wrap_list):
try:
fun_name, cur_module = get_fun_to_wrap(fun_name, module)
fun_name, modules = get_fun_to_wrap(fun_name, module)
for cur_module in modules:
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32, cond_arg=(arg, arg_values)))
if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32, cond_arg=(arg, arg_values)))
except AttributeError:
raise


for fun_name in get_aliases(list_widest_type_cast(target_dtype)):
try:
fun_name, cur_module = get_fun_to_wrap(fun_name, module)
fun_name, modules = get_fun_to_wrap(fun_name, module)
for cur_module in modules:
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _symbol_widest_wrapper(f_to_wrap))
if not is_numpy_module and cur_module == module:
setattr(module.op, fun_name, _symbol_widest_wrapper(f_to_wrap))
except AttributeError:
raise

def _wrap_loss_output_functions(module, ls, target_dtype):
if module == ndarray:
Expand Down
53 changes: 53 additions & 0 deletions tests/python/gpu/test_amp_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
from mxnet.gluon import nn
from mxnet import amp
import numpy as np
import pytest


@pytest.fixture
def np_shape_array():
flags = mx.npx.is_np_shape(), mx.npx.is_np_array(), mx.npx.is_np_default_dtype()
mx.npx.set_np()
yield
mx.npx.set_np(*flags)


@pytest.fixture(scope='module')
def amp_init():
amp.init()


def test_npi_concatenate_multicast(np_shape_array, amp_init):
class Foo(nn.HybridBlock):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense0 = nn.Dense(16, in_units=8)

def forward(self, x):
y = self.dense0(x)
return mx.np.concatenate([y, x], axis=-1)

foo = Foo()
foo.initialize(ctx=mx.gpu())

data = mx.np.ones((32, 8), ctx=mx.gpu())
out = foo(data)
assert out.dtype == np.float32

0 comments on commit 6648866

Please sign in to comment.