Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't include saved state in keras_model.weights #140

Merged
merged 3 commits into from
Mar 17, 2020
Merged

Conversation

drasmuss
Copy link
Member

@drasmuss drasmuss commented Mar 10, 2020

This makes it easier to re-use saved weights between models, as models with different saved state variables can still use the same saved Keras parameters.

Fixes #133

@drasmuss drasmuss force-pushed the model-checkpoint branch 2 times, most recently from 6fc1f6e to 3380717 Compare March 13, 2020 17:15
@@ -22,7 +22,7 @@ def configure_settings(**kwargs):
Parameters
----------
trainable : bool or None
Adds a parameter to Nengo Ensembles/Connections/Networks that controls
Adds a parameter to Nengo Ensembles/Connections that controls
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a change in function that this can't be set on networks, or just a fix to the documentation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a fix for the documentation (setting it on networks was removed a while ago, I just missed this docstring).

non-trainable parameters of the network (this includes the internal
simulation state).
include_state : bool
If True (default False) also save the internal simulation state.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would opt for still having a line for include_non_trainable, that basically says it's deprecated and equivalent to include_state.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at nengo.Simulator.trange and the switch in parameters there, we don't have documentation for both the old and new names, but we do have a "version changed" tag, so I've opted for that instead.


vars = self.keras_model.weights
if include_state:
vars.extend(self.tensor_graph.saved_state.values())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will modify self.keras_model.weights, right? Is that a problem?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good call probably safer to do this on a copy

nengo_dl/simulator.py Show resolved Hide resolved
nengo_dl/simulator.py Outdated Show resolved Hide resolved
"trainable": OrderedDict(),
"non_trainable": OrderedDict(),
"state": OrderedDict(),
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make base_arrays an ordered dict just so it's deterministic when we call .items on it (right now, that's just in a logger message, but it's nice to have things consistent IMO).

CHANGES.rst Outdated
- Model parameters (e.g., connection weights) that are not trainable (because they've
been marked non-trainable by user or targeted by an online learning rule) will now
be treated separately from simulator state. For example, resetting the simulator
state will not reset those parameters, and the results of any online learning
Copy link
Collaborator

@hunse hunse Mar 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't calling sim.reset reset online learning, though (as per test_online_learning_reset)? It's just sim.soft_reset with include_trainable=False that doesn't reset.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was intending "resetting the simulator state" to refer specifically to include_trainable=False, but that may not be clear.


sim.reset()

assert np.allclose(w0, sim.data[conn].weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to group the lines a bit and add some short comments just making it clear what we expect for each case.

assert np.allclose(sim_load.data[p1], sim_save.data[p0][10:])
else:
assert not np.allclose(sim_load.data[p1], sim_save.data[p0][10:])
assert np.allclose(sim_load.data[p1], sim_save.data[p0][:10])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NTS: add a comment about this (took a minute of thinking to figure it out)

@hunse
Copy link
Collaborator

hunse commented Mar 16, 2020

Done with my review, and pushed a fixup commit making the changes I've suggested above. If it looks good to you @drasmuss, and passes CI, I'll merge.

EDIT: Oh, I did have one question above about the changelog, basically wondering if it's slightly unclear about when online learning weights are reset.

@drasmuss
Copy link
Member Author

Fixups all look good to me, I added a clarification to the changelog, if that looks good to you go ahead and merge!

@hunse hunse merged commit 32c3e67 into master Mar 17, 2020
@hunse hunse deleted the model-checkpoint branch March 17, 2020 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

Implement NengoDL version of ModelCheckpoint
2 participants