Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lots of new functions on NDArrays #78

Merged
merged 6 commits into from Oct 23, 2018
Merged

Conversation

peastman
Copy link
Contributor

This implements #74.

For the moment, mean() and sum() are only implemented as functions, not methods. So you can write sum(x), but not x.sum(). Once the codegen issues are figured out, that will be easy to change.

For all the reduction funtions (min(), argMin(), sum(), etc.) I've included two versions. There's one that works over the whole array, and a second one that does the reduction over a single axis. I modelled the API on Numpy.

@kyonifer
Copy link
Owner

Thanks for looking into this. I like the idea overall-- I have a few comments on the implementation.

* @param rtol the maximum relative (i.e. fractional) difference to allow between elements
* @param atol the maximum absolute difference to allow between elements
*/
fun <R> allClose(other: NDArray<R>, rtol: Double=1e-05, atol: Double=1e-08): Boolean {
Copy link
Owner

@kyonifer kyonifer Oct 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding allClose directly to the NDArray interface means it will be available to all NDArrays, including e.g. NDArray<String>. One of the goals of implementing it as an extension function on Matrix was that it would only appear when a user had a supported numeric matrix type. For example, Matrix<Double>.allClose would compile but Matrix<Int>.allClose wouldn't. With the implementation in this PR, ndArrayOf("a","b").allClose(ndArrayOf("a","b")) compiles and throws an exception at runtime and we therefore lose type safety.

I agree with removing the Matrix version of allClose (since NDArray is a super type), but I'd recommend keeping the implementation as a set of extension functions, optimally only implemented for floating point (since as you pointed out in the issue, these functions don't make sense for integral types). If the difficulty in doing so is the currently complicated codegen implementation, I can take over from here and work that in if you'd like.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about defining it as an extension function for type <out Number>? That way you could still use it to compare between different numerical types, like I do in the test case.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me, as long as we can be careful to implement it so we don't box every element.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we can guarantee that getDouble() will always work, I think boxing is unavoidable. What about modifying the generic version to cast between numeric types, just like the specialized ones do? For example, in DefaultGenericNDArray.getDouble() change

if (ele is Double)
    return ele

to

if (ele is Number)
    return ele.toDouble()

Likewise, require that any future backends also support casting. Then we can just call getDouble() and know it will be implemented in whatever way is most efficient given the storage of the particular array.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we could dispatch in a when block, similar to the NDArray.invoke block, that checked the reified type and called a specific one for each primitive, i.e. for a NDArray<Int> we could call .getInt(..).toDouble(). I haven't tried to see if it would break down though.

if (ele is Number)
return ele.toDouble()

I think this route would work as well and might be cleaner. As long as we only expose the extension functions for NDArrays containing primitive types and we don't box in the specialized implementations, no objections here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll change it to work like that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to not even have the extension function, and just put the whole implementation into the allclose() function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inside the interface, you only get an opportunity to reify <R>, so you're still doing needlessly-expensive number conversion logic on the contents of the NDArray.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why? If getDouble() is declared to return a Double, on the JVM it will be implemented with a method whose return type is a primitive double. It's up to each implementation of the interface to provide a value of that type. In the case of an array of primitive doubles, no conversion or boxing should be needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I misread much of the above. You guys were talking about boxing and I misread it as being about type conversion. My comment was trying to avoid doing O(n) primitive casts from whatever the true underlying data type is to double, which seems like it might be important given

>>> Long.MAX_VALUE.toDouble() == Long.MAX_VALUE.toDouble() + 1
true

if (!(shape().toIntArray() contentEquals other.shape().toIntArray()))
return false
for (i in 0 until size) {
val a = getDouble(i)
Copy link
Owner

@kyonifer kyonifer Oct 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getDouble is only guaranteed to be implemented by NDArray<Double> backends, so this function is only guaranteed to work with Double currently. The current Default implementations for numerical types do cast between primitives, but the DefaultGenericNDArray implementation doesn't and it's unclear what future NDArray implementations like #70 will do. If one ends up needing to call e.g. getDouble in a context where the type is only known to be generic NDArray<T>, a try/catch is probably necessary to make sure the user doesn't get a undiagnosable error (for a DefaultGenericNDArray this method currently gives them the error Exception in thread "main" java.lang.IllegalStateException: Double methods not implemented for generic NDArray, which is difficult to decipher for a user)

Note that the goal of the get$PRIMITIVE methods is to have a non-boxing accessor that should only be called by code that knows the primitive type is correct. In particular, they were designed to be called by the extension functions with a known primitive type, such as fun NDArray<Double>.foo(...) = this.getDouble(), thus retaining type safety for the user. Ideally these get$PRIMITIVE methods would be marked internal so users wouldn't be able to see them at all, but there's some issues with module access that haven't been resolved yet (it looks like they are resolved in the new mpp though, so perhaps #77 will allow us to).

* Find the linear index of the minimum element in this array.
* If the array contains non-comparable values, this throws an exception.
*/
fun argMin(): Int
Copy link
Owner

@kyonifer kyonifer Oct 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could also be implemented as extension functions with a T: Comparable upper bound on their applicability. That is, something like fun <T: Comparable<T>> NDArray<T>.argMin(). Thus if the user has an NDArray holding something that is comparable argMin is a valid method, otherwise it's a compile time error.

For Matrix these were defined on the interface itself because matrices are guaranteed to be numerical, thus they applied to all Matrix implementations. For NDArray, there's a wide class of comparable types that arent numerical these would apply to, and an even wider class of objects that aren't ordered in which case this function doesn't apply

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed you made them methods rather than extensions so different backends could provide their own implementations? I can change them to extensions if you prefer, but that will rule out the possibility of a different backend having an optimized implementation.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about a hybrid approach, similar to how we implement indexed access operators? The user would see an extension function NDArray<Double>.argMin, which would in turn delegate the action to a method on the instance with a different name like argMinDouble (for index accessors, the analog is a NDArray<Double>.get(i) extension function that forwards to getDouble(i)).

The disadvantage is that right now the should-not-be-used-by-the-end-user-directly type-specific operations are still exposed to the user. There's a few potential ways this could be hidden from the user though:

  • Marking the delegated methods as protected. This is unavailable since we are using an interface to implement NDArray instead of an abstract base (which would cause other issues).

  • Marking the delegated methods as internal. This causes problems because the platform specific implementations are in a different project, which don't see internal declarations. However, this might be changing in the new multiplatform gradle plugin.

  • Hiding them all on some internal class that the user likely won't dive into and could be overridden, e.g. declaring a NumericalOperations class that has all the platform specific stuff on it. I've considered doing this a few times to solve other problems in the past.

I think for now we could just do the straightforward hybrid approach with publically delegated functions like the array accessors do now, and then convert all of them together in the future if we ever decide to.

Copy link
Owner

@kyonifer kyonifer Oct 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another perspective on this: BLAS level 1 operations like argmin are almost certainly memory bound and don't benefit from any sort of computational optimization a backend could bring. So I don't think there will be any immediate advantage to delegating these to the backends to do the work (my own benchmarks on the cblas k/native backend showed that implementing things like matrix addition directly in kotlin made no difference over making the call into openblas). So maybe the right answer here is O(N) operations like finding the minimum we just go ahead and implement it once in the extension function.

The only concern there is that, for future GPU backends, such an implementation of argmin will require a recall of all the memory from the gpu to the cpu to do the work in kotlin. However, thats a bridge we can cross when we get there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLAS level 1 operations like argmin are almost certainly memory bound and don't benefit from any sort of computational optimization a backend could bring.

Actually there's a lot of room for optimization. If the reduction is over the whole array, or the last axis, then memory access is fast. You're accessing everything in order so the cache makes it efficient. Otherwise you break the array up into tiles and compute several output values at once. That way you still only have to load each cache line once. Then you vectorize it and optionally multithread it.

Also, if accessing an individual element from Kotlin involves a native method call, putting the loop over elements in Kotlin rather than native code will be much slower.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I'm following, but that may be due to my being too brief in reply again and therefore being unclear. I'm good with having a delegating virtual method on the NDArray instance that the extension function calls to leave our options open, so let's go with that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I just misunderstood. So I'll do what you described in #78 (comment). And for the moment I won't worry about trying to hide the internal method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, the base NDArray class will define all possible reductions for all types? argMinDouble(), argMinFloat(), argMinInt(), etc., and the same for argMax, min, and max? Then there will be an extension function argMin() on NDArray<Double> that calls argMinDouble(). Is that right?

It's not obvious to me what advantage that has over just a single method for each one. Perhaps I'm still misunderstanding.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm bad at explaining myself. Let me see if I can shed some light on the idea...

The goal of using extension functions instead of just having members on NDArray is type safety. Suppose we define:

fun <T: Comparable<T>> NDArray<T>.argMin(): Int {
  ...
}

then we would have

ndArrayOf(1, 2, 3) // Okay
ndArrayOf(1.0,2.0) // Okay
ndArrayOf(Image("a.jpg"), Image("b.jpg")).argMin() // Compile-time error, not Comparable

Thus trying to use argMin on something non-comparable becomes a syntax error at compile time. If argMin were defined on the instance itself (as it currently is in this PR) then (foo as NDArray<Image>).argMin() doesn't fail until runtime. This is noted in the current docstring: If the array contains non-comparable values, this throws an exception. It would be better to throw a compile time error than an exception, thus the recommendation to use extension functions.

So we could just end it there, and have the extension function contain a full implementation of argMin. However, as you noted there might be use cases where the backend wants to take control over that implementation. In this case we'll need to delegate the extension function implementation to virtual methods on NDArray that can be overridden by the implementation. Then we'd have:

fun <T: Comparable<T>> NDArray<T>.argMin(): Int {
    return this.argMinHiddenFromUser()
}

The virtual method should be "hidden" from the user one way or another, with the extension functions being what they see and call (the hiding part is TBD). Doing it this way with a user-facing extension function backed by a "hidden" virtual method on NDArray allows us to present the user with type safety and still delegate the implementation to the backends.

I don't think we need virtual methods argMin$PRIMITIVE for each primitive. We probably just need a single argMinHiddenFromUser() virtual method on NDArray which is overridden by the backends to do the operation in an efficient way for its particular data configuration, similar to what you currently have for fun argMin(): Int (except renamed to something hidden, so it doesn't override the extension function). The duplicate-methods-for-each-primitive thing is only needed when we are boxing on each individual element. For example, with the array accessors we did need separate methods for each primitive, because otherwise array access would be extremely slow (it was actually re-implemented this way after a user report that it was too slow to use). The decision as to whether or not we need one-per-primitive overloads of a method should hinge on whether or not we're incurring a box per element. For aggregate operations like finding the argmin over the entire array, there is no need. If there are methods that pass in/out elements one by one then we may need to look at whats getting boxed to decide.

Does that help clarify at all?

@peastman
Copy link
Contributor Author

Ok, I think I've made all the changes.

@kyonifer
Copy link
Owner

Looks good now, thanks!

@kyonifer kyonifer merged commit 6a44fac into kyonifer:master Oct 23, 2018
@peastman peastman deleted the arrayfuncs branch October 24, 2018 00:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants