Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make weight initialization extendable #6820

Merged
merged 18 commits into from Dec 22, 2018

Conversation

Projects
None yet
2 participants
@DrChainsaw
Copy link
Contributor

commented Dec 8, 2018

As discussed in #6813

I just made the initial interface and ported all existing strategies to new classes and added a testcase. Nothing is plugged in yet, I'm just hoping to get some initial feedback before wreaking havoc in layers and parameter initializers.

Plan is to copy the design from Activations/IActivations by replacing WeightInit and Distribution members of e.g BaseLayer with a IWeightInit, do a slight namechange (weightInit -> weighInitFn) and catch the old name when deserializing to support (to be) legacy nets.

For serialized layers which has a Distribution but has WeightInit != DISTRIBUTION I guess the Distribution can be ignored. If not I'll keep it as a member, but will not use it for weight init when deserializing (as the new WeightInitDistribution already has a Distribution object internally).

A less code-y approach would be to replace all new classes with a LegacyWeightInit which encapsulates the WeightInit and Distribution. Calls to init would then just redirect to WeightInitUtil. Imo this would be less attractive due to the extra verbosity (e.g. builder.weightInit(new LegacyWeightInit(WeightInit.SOME_INIT)) which would probably discourge users from using the class API.

@DrChainsaw

This comment has been minimized.

Copy link
Contributor Author

commented Dec 9, 2018

I have now proceeded with my plan and so far so good (although a bit tedious since more than a few unit test assert against the WeighInit).

I'm a bit unsure about what to do with the following bullet from the CONTRIBUTING checklist:

Don't put submodule updates in your pull request unless they are to landed commits.

My changes spilled over to unit tests in modelimport and core (maybe more) and also to some classes inside modelimport (basically workaround for the fact that I made it not allowed to set legacy weight init to DISTRIBUTION).

Shall I commit only my changes to deeplearning-nn first and then the rest or does the "landed commits" part of the bullet just mean "unless they are required for the project to build"? Or do I just misunderstand what a submodule is?

@AlexDBlack

This comment has been minimized.

Copy link
Member

commented Dec 10, 2018

At first glance, this looks good, design/implementation is pretty good so far. 👍

One minor thing: the signature...

INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView);

I don't think fanIn/fanOut can ever be anything other than integers, so let's make these args integers. (Unless you know of a use case where double is required?).

The only thing I'm not really happy with is that we'll have essentially duplicates in DL4J and ND4J:
https://github.com/deeplearning4j/deeplearning4j/tree/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl

Given it's basically the same thing, we might be able to combine them.
The main challenge I forsee here (other than slight API differences) is the fanIn/fanOut values - that's inferred automatically in DL4J from the network structure, but is user provided in ND4J/SameDiff.
Not sure what (if anything) we can do about that, I'll have to think about it...

Plan is to copy the design from Activations/IActivations by replacing WeightInit and Distribution members of e.g BaseLayer with a IWeightInit, do a slight namechange (weightInit -> weighInitFn) and catch the old name when deserializing to support (to be) legacy nets.

That's what I'd do - so sounds good.

For serialized layers which has a Distribution but has WeightInit != DISTRIBUTION I guess the Distribution can be ignored. If not I'll keep it as a member, but will not use it for weight init when deserializing (as the new WeightInitDistribution already has a Distribution object internally).

Yes, I'd do the same. Ignore field in legacy JSON unless it's DISTRIBUTION weight init.

A less code-y approach

Let's stick with the class-based approach. Yes, there's a ton of classes, but otherwise it's fine IMO.

I'm a bit unsure about what to do with the following bullet from the CONTRIBUTING checklist:
Don't put submodule updates in your pull request unless they are to landed commits.

I have no idea what that means or why it's there :)
I'm fine with PRs with changes across multiple modules when they really are required, like here.

DrChainsaw added some commits Dec 10, 2018

Change weight initialization to new type in NeuralNetConfiguration an…
…d BaseLayer (and some extending classes) and rename to weightInitFn. Also remove Distribution field.

Consequence of the latter is that an exception is thrown if legacy weightInit method is called with WeightInit.DISTRIBUTION as argument.

SameDiff not changed.
@DrChainsaw

This comment has been minimized.

Copy link
Contributor Author

commented Dec 10, 2018

Firstly, sorry for hard-to-review-blob.

Most controversial change imo is that calling (legacy) BaseLayer (or NeuralNetConfiguration.Builder) weightInit with WeightInit.DISTRIBUTION now generates an exception (as there is no distribution yet and WeightInitDistribution requires this as constructor input).

I guess I could work around this by having some WeighInitBuilder which is built at the same time as the Config is built, but I though it was better to just rip off the bandaid instead and not mess around.

I could not get all test cases to pass. Alot (all?) of SameDiff testcases failed with some NPE on some properties. Same for a bunch of gradient checking tests (which is a bit be worrying as my changes could affect them). A couple of other tests failed with some P.A.N.I.C exception.

Many of the failed testcases also failed on a "clean" copy of the branch but since it was a kind of manual back and forth testing I did not manage to check every single testcase this way.

I did ensure that the deserialization testcases in org.deeplearning4j.regressiontest passed (made sure they ran and failed first and then implemented legacy deserialization).

Same goes for a bunch of weight init asserting test cases in org.deeplearning4j.nn. In total I have 20 testcases in that package which fail in my environment both before and after changes.

I'll remove the WIP in hopes that it was a config error on my side. If not, I think I need some assistance on getting things to run in my environment (Windows 10) so I can test effectively. I found instructions for how to install test resources but they might not have installed properly. I didn't get any errors, but some testcases fail due a missing 1.0.0-SNAPSHOT.jar in the testresources dir. Btw, I found the instructions under "build from source instructions", maybe they belong in contributing?

Finally, as per @AlexDBlack s comments: I changed fanIn and fanOut to longs. The convolution param initializers had them as doubles because they divide by the stride(s) for fanOut. I think that truncating the division gives the actual fanOut though.

About duplication in Nd4j: I can give that one a shot in a separate PR perhaps (along with SameDiff). This one seems big enough already.

@DrChainsaw DrChainsaw changed the title [WIP] Make weight initialization extendable Make weight initialization extendable Dec 10, 2018

@AlexDBlack

This comment has been minimized.

Copy link
Member

commented Dec 10, 2018

hm... If we're replacing .weightInit(WeightInit.DISTRIBUTION) then let's deprecate .dist(...) too, in favor of .weightInit(Distribution).
Latter is clearer, no point having 2 methods for this.
And .weightInit(weightInit.getWeightInitFunction(distribution)) should be replaced by .weightInit(distribution) (Keras import)

Let's add a WeightInit.getWeightInitFunction() (i.e., no arg method) instead of always passing null when no distribution is used... it can pass null internally... just a bit cleaner than way.

Same goes for a bunch of weight init asserting test cases in org.deeplearning4j.nn. In total I have 20 testcases in that package which fail in my environment both before and after changes.

IIRC there was 2 or 3 tests left failing from the recent datatypes PR merge that I need to get to.
If it's missing test resources, make suer you have installed this and it's up to date: https://github.com/deeplearning4j/dl4j-test-resources
And yes, it should be mentioned anywhere contributors will be looking for instructions.

Finally, as per @AlexDBlack s comments: I changed fanIn and fanOut to longs. The convolution param initializers had them as doubles because they divide by the stride(s) for fanOut. I think that truncating the division gives the actual fanOut though.

I had completely forgotten about that edge case. Consider ConvolutionParamInitializer:

double fanOut = outputDepth * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);
long fanOut = outputDepth * kernel[0] * kernel[1] / (stride[0] * stride[1]);

now consider kernel 2, stride 3, out depth 1. With integer division, that gives us 0 fan out, whereas in reality some elements have 0 fan out, and some have non-zero fan out.
So I think in this case using double does make sense - it's like the average fan-out.
So in light of that it we might have to change back to double not long there sorry. (Maybe add a note on the API why it's double not long though)

About duplication in Nd4j: I can give that one a shot in a separate PR perhaps (along with SameDiff). This one seems big enough already.

Yeah, that sounds good. I'm in the process of making some fairly major changes to samediff, so perhaps it would be better after that is merged too (hopefully late this week).

Anyway, this looks close - let me know when you're done, I'll do a more thorough review + test run and we can get this merged. Thanks for your work here.

@DrChainsaw

This comment has been minimized.

Copy link
Contributor Author

commented Dec 12, 2018

Thanks alot for the support @AlexDBlack! Barring messed up testcases and further comments I think I'm done now. I was thinking I could look at the results from CI as it probably has the right environment to run the tests, but it didn't seem to finalize.

Comments on comments:
About fanIn/fanOut: Agree, I somehow had in my mind that they were the total number of elements in the activation and therefore the 0 edge case was anyways not valid. Commit reverted!

I didn't change .weightInit(weightInit.getWeightInitFunction(distribution)) to .weightInit(distribution) for the Keras import. Distribution returned from KerasInitilizationUtils.getWeightInitFromConfig is null for WeightInits != DISTRIBUTION so making your change would require null checking before deciding whether to use distribution or weightInit as input argument.

@DrChainsaw

This comment has been minimized.

Copy link
Contributor Author

commented Dec 14, 2018

@AlexDBlack Just checking if you expect something more from me here.

@AlexDBlack
Copy link
Member

left a comment

Apologies for the delay in getting back to this.
All LGTM - thanks!
Given we've got some unrelated failing tests, I'll run tests manually before merging this.

@AlexDBlack

This comment has been minimized.

Copy link
Member

commented Dec 21, 2018

Looks like theer's a couple of things in UI module that need updating here... should show up if you try an install:

[ERROR] Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.7.0:compile (default-compile) on project deeplearning4j-play_2.11: Compilation failure: Compilation failure:
[ERROR] /C:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java:[943,47] cannot find symbol
[ERROR] symbol:   method getWeightInit()
[ERROR] location: variable bl of type org.deeplearning4j.nn.conf.layers.BaseLayer
[ERROR] /C:/DL4J/Git/deeplearning4j/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java:[946,42] cannot find symbol
[ERROR] symbol:   method getDist()
[ERROR] location: variable bl of type org.deeplearning4j.nn.conf.layers.BaseLayer

Also mind dealing with that conflict?
I expect to run tests tomorrow, but I can fix those 2 issues if you haven't by the time I look at it by then.

@DrChainsaw

This comment has been minimized.

Copy link
Contributor Author

commented Dec 21, 2018

@AlexDBlack Fixed the issue and checked manually that other projects compiled (apparently it was not enough to rebuild on top level). I tried mvn install, but then it started running all testcases and aborted when it hit one that fails.

I CBA to implement toString in each new WeightInit so I went the "lazy" route and added the json string as info. In some sense it is a bit less error prone as it does not require users to implement toString in new custom weight inits to get meaningful info.

If not ok I can rework it tomorrow.

@AlexDBlack AlexDBlack merged commit 75e3996 into deeplearning4j:master Dec 22, 2018

0 of 3 checks passed

Codacy/PR Quality Review Hang in there, Codacy is reviewing your Pull request.
Details
codeclimate Code Climate is analyzing this code.
Details
continuous-integration/jenkins/pr-head This commit is being built
Details
@AlexDBlack

This comment has been minimized.

Copy link
Member

commented Dec 22, 2018

@DrChainsaw I checked this locally, fixed one minor merge conflict and merged.
Thanks again for the PR!

@DrChainsaw DrChainsaw deleted the DrChainsaw:newWeightInit branch Dec 23, 2018

@DrChainsaw

This comment has been minimized.

Copy link
Contributor Author

commented Dec 23, 2018

@AlexDBlack Thanks for merging!

printomi added a commit to printomi/deeplearning4j that referenced this pull request Jan 7, 2019

Make weight initialization extendable (deeplearning4j#6820)
* Added classes for legacy WeightInit variants

* Forgot to make LegacyWeightInitTest pass

* Added testcase for distributions

* Optimized imports

* Add test for output shape of IWeightInit

* Remove temp code

* Add ser/deser

* Change weight initialization to new type in NeuralNetConfiguration and BaseLayer (and some extending classes) and rename to weightInitFn. Also remove Distribution field.

Consequence of the latter is that an exception is thrown if legacy weightInit method is called with WeightInit.DISTRIBUTION as argument.

SameDiff not changed.

* Change API to IWeightInit to use long for fanIn and fanOut instead of double.

* Revert "Change API to IWeightInit to use long for fanIn and fanOut instead of double."

This reverts commit f0c1bec

* Fix javadoc

* Deprecate method dist

* Add no args version of WeightInit.getWeightInitFunction and replaced calls with null
Add null check of distribution in WeightInitDistribution

* Add ID mapping for convolution layers

* Clarify the use of doubles in API description

* Fix weightinit to string
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.