-
Notifications
You must be signed in to change notification settings - Fork 250
NDArray Sum Aggregator #9209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NDArray Sum Aggregator #9209
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a benchmark that sums a table of ndarrays?
Sure. Any size preference? I could do the sum of 100 2ds with dims 4096 x 4096 ndarrays for block matrix / dndarray similarity. |
Just big enough to get consistent timing numbers (probs easiest to check this on your laptop for now), imo. Also very curious to hear how master compares to this PR. |
Doing 100 4096 x 4096s on my laptop took 1min20s.
|
Your hacky version took 1min41 seconds on my laptop. Not as big a difference as I'd hoped. |
Ah, I bet I'm hitting the same problem I hit with the original |
Ok, so Cotton's new thing means emitting separate methods by hand is not a thing we do anymore. But there are two factors hurting the benchmark. One is that the benchmark is hiding the fact that we are spending ~25 seconds serializing and de-serializing JSON for this ndarray. So the real comparison is more like 55 seconds vs 75 seconds, which is a roughly 25% speed improvement. The other is that Anyway, 25% improvement + better interface is a win for now, we can revisit ways to make this faster in the future. |
you can definitely still generate methods by hand! |
I talked to Cotton about it and he said not to. But it's not clear how much of a difference that makes yet anyway. I think this version is pretty good and an improvement. Plus it'll add a benchmark which we can work on optimizing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a few high level comments.
@@ -54,7 +56,11 @@ abstract class PNDArray extends PType { | |||
abstract class PNDArrayValue extends PValue { | |||
def apply(indices: IndexedSeq[Value[Long]], mb: EmitMethodBuilder[_]): Value[_] | |||
|
|||
override def pt: PNDArray = ??? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not leave abstract?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This lets me set the return type for things that override it so they don't have to be cast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tim and I previously handled this using
private def ptND: PNDArray
override def pt: PType = ptND
|
||
override def resultType: PType = ndTyp | ||
|
||
val stateType = PCanonicalTuple(true, PBooleanRequired, ndTyp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the seqOp skips missing values, the ndarray field can never be missing if the state is initialized, right? Is there a reason not to just let a missing ndarray be the uninitialized state? Then the result is exactly the state.
If not, up to you if you want to change it. I think it's a bit cleaner, but not a big difference.
@@ -126,6 +126,10 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo | |||
)) | |||
} | |||
|
|||
def mutateElement(indices: IndexedSeq[Value[Long]], ndAddress: Value[Long], newElement: Code[_], mb: EmitMethodBuilder[_]): Code[Unit] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'd expect this to be called setElement
. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Addressed all comments, should be good to go. One question I did have is that I have to specify a region size in this aggregator, and I picked |
…he element arrays
Added new test that checks summing mix of regular and transposed data to make sure that future changes respect striding. |
That specifies the granularity of allocations. Each time the region needs more space than it has allocated, it allocates a new block of size determined by the size parameter. Unless it's trying to make a single allocation larger than the block size, in which case it allocates a block of exactly the desired size. The ideal thing here would be to make a single block with exactly the size needed for the state, but the current Region interface doesn't support that. The closest we can get is using the smallest block size, which is |
def isInitialized(state: State): Code[Boolean] = { | ||
stateType.isFieldDefined(state.off, ndarrayFieldNumber) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that you're just using the missing bit, you shouldn't need this or ndArrayPointer
. Instead of
cb.ifx(isInitialized(state), {
val currentNDPValue = PCode(ndTyp, ndArrayPointer(state)).asNDArray.memoize(cb, "ndarray_sum_seqop_current")
...
}, {
// uninitialized case
})
you can simplify to
val statePV = new PCanonicalBaseStructSettable(stateType, state.off)
...
statePV.loadField(ndarrayFieldNumber).consume(cb, {
// uninitialized case
}, { currentNDPCode =>
val currentNDPValue = currentNDPCode.asNDArray.memoize(cb, "ndarray_sum_seqop_current")
...
})
A bit simpler and more idiomatic.
Thanks for pushing me to clean that up with the However, I can't seem to write a working |
Weird. Did you do something like this?
Anyways, happy to approve if you don't want to mess with it. |
Yup, I did that, and I also tried a version where I did |
The ability to sum over many ndarrays with an aggregator is necessary for several distributed matmul like operations. We'll probably want something more general in the future, but for now this should suffice. This PR:
hl.agg.ndarray_sum
, a python aggregator that sums over ndarrays of the same shapeNDArraySumAggregator
PNDArrayValue.sameShape
, a helper function that checks if twoPNDArray
s are the same shape.Future work:
@danking