Skip to content

[query] fix ndarray concat with size 0 dims #13755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 13, 2023

Conversation

patrick-schultz
Copy link
Collaborator

ndarray concat was broken when the first input has size 0 along the concat axis. For example

In [3]: hl.eval(hl.nd.hstack([hl.nd.zeros((2, 0)), hl.nd.array([[1.0, 2.0], [3.0, 4.0]])]))
Out[3]:
array([[0., 2.],
       [0., 4.]])

The zeros matrix is 2 by 0, so horizontal concatenation should just return the other matrix.
(I once saw the first column filled with random numbers, presumably from a buffer overflow)

I did some cleaning up in the concat implementation, but the functional change is to record the index of the first input which is non-empty along the concat axis, and when resetting to the start of the axis, reset to that non-empty index. Other size 0 inputs are correctly handled when incrementing the index, the problem was that the first read happens before an increment.

ehigham
ehigham previously requested changes Oct 4, 2023
Comment on lines 1147 to 1152
assert(np.array_equal(hl.eval(hl.nd.vstack((a, b))), np.vstack((a, b))))
assert(np.array_equal(hl.eval(hl.nd.vstack(hl.array([a, b]))), np.vstack((a, b))))
assert(np.array_equal(hl.eval(hl.nd.vstack((a, empty, b))), np.vstack((a, empty, b))))
assert(np.array_equal(hl.eval(hl.nd.vstack(hl.array([a, empty, b]))), np.vstack((a, empty, b))))
assert(np.array_equal(hl.eval(hl.nd.vstack((empty, a, b))), np.vstack((empty, a, b))))
assert(np.array_equal(hl.eval(hl.nd.vstack(hl.array([empty, a, b]))), np.vstack((empty, a, b))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite a lot of asserts for one test. Perhaps break these apart or consider parameterisation.

val dimLength = cb.newLocal[Long]("dimLength", shapeOfNDAtIdx.loadField(cb, dimIdx).get(cb).asInt64.value)
// compute index of first input which has non-zero concat axis size
val firstNonEmpty = cb.newLocal[Int]("ndarray_concat_first_nonempty", 0)
cb.whileLoop(stagedArrayOfSizes.loadElement(cb, firstNonEmpty).get(cb).asInt64.value.ceq(0L), {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this checking is null? don't we have a isna function to make this clearer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is checking for size 0 axes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, arrayOfSizes makes this kinda obvious. duh.

val mismatchedDim = cb.newLocal[Int]("ndarray_concat_mismatched_dim", -1)
val expected = cb.newLocal[Long]("ndarray_concat_expected_size")
val found = cb.newLocal[Long]("ndarray_concat_found_size")
for (i <- (0 until firstND.st.nDims).reverse if i != axis) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do something like first.st.nDims - 1 to 0 by -1, though .reverse is fine too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it tricky to think about endpoint inclusivity using the -1 step, reverse is harder to get wrong.

Copy link
Member

@ehigham ehigham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@danking danking merged commit 79c1c82 into hail-is:main Oct 13, 2023
@patrick-schultz patrick-schultz deleted the fix-ndarray-concat branch January 2, 2025 13:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants