In [1]:
%load_ext Cython

In [109]:
%%cython
import numpy as np
cimport numpy as np
cimport cython
DTYPE = np.int8
ctypedef np.int8_t DTYPE_t
IDXTYPE = np.int
ctypedef np.int_t IDXTYPE_t

@cython.boundscheck(False)
@cython.wraparound(False)
cdef flip_wrapped_1d(np.ndarray[DTYPE_t, ndim=2] source,
                     int pad_size, np.ndarray[IDXTYPE_t, ndim=2] pos):
    assert source.shape[0] == pos.shape[0]
    assert pos.shape[1] == 1
    cdef int wrap_x
    cdef int x
    cdef int n, N = source.shape[0]
    cdef int size_x = source.shape[1]-2*pad_size
    for n in range(N):
        for wrap_x in range(-1, 2):
            x = pad_size+wrap_x*size_x+pos[n, 0]
            if 0 <= x and x < source.shape[1]:
                source[n, x] *= -1

@cython.boundscheck(False)
@cython.wraparound(False)
cdef flip_wrapped_2d(np.ndarray[DTYPE_t, ndim=3] source,
                     int pad_size, np.ndarray[IDXTYPE_t, ndim=2] pos):
    assert source.shape[0] == pos.shape[0]
    assert pos.shape[1] == 2
    cdef int wrap_x, wrap_y
    cdef int x, y
    cdef int n, N = source.shape[0]
    cdef int size_x = source.shape[2]-2*pad_size
    cdef int size_y = source.shape[1]-2*pad_size
    for n in range(N):
        for wrap_y in range(-1, 2):
            y = pad_size+wrap_y*size_y+pos[n, 0]
            if 0 <= y and y < source.shape[1]:
                for wrap_x in range(-1, 2):
                    x = pad_size+wrap_x*size_x+pos[n, 1]
                    if 0 <= x and x < source.shape[2]:
                        source[n, y, x] *= -1

@cython.boundscheck(False)
@cython.wraparound(False)
cdef flip_wrapped_3d(np.ndarray[DTYPE_t, ndim=4] source,
                     int pad_size, np.ndarray[IDXTYPE_t, ndim=2] pos):
    assert source.shape[0] == pos.shape[0]
    assert pos.shape[1] == 3
    cdef int wrap_x, wrap_y, wrap_z
    cdef int x, y, z
    cdef int n, N = source.shape[0]
    cdef int size_x = source.shape[3]-2*pad_size
    cdef int size_y = source.shape[2]-2*pad_size
    cdef int size_z = source.shape[1]-2*pad_size
    for n in range(N):
        for wrap_z in range(-1, 2):
            z = pad_size+wrap_z*size_z+pos[n, 0]
            if 0 <= z and z < source.shape[1]:
                for wrap_y in range(-1, 2):
                    y = pad_size+wrap_y*size_y+pos[n, 1]
                    if 0 <= y and y < source.shape[2]:
                        for wrap_x in range(-1, 2):
                            x = pad_size+wrap_x*size_x+pos[n, 2]
                            if 0 <= x and x < source.shape[3]:
                                source[n, z, y, x] *= -1


r = 2
N = 4
x = np.arange(10)+1
y = np.pad(x, r, 'wrap')
z = np.stack([y]*N, 0)
pos = np.array([0, 4, 8, 7]).reshape((-1, 1))
flip_wrapped_1d(z.astype(np.int8), r, pos)
# print(z)

x = np.arange(9).reshape((3,3))
y = np.pad(x, r, 'wrap')
z = np.stack([y]*N, 0)
pos = np.array([(1,1), (2,1), (2,2), (2, 1)])
flip_wrapped_2d(z.astype(np.int8), r, pos)
# print(z)

N=2
r=1
x = np.arange(8).reshape((2,2,2))+1
y = np.pad(x, r, 'wrap')
z = np.stack([y]*N, 0)
pos = np.array([[0,0,0],[1,0,1]])
flip_wrapped_3d(z.astype(np.int8), r, pos)
# print(z)

In [148]:
%%cython

import numpy as np
cimport numpy as np
cimport cython
DTYPE = np.complex64
ctypedef np.complex64_t DTYPE_t
IDXTYPE = np.int
ctypedef np.int_t IDXTYPE_t

@cython.boundscheck(False)
@cython.wraparound(False)
cdef replace_wrapped_1d(np.ndarray[DTYPE_t, ndim=2] old,
                        np.ndarray[DTYPE_t, ndim=2] replacement,
                        np.ndarray[IDXTYPE_t, ndim=2] center):
    assert old.shape[0] == replacement.shape[0] == center.shape[0]
    assert center.shape[1] == 1
    assert replacement.shape[1] % 2 == 1
    cdef int wrap_x
    cdef int x
    cdef int n, N = old.shape[0]
    cdef int R_x = replacement.shape[1]
    cdef int pad_x = R_x-1
    cdef int size_x = old.shape[1]-2*pad_x
    cdef int offset_x
    for n in range(N):
        for wrap_x in range(-1, 2):
            offset_x = int(pad_x/2+wrap_x*size_x+center[n, 0])
            for x in range(max(offset_x, 0), min(offset_x+R_x, old.shape[1])):
                old[n, x] = replacement[n, x-offset_x]


@cython.boundscheck(False)
@cython.wraparound(False)
cdef replace_wrapped_2d(np.ndarray[DTYPE_t, ndim=3] old,
                        np.ndarray[DTYPE_t, ndim=3] replacement,
                        np.ndarray[IDXTYPE_t, ndim=2] center):
    assert old.shape[0] == replacement.shape[0] == center.shape[0]
    assert center.shape[1] == 2
    assert replacement.shape[1] % 2 == replacement.shape[2] % 2 == 1
    cdef int wrap_x, wrap_y
    cdef int x, y
    cdef int n, N = old.shape[0]
    cdef int R_x = replacement.shape[2]
    cdef int R_y = replacement.shape[1]
    cdef int pad_x = R_x-1
    cdef int pad_y = R_y-1
    cdef int size_x = old.shape[2]-2*pad_x
    cdef int size_y = old.shape[1]-2*pad_y
    cdef int offset_x, offset_y
    for n in range(N):
        for wrap_y in range(-1, 2):
            offset_y = int(pad_y/2+wrap_y*size_y+center[n, 0])
            for y in range(max(offset_y, 0), min(offset_y+R_y, old.shape[1])):
                for wrap_x in range(-1, 2):
                    offset_x = int(pad_x/2+wrap_x*size_x+center[n, 0])
                    for x in range(max(offset_x, 0), min(offset_x+R_x, old.shape[2])):
                        old[n, y, x] = replacement[n, y-offset_y, x-offset_x]
                        


@cython.boundscheck(False)
@cython.wraparound(False)
cdef replace_wrapped_3d(np.ndarray[DTYPE_t, ndim=4] old,
                        np.ndarray[DTYPE_t, ndim=4] replacement,
                        np.ndarray[IDXTYPE_t, ndim=2] center):
    assert old.shape[0] == replacement.shape[0] == center.shape[0]
    assert center.shape[1] == 3
    assert replacement.shape[1] % 2 == 1
    assert replacement.shape[2] % 2 == 1
    assert replacement.shape[3] % 2 == 1
    cdef int wrap_x, wrap_y, wrap_z
    cdef int x, y, z
    cdef int n, N = old.shape[0]
    cdef int R_x = replacement.shape[3]
    cdef int R_y = replacement.shape[2]
    cdef int R_z = replacement.shape[1]
    cdef int pad_x = R_x-1
    cdef int pad_y = R_y-1
    cdef int pad_z = R_z-1
    cdef int size_x = old.shape[3]-2*pad_x
    cdef int size_y = old.shape[2]-2*pad_y
    cdef int size_z = old.shape[1]-2*pad_z
    cdef int offset_x, offset_y, offset_z
    for n in range(N):
        for wrap_z in range(-1, 2):
            offset_z = int(pad_z/2+wrap_z*size_z+center[n, 0])
            for z in range(max(offset_z, 0), min(offset_z+R_z, old.shape[1])):
                for wrap_y in range(-1, 2):
                    offset_y = int(pad_y/2+wrap_y*size_y+center[n, 1])
                    for y in range(max(offset_y, 0), min(offset_y+R_y, old.shape[2])):
                        for wrap_x in range(-1, 2):
                            offset_x = int(pad_x/2+wrap_x*size_x+center[n, 2])
                            for x in range(max(offset_x, 0), min(offset_x+R_x, old.shape[3])):
                                old[n, z, y, x] = replacement[n, z-offset_z, y-offset_y, x-offset_x]
                                


N = 1
# x = np.arange(8)
# y = np.pad(x, 2, 'wrap')
# z = np.stack([y]*N, 0).astype(DTYPE)
# pos = np.array([[7]])
# r = np.stack([np.array([-1,-9,-2])]*N, 0).astype(DTYPE)
# print(z)
# replace_wrapped_1d(z, r, pos)
# print(z)

# x = np.arange(81).reshape((9,9))
# y = np.pad(x, 2, 'wrap')
# z = np.stack([y]*N, 0).astype(DTYPE)
# pos = np.array([[8,8]])
# r = np.stack([(np.arange(9).reshape((3,3))+1)*-1]*N, 0).astype(DTYPE)
# replace_wrapped_2d(z, r, pos)
# print(z)

x = np.arange(27).reshape((3, 3, 3))
y = np.pad(x, [(0,0), (0,0), (2,2)], 'wrap')
z = np.stack([y]*N, 0).astype(DTYPE)
pos = np.array([[0,0,0]])
r = np.array([[[[-9, -8, -7]]]]).astype(DTYPE)
print(z)
replace_wrapped_3d(z, r, pos)
print(z)

[[[[  1.+0.j   2.+0.j   0.+0.j   1.+0.j   2.+0.j   0.+0.j   1.+0.j]
   [  4.+0.j   5.+0.j   3.+0.j   4.+0.j   5.+0.j   3.+0.j   4.+0.j]
   [  7.+0.j   8.+0.j   6.+0.j   7.+0.j   8.+0.j   6.+0.j   7.+0.j]]

  [[ 10.+0.j  11.+0.j   9.+0.j  10.+0.j  11.+0.j   9.+0.j  10.+0.j]
   [ 13.+0.j  14.+0.j  12.+0.j  13.+0.j  14.+0.j  12.+0.j  13.+0.j]
   [ 16.+0.j  17.+0.j  15.+0.j  16.+0.j  17.+0.j  15.+0.j  16.+0.j]]

  [[ 19.+0.j  20.+0.j  18.+0.j  19.+0.j  20.+0.j  18.+0.j  19.+0.j]
   [ 22.+0.j  23.+0.j  21.+0.j  22.+0.j  23.+0.j  21.+0.j  22.+0.j]
   [ 25.+0.j  26.+0.j  24.+0.j  25.+0.j  26.+0.j  24.+0.j  25.+0.j]]]]
[[[[ -7.+0.j  -9.+0.j  -8.+0.j  -7.+0.j  -9.+0.j  -8.+0.j  -7.+0.j]
   [  4.+0.j   5.+0.j   3.+0.j   4.+0.j   5.+0.j   3.+0.j   4.+0.j]
   [  7.+0.j   8.+0.j   6.+0.j   7.+0.j   8.+0.j   6.+0.j   7.+0.j]]

  [[ 10.+0.j  11.+0.j   9.+0.j  10.+0.j  11.+0.j   9.+0.j  10.+0.j]
   [ 13.+0.j  14.+0.j  12.+0.j  13.+0.j  14.+0.j  12.+0.j  13.+0.j]
   [ 16.+0.j  17.+0.j  15.+0.j  16.+0.j