-
Notifications
You must be signed in to change notification settings - Fork 251
[query] fix ndarray concat with size 0 dims #13755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
assert(np.array_equal(hl.eval(hl.nd.vstack((a, b))), np.vstack((a, b)))) | ||
assert(np.array_equal(hl.eval(hl.nd.vstack(hl.array([a, b]))), np.vstack((a, b)))) | ||
assert(np.array_equal(hl.eval(hl.nd.vstack((a, empty, b))), np.vstack((a, empty, b)))) | ||
assert(np.array_equal(hl.eval(hl.nd.vstack(hl.array([a, empty, b]))), np.vstack((a, empty, b)))) | ||
assert(np.array_equal(hl.eval(hl.nd.vstack((empty, a, b))), np.vstack((empty, a, b)))) | ||
assert(np.array_equal(hl.eval(hl.nd.vstack(hl.array([empty, a, b]))), np.vstack((empty, a, b)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite a lot of asserts for one test. Perhaps break these apart or consider parameterisation.
val dimLength = cb.newLocal[Long]("dimLength", shapeOfNDAtIdx.loadField(cb, dimIdx).get(cb).asInt64.value) | ||
// compute index of first input which has non-zero concat axis size | ||
val firstNonEmpty = cb.newLocal[Int]("ndarray_concat_first_nonempty", 0) | ||
cb.whileLoop(stagedArrayOfSizes.loadElement(cb, firstNonEmpty).get(cb).asInt64.value.ceq(0L), { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this checking is null? don't we have a isna
function to make this clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is checking for size 0 axes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, arrayOfSizes makes this kinda obvious. duh.
val mismatchedDim = cb.newLocal[Int]("ndarray_concat_mismatched_dim", -1) | ||
val expected = cb.newLocal[Long]("ndarray_concat_expected_size") | ||
val found = cb.newLocal[Long]("ndarray_concat_found_size") | ||
for (i <- (0 until firstND.st.nDims).reverse if i != axis) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can do something like first.st.nDims - 1 to 0 by -1
, though .reverse
is fine too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it tricky to think about endpoint inclusivity using the -1 step, reverse is harder to get wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
ndarray concat was broken when the first input has size 0 along the concat axis. For example
The zeros matrix is 2 by 0, so horizontal concatenation should just return the other matrix.
(I once saw the first column filled with random numbers, presumably from a buffer overflow)
I did some cleaning up in the concat implementation, but the functional change is to record the index of the first input which is non-empty along the concat axis, and when resetting to the start of the axis, reset to that non-empty index. Other size 0 inputs are correctly handled when incrementing the index, the problem was that the first read happens before an increment.