Join GitHub today
GitHub is home to over 40 million developers working together to host and review code, manage projects, and build software together.Sign up
Make weight initialization extendable #6820
As discussed in #6813
I just made the initial interface and ported all existing strategies to new classes and added a testcase. Nothing is plugged in yet, I'm just hoping to get some initial feedback before wreaking havoc in layers and parameter initializers.
Plan is to copy the design from Activations/IActivations by replacing WeightInit and Distribution members of e.g BaseLayer with a IWeightInit, do a slight namechange (weightInit -> weighInitFn) and catch the old name when deserializing to support (to be) legacy nets.
For serialized layers which has a Distribution but has WeightInit != DISTRIBUTION I guess the Distribution can be ignored. If not I'll keep it as a member, but will not use it for weight init when deserializing (as the new WeightInitDistribution already has a Distribution object internally).
A less code-y approach would be to replace all new classes with a LegacyWeightInit which encapsulates the WeightInit and Distribution. Calls to init would then just redirect to WeightInitUtil. Imo this would be less attractive due to the extra verbosity (e.g. builder.weightInit(new LegacyWeightInit(WeightInit.SOME_INIT)) which would probably discourge users from using the class API.
I have now proceeded with my plan and so far so good (although a bit tedious since more than a few unit test assert against the WeighInit).
I'm a bit unsure about what to do with the following bullet from the CONTRIBUTING checklist:
My changes spilled over to unit tests in modelimport and core (maybe more) and also to some classes inside modelimport (basically workaround for the fact that I made it not allowed to set legacy weight init to DISTRIBUTION).
Shall I commit only my changes to deeplearning-nn first and then the rest or does the "landed commits" part of the bullet just mean "unless they are required for the project to build"? Or do I just misunderstand what a submodule is?
At first glance, this looks good, design/implementation is pretty good so far.
One minor thing: the signature...
I don't think fanIn/fanOut can ever be anything other than integers, so let's make these args integers. (Unless you know of a use case where double is required?).
The only thing I'm not really happy with is that we'll have essentially duplicates in DL4J and ND4J:
Given it's basically the same thing, we might be able to combine them.
That's what I'd do - so sounds good.
Yes, I'd do the same. Ignore field in legacy JSON unless it's DISTRIBUTION weight init.
Let's stick with the class-based approach. Yes, there's a ton of classes, but otherwise it's fine IMO.
I have no idea what that means or why it's there :)
Firstly, sorry for hard-to-review-blob.
Most controversial change imo is that calling (legacy) BaseLayer (or NeuralNetConfiguration.Builder) weightInit with WeightInit.DISTRIBUTION now generates an exception (as there is no distribution yet and WeightInitDistribution requires this as constructor input).
I guess I could work around this by having some WeighInitBuilder which is built at the same time as the Config is built, but I though it was better to just rip off the bandaid instead and not mess around.
I could not get all test cases to pass. Alot (all?) of SameDiff testcases failed with some NPE on some properties. Same for a bunch of gradient checking tests (which is a bit be worrying as my changes could affect them). A couple of other tests failed with some P.A.N.I.C exception.
Many of the failed testcases also failed on a "clean" copy of the branch but since it was a kind of manual back and forth testing I did not manage to check every single testcase this way.
I did ensure that the deserialization testcases in org.deeplearning4j.regressiontest passed (made sure they ran and failed first and then implemented legacy deserialization).
Same goes for a bunch of weight init asserting test cases in org.deeplearning4j.nn. In total I have 20 testcases in that package which fail in my environment both before and after changes.
I'll remove the WIP in hopes that it was a config error on my side. If not, I think I need some assistance on getting things to run in my environment (Windows 10) so I can test effectively. I found instructions for how to install test resources but they might not have installed properly. I didn't get any errors, but some testcases fail due a missing 1.0.0-SNAPSHOT.jar in the testresources dir. Btw, I found the instructions under "build from source instructions", maybe they belong in contributing?
Finally, as per @AlexDBlack s comments: I changed fanIn and fanOut to longs. The convolution param initializers had them as doubles because they divide by the stride(s) for fanOut. I think that truncating the division gives the actual fanOut though.
About duplication in Nd4j: I can give that one a shot in a separate PR perhaps (along with SameDiff). This one seems big enough already.
hm... If we're replacing .weightInit(WeightInit.DISTRIBUTION) then let's deprecate .dist(...) too, in favor of .weightInit(Distribution).
Let's add a WeightInit.getWeightInitFunction() (i.e., no arg method) instead of always passing null when no distribution is used... it can pass null internally... just a bit cleaner than way.
IIRC there was 2 or 3 tests left failing from the recent datatypes PR merge that I need to get to.
I had completely forgotten about that edge case. Consider ConvolutionParamInitializer:
now consider kernel 2, stride 3, out depth 1. With integer division, that gives us 0 fan out, whereas in reality some elements have 0 fan out, and some have non-zero fan out.
Yeah, that sounds good. I'm in the process of making some fairly major changes to samediff, so perhaps it would be better after that is merged too (hopefully late this week).
Anyway, this looks close - let me know when you're done, I'll do a more thorough review + test run and we can get this merged. Thanks for your work here.
Thanks alot for the support @AlexDBlack! Barring messed up testcases and further comments I think I'm done now. I was thinking I could look at the results from CI as it probably has the right environment to run the tests, but it didn't seem to finalize.
Comments on comments:
I didn't change
Looks like theer's a couple of things in UI module that need updating here... should show up if you try an install:
Also mind dealing with that conflict?
@AlexDBlack Fixed the issue and checked manually that other projects compiled (apparently it was not enough to rebuild on top level). I tried mvn install, but then it started running all testcases and aborted when it hit one that fails.
I CBA to implement toString in each new WeightInit so I went the "lazy" route and added the json string as info. In some sense it is a bit less error prone as it does not require users to implement toString in new custom weight inits to get meaningful info.
If not ok I can rework it tomorrow.