Skip to content

Commit

Permalink
Merge pull request #8072 from asi1024/refactor-radartools
Browse files Browse the repository at this point in the history
Refactor radartools
  • Loading branch information
takagi committed Jan 11, 2024
2 parents 2912651 + 0d714f6 commit cf4e09b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 43 deletions.
69 changes: 30 additions & 39 deletions cupyx/signal/_radartools/_radartools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@
from cupyx.scipy.signal import windows


def _pulse_preprocess(x, normalize, window):
if window is not None:
n = x.shape[-1]
if callable(window):
w = window(cupy.fft.fftfreq(n).astype(x.dtype))
elif isinstance(window, cupy.ndarray):
if window.shape != (n,):
raise ValueError("window must have the same length as data")
w = window
else:
w = windows.get_window(window, n, False).astype(x.dtype)
x = x * w

if normalize:
x = x / cupy.linalg.norm(x)

return x


def pulse_compression(x, template, normalize=False, window=None, nfft=None):
"""
Pulse Compression is used to increase the range resolution and SNR
Expand Down Expand Up @@ -56,30 +75,19 @@ def pulse_compression(x, template, normalize=False, window=None, nfft=None):
compressedIQ : ndarray
Pulse compressed output
"""
[num_pulses, samples_per_pulse] = x.shape
num_pulses, samples_per_pulse = x.shape
dtype = cupy.result_type(x, template)

if nfft is None:
nfft = samples_per_pulse

if window is not None:
Nx = len(template)
if callable(window):
W = window(cupy.fft.fftfreq(Nx))
elif isinstance(window, cupy.ndarray):
if window.shape != (Nx,):
raise ValueError("window must have the same length as data")
W = window
else:
W = windows.get_window(window, Nx, False)

template = template * W

if normalize is True:
template = template / cupy.linalg.norm(template)

t = _pulse_preprocess(template, normalize, window)
fft_x = cupy.fft.fft(x, nfft)
fft_template = cupy.fft.fft(template, nfft).conj()
return cupy.fft.ifft(fft_x * fft_template, nfft)
fft_t = cupy.fft.fft(t, nfft)
out = cupy.fft.ifft(fft_x * fft_t.conj(), nfft)
if dtype.kind != 'c':
out = out.real
return out


def pulse_doppler(x, window=None, nfft=None):
Expand All @@ -106,30 +114,13 @@ def pulse_doppler(x, window=None, nfft=None):
pd_dataMatrix : ndarray
Pulse-doppler output (range/doppler matrix)
"""
[num_pulses, samples_per_pulse] = x.shape
num_pulses, samples_per_pulse = x.shape

if nfft is None:
nfft = num_pulses

if window is not None:
Nx = num_pulses
if callable(window):
W = window(cupy.fft.fftfreq(Nx))[cupy.newaxis]
elif isinstance(window, cupy.ndarray):
if window.shape != (Nx,):
raise ValueError("window must have the same length as data")
W = window[cupy.newaxis]
else:
W = windows.get_window(window, Nx, False)[cupy.newaxis]

pd_dataMatrix = cupy.fft.fft(
cupy.multiply(x, cupy.tile(W.T, (1, samples_per_pulse))),
nfft, axis=0
)
else:
pd_dataMatrix = cupy.fft.fft(x, nfft, axis=0)

return pd_dataMatrix
xT = _pulse_preprocess(x.T, False, window)
return cupy.fft.fft(xT, nfft).T


def cfar_alpha(pfa, N):
Expand Down
14 changes: 10 additions & 4 deletions tests/cupyx_tests/signal_tests/radartools_tests/test_radartools.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@ def _numpy_pulse_compression(x, t, normalize, window):
return out.astype(dtype)


tol = {
numpy.float32: 2e-3,
numpy.float64: 1e-7,
numpy.complex64: 2e-3,
numpy.complex128: 1e-7,
}


@pytest.mark.parametrize('normalize', [True, False])
@pytest.mark.parametrize('window', [None, 'hamming', numpy.negative])
@testing.for_dtypes('fdFD')
@testing.numpy_cupy_allclose(
rtol=1e-3, type_check=False, contiguous_check=False)
@testing.numpy_cupy_allclose(rtol=tol, contiguous_check=False)
def test_pulse_compression(xp, normalize, window, dtype):
x = testing.shaped_random((8, 700), xp=xp, dtype=dtype)
template = testing.shaped_random((100,), xp=xp, dtype=dtype)
Expand All @@ -57,8 +64,7 @@ def test_pulse_compression(xp, normalize, window, dtype):

@pytest.mark.parametrize('window', [None, 'hamming', numpy.negative])
@testing.for_dtypes('fdFD')
@testing.numpy_cupy_allclose(
rtol=1e-3, type_check=False, contiguous_check=False)
@testing.numpy_cupy_allclose(rtol=tol, contiguous_check=False)
def test_pulse_doppler(xp, window, dtype):
x = testing.shaped_random((8, 700), xp=xp, dtype=dtype)

Expand Down

0 comments on commit cf4e09b

Please sign in to comment.