diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index a0781ebf9ac1..f0267afb4e81 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -3519,24 +3519,18 @@ def insert(arr, obj, values, axis=None): N = arr.shape[axis] newshape = list(arr.shape) if isinstance(obj, (int, long, integer)): + if (obj < 0): obj += N if obj < 0 or obj > N: raise ValueError( "index (%d) out of range (0<=index<=%d) "\ "in dimension %d" % (obj, N, axis)) - newshape[axis] += 1; - new = empty(newshape, arr.dtype, arr.flags.fnc) - slobj[axis] = slice(None, obj) - new[slobj] = arr[slobj] - slobj[axis] = obj - new[slobj] = values - slobj[axis] = slice(obj+1,None) - slobj2 = [slice(None)]*ndim - slobj2[axis] = slice(obj,None) - new[slobj] = arr[slobj2] - if wrap: - return wrap(new) - return new + + if isinstance(values, (int, long, integer)): + obj = [obj] + else: + obj = [obj] * len(values) + elif isinstance(obj, slice): # turn it into a range object diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 95b32e47c721..2ed6e7edd092 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -145,7 +145,8 @@ def test_basic(self): assert_equal(insert(a, 0, 1), [1, 1, 2, 3]) assert_equal(insert(a, 3, 1), [1, 2, 3, 1]) assert_equal(insert(a, [1, 1, 1], [1, 2, 3]), [1, 1, 2, 3, 2, 3]) - + assert_equal(insert(a, 1,[1,2,3]), [1, 1, 2, 3, 2, 3]) + assert_equal(insert(a,[1,2,3],9),[1,9,2,9,3,9]) class TestAmax(TestCase): def test_basic(self):