Skip to content

Commit

Permalink
Added tests for the complex integer types.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaycedowell committed Nov 16, 2021
1 parent 5a6e6bd commit dfb73ce
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions test/test_map.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

# Copyright (c) 2016-2020, The Bifrost Authors. All rights reserved.
# Copyright (c) 2016-2021, The Bifrost Authors. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_shift(self):
a = a.copy('system')
b = b.copy('system')
np.testing.assert_equal(b, np.fft.fftshift(a))
def test_complex(self):
def test_complex_float(self):
n = 89
real = np.random.randint(-127, 128, size=(n,n)).astype(np.float32)
imag = np.random.randint(-127, 128, size=(n,n)).astype(np.float32)
Expand All @@ -141,6 +141,39 @@ def test_complex(self):
self.run_simple_test(x, "y = x*x.conj()", lambda x: x * x.conj())
self.run_simple_test(x, "y = x.mag2()", lambda x: x * x.conj())
self.run_simple_test(x, "y = 3*x", lambda x: 3 * x)
def test_complex_integer(self):
n = 7919
for in_dtype in ('ci4', 'ci8', 'ci16', 'ci32'):
a_orig = bf.ndarray(shape=(n,), dtype=in_dtype, space='system')
try:
a_orig['re'] = np.random.randint(256, size=n)
a_orig['im'] = np.random.randint(256, size=n)
except ValueError:
# ci4 is different
a_orig['re_im'] = np.random.randint(256, size=n)
for out_dtype in (in_dtype, 'cf32'):
a = a_orig.copy(space='cuda')
b = bf.ndarray(shape=(n,), dtype=out_dtype, space='cuda')
bf.map('b(i) = a(i)', {'a': a, 'b': b}, shape=a.shape, axis_names=('i',))
a = a.copy(space='system')
try:
a = a['re'] + 1j*a['im']
except ValueError:
# ci4 is different
a = np.int8(a['re_im'] & 0xF0) + 1j*np.int8((a['re_im'] & 0x0F) << 4)
a /= 16
b = b.copy(space='system')
try:
b = b['re'] + 1j*b['im']
except ValueError:
# ci4 is different
b = np.int8(b['re_im'] & 0xF0) + 1j*np.int8((b['re_im'] & 0x0F) << 4)
b /= 16
except IndexError:
# pass through cf32
pass
np.testing.assert_equal(a, b)

def test_polarisation_products(self):
n = 89
real = np.random.randint(-127, 128, size=(n,2)).astype(np.float32)
Expand Down

0 comments on commit dfb73ce

Please sign in to comment.