Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

natural_breaks(): drop gpu support #705

Merged
merged 1 commit into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
85 changes: 14 additions & 71 deletions xrspatial/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,68 +539,12 @@ def _run_numpy_jenks_matrices(data, n_classes):
return lower_class_limits, var_combinations


def _run_cupy_jenks_matrices(data, n_classes):
n_data = data.shape[0]
lower_class_limits = cupy.zeros((n_data + 1, n_classes + 1), dtype='f4')
lower_class_limits[1, 1:n_classes + 1] = 1.0

var_combinations = cupy.zeros((n_data + 1, n_classes + 1), dtype='f4')
var_combinations[2:n_data + 1, 1:n_classes + 1] = cupy.inf

nl = data.shape[0] + 1
variance = 0.0

for l in range(2, nl): # noqa
sum = 0.0
sum_squares = 0.0
w = 0.0

for m in range(1, l + 1):
# `III` originally
lower_class_limit = l - m + 1
i4 = lower_class_limit - 1

val = data[i4]

# here we're estimating variance for each potential classing
# of the data, for each potential number of classes. `w`
# is the number of data points considered so far.
w += 1.0

# increase the current sum and sum-of-squares
sum += val
sum_squares += val * val

# the variance at this point in the sequence is the difference
# between the sum of squares and the total x 2, over the number
# of samples.
variance = sum_squares - (sum * sum) / w

if i4 != 0:
for j in range(2, n_classes + 1):
jm1 = j - 1
if var_combinations[l, j] >= \
(variance + var_combinations[i4, jm1]):
lower_class_limits[l, j] = lower_class_limit
var_combinations[l, j] = variance + \
var_combinations[i4, jm1]

lower_class_limits[l, 1] = 1.
var_combinations[l, 1] = variance

return lower_class_limits, var_combinations


def _run_jenks(data, n_classes, module):
def _run_jenks(data, n_classes):
# ported from existing cython implementation:
# https://github.com/perrygeo/jenks/blob/master/jenks.pyx

data.sort()

if module == np:
lower_class_limits, _ = _run_numpy_jenks_matrices(data, n_classes)
elif module == cupy:
lower_class_limits, _ = _run_cupy_jenks_matrices(data, n_classes)
lower_class_limits, _ = _run_numpy_jenks_matrices(data, n_classes)

k = data.shape[0]
kclass = np.zeros(n_classes + 1, dtype=np.float32)
Expand All @@ -617,12 +561,10 @@ def _run_jenks(data, n_classes, module):
return kclass


def _run_natural_break(agg, num_sample, k, module):
def _run_natural_break(agg, num_sample, k):
data = agg.data
num_data = data.size
max_data = module.max(data[module.isfinite(data)])
if module == cupy:
max_data = max_data.get()
max_data = np.max(data[np.isfinite(data)])

if num_sample is not None and num_sample < num_data:
# randomly select sample from the whole dataset
Expand All @@ -631,7 +573,7 @@ def _run_natural_break(agg, num_sample, k, module):
# use numpy.random to ensure the same result
generator = np.random.RandomState(1234567890)
idx = np.linspace(
0, data.size, data.size, endpoint=False, dtype=module.uint32
0, data.size, data.size, endpoint=False, dtype=np.uint32
)
generator.shuffle(idx)
sample_idx = idx[:num_sample]
Expand All @@ -649,11 +591,11 @@ def _run_natural_break(agg, num_sample, k, module):
'a long time.'.format(sample_data.size),
Warning)

sample_data = module.asarray(sample_data)
sample_data = np.asarray(sample_data)

# only include finite values
sample_data = sample_data[module.isfinite(sample_data)]
uv = module.unique(sample_data)
sample_data = sample_data[np.isfinite(sample_data)]
uv = np.unique(sample_data)
uvk = len(uv)

if uvk < k:
Expand All @@ -667,11 +609,11 @@ def _run_natural_break(agg, num_sample, k, module):
uv.sort()
bins = uv
else:
centroids = _run_jenks(sample_data, k, module)
bins = module.array(centroids[1:])
centroids = _run_jenks(sample_data, k)
bins = np.array(centroids[1:])
bins[-1] = max_data

out = _bin(agg, bins, module.arange(uvk))
out = _bin(agg, bins, np.arange(uvk))
return out


Expand Down Expand Up @@ -760,10 +702,11 @@ def natural_breaks(agg: xr.DataArray,
"""

mapper = ArrayTypeFunctionMapping(
numpy_func=lambda *args: _run_natural_break(*args, module=np),
numpy_func=lambda *args: _run_natural_break(*args),
dask_func=lambda *args: not_implemented_func(
*args, messages='natural_breaks() does not support dask with numpy backed DataArray.'), # noqa
cupy_func=lambda *args: _run_natural_break(*args, module=cupy),
cupy_func=lambda *args: not_implemented_func(
*args, messages='natural_breaks() does not support cupy backed DataArray.'), # noqa
dask_cupy_func=lambda *args: not_implemented_func(
*args, messages='natural_breaks() does not support dask with cupy backed DataArray.'), # noqa
)
Expand Down
16 changes: 0 additions & 16 deletions xrspatial/tests/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,22 +239,6 @@ def test_natural_breaks_cpu_deterministic():
)


@cuda_and_cupy_available
def test_natural_breaks_cupy(result_natural_breaks):
cupy_agg = input_data('cupy')
k, expected_result = result_natural_breaks
cupy_natural_breaks = natural_breaks(cupy_agg, k=k)
general_output_checks(cupy_agg, cupy_natural_breaks, expected_result, verify_dtype=True)


@cuda_and_cupy_available
def test_natural_breaks_cupy_num_sample(result_natural_breaks_num_sample):
cupy_agg = input_data('cupy')
k, num_sample, expected_result = result_natural_breaks_num_sample
cupy_natural_breaks = natural_breaks(cupy_agg, k=k, num_sample=num_sample)
general_output_checks(cupy_agg, cupy_natural_breaks, expected_result, verify_dtype=True)


@pytest.fixture
def result_equal_interval():
k = 3
Expand Down