Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor PartialDerivatives to contain single partial #413

Merged
merged 65 commits into from
Jan 2, 2019

Conversation

gordoncaleb
Copy link
Contributor

This PR contains quite a big refactor that does not contain many functional changes but significantly decreases the complexity of the forward and reverse mode auto diff logic. These refactors include:

  • PartialDerivatives has been renamed PartialDerivative and no longer contains multiple partial derivatives.

  • Changed the way forward mode auto diff is calculated in order to avoid needing to store more than a single partial derivative with respect to anything at a time.

  • Operations that support operands that broadcast (e.g. +-*/) are now responsible for correcting shape changes due to that broadcast. This removes the need for lots of complex code to handle many edge cases in the partial multiply and in the reverse mode autodiff algorithm itself.

  • The Differentiator reverse mode method now returns a PartialsOf object that contains a collection of partials with respect to many inputs but always of a single output. For example, Differentiator.reverseModeAutoDiff(A, B, C) would get the derivative of A with respect to B AND C. The object returned from this contains a withRespectTo(...) that would return the derivative of A with respect to B OR C.

  • The Differentiator forward mode method now returns a PartialsWithRespectTo object that like the PartialsOf class contains a collection of partials but are all with respect to the same input but of many outputs. For example, Differentiator.forwardModeAutoDiff(A, B, C) would get the derivative of B and C with respect to A. The object returned from this contains a of(...) that would return the derivative of a single B OR C with respect to A.

  • Differentiable::getDerivativeWrtLatents used to be backed by forward mode auto diff and used quite heavily in our tests. It has been removed and replaced with a direct call to the forward mode algorithm (Differentiator::forwardModeAutoDiff).

  • Some operations have had their forward/reverse mode AD code refactored to reflect new guarantees provided by the fact that there is only a single partial derivative to deal with at a time.

… store more than one partial in the partial derivatives class
…artials can be created inside PartialDerivative
… map of multiple partials but instead a single present or absent partial
…ZERO and make sure ops on it return correctly
Copy link
Contributor

@GeorgeNash GeorgeNash left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice refactor man. Few comments

public static PartialDerivative matrixMultiplyAlongOfDimensions(PartialDerivative partial, DoubleTensor multiplier, boolean partialIsLeft) {

if (partial.isEmpty()) {
return partial;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this return partial and not this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's static and there is no this. It's the same concept though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice - missed it was static

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why these 2 methods are static when all the others aren't?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're static in order to allow the 2nd tensor being multiplied by to be a DoubleTensor. The other option was to make it a method on the partial but keep the partialIsLeft flag. That gives you an api where you have A.matrixMultiplyAlongOfDimensions(B, false) looking like AB but due to the flag would actually be BA, which I thought was too confusing. If you make both A and B a PartialDerivative then you have to create an extra PartialDerivative for each multiply in order to account for AB and BA due to the fact that either A or B will be a plain old DoubleTensor. I've taken another look at this today and there might be a nicer way to organise this but all the alternatives are some combination of less clear and less performant. We can revisit this when we graph-ify everything.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW - I like this explanation.

Copy link
Contributor

@migwellian migwellian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good stuff. My comments are very minor.

public static PartialDerivative matrixMultiplyAlongOfDimensions(PartialDerivative partial, DoubleTensor multiplier, boolean partialIsLeft) {

if (partial.isEmpty()) {
return partial;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why these 2 methods are static when all the others aren't?

Copy link
Contributor

@christophernorth christophernorth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots cleaner and easier to navigate now IMO. Just a few small comments scattered around and potentially one bug (due to our confusingly named Tensor functions ;-)


/**
* This class is meant to help with auto diff in operations that support implicit broadcasting. E.g. In
* addition/subtraction/multiplication/division scalar operands can be operated with non-scalar operands.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obviously it's not just scalars that can be broadcast and lead to implicit partial changes - it's also compatible Tensors - presumably given the naming that's still a bug we have to fix at some point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those non-scalar broadcast aren't supported yet and strictly prohibited through the shape check in the operations that could broadcast. When we add that support then this code will need to be tweaked to support the new cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool stuff - If you see that as a bigger refactor, no bother. I wonder if there's anything we can do to mark this as something that will need fixing if we add broadcasting properly so this isn't missed?

if (existingPartialDerivative == null) {
partials.put(id, entry.getValue().duplicate());
} else {
existingPartialDerivative.plusInPlace(entry.getValue());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have a potential bug hiding here now? plusInPlace sometimes doesn't actually do things in place...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment it's impossible for the existing partial to have a different shape to the next partial. All partials at this point are the correct deterministic shape [of, wrt]. That being said, the inPlace contract doesn't guarantee that the object will always be the same. I've gone with the double put to make sure this is never an issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think better to be safe than sorry definitely!

public static PartialDerivative matrixMultiplyAlongOfDimensions(PartialDerivative partial, DoubleTensor multiplier, boolean partialIsLeft) {

if (partial.isEmpty()) {
return partial;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW - I like this explanation.

final Map<VertexId, DoubleTensor> tensorMap = new HashMap<>();

for (Map.Entry<VertexId, PartialDerivative> entry : partials.entrySet()) {
tensorMap.put(entry.getKey(), entry.getValue().get());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this is in the hot path of reverse mode auto-diff - I'm not sure I love having to do this translation from one map type to another - just seems to be adding busy work to avoid other people seeing PartialDerivatives rather than DoubleTensors ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried avoiding this earlier but gave up. I've refactored it to do the conversion at the point the double tensor is used. This avoids the double iteration and the new object creation.

DoubleIfVertex outputVertex = (DoubleIfVertex)complexNet.getVertexByLabel(new VertexLabel(OUTPUT_NAME));
DoubleVertex inputVertex = (DoubleVertex)complexNet.getVertexByLabel(new VertexLabel(INPUT_NAME));
DoubleIfVertex outputVertex = (DoubleIfVertex) complexNet.getVertexByLabel(new VertexLabel(OUTPUT_NAME));
DoubleVertex inputVertex = (DoubleVertex) complexNet.getVertexByLabel(new VertexLabel(INPUT_NAME));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just cast this straight to the proper vertex type to avoid having to do the cast in the Differentiator below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice clean up.

@@ -0,0 +1,40 @@
package io.improbable.keanu.vertices.dbl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be useful to wrap this as a Benchmark so we can monitor changes to this perf?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tests "efficiency", which just means it's not doing any more calls to autodiff than needed. I think the benchmarks should try quite a few more complex graphs that vary in length, depth, width, etc. The benchmarks could use this for inspiration but I don't think there's much benefit to coupling the benchmarks and this unit test.

DoubleTensor dCdA = dC.withRespectTo(A);
DoubleTensor dCdB = dC.withRespectTo(B);
DoubleTensor dCdA = dC.withRespectTo(A).get();
DoubleTensor dCdB = dC.withRespectTo(B).get();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit ugly we have to call this get() method all over the place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want it as a DoubleTensor then yes. We could call it toDoubleTensor() or we could just return the DoubleTensor from the withRespectTo(...) method. It doesn't look like the withRespectTo(...) method is ever used to get the PartialDerivative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having withRespectTo(...) return a DoubleTensor removes all of these get() calls. Nice clean up.

Copy link
Contributor

@christophernorth christophernorth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the changes.

@gordoncaleb gordoncaleb merged commit 506a81f into develop Jan 2, 2019
@gordoncaleb gordoncaleb deleted the feature/single-partial branch January 2, 2019 17:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants