New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sequential API for Model Construction #925
Conversation
if not len(layer.in_layers) == 0: | ||
raise ValueError("Cannot specify in_layers for Sequential.") | ||
layer.in_layers += [prev_layer] | ||
self._add_layer(layer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need this? When you set loss this should go through the tree and call already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, will remove.
|
||
if loss == "binary_crossentropy": | ||
smce = SoftMaxCrossEntropy(in_layers=[labels, prev_layer]) | ||
self._add_layer(smce) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto for this _add_layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will remove.
else: | ||
# TODO(rbharath): Add in support for additional losses. | ||
raise ValueError("Unsupported loss.") | ||
self._built = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? super().fit() should call build and will also install the queue for faster training/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Will remove.
This class is kind of strange, in that the model isn't fully defined until you call |
@peastman You raise a good point about A more general solution would probably be to write placeholder shapes in metadata and add placeholders if necessary in |
Going to go ahead and merge in this PR. Will address any outstanding issues in a follow-on PR. LGTM |
This PR adds a new API for constructing sequential models (those that are a linear stack of models) based on the Keras Sequential API. For simple networks, this new API allows users to skip explicitly specifying
Feature
,Label
, loss inputs, outputs, andin_layers
. For example, here's a simple classifier inSequential
:For comparison, here's the same model in the
TensorGraph
APIUnderneath the hood,
Sequential
inherits fromTensorGraph
and simply constructs the explicitTensorGraph
as needed. For now, losses are explicitly specified as strings passed intoSequential.fit()
. The initial implementation will support the same string arguments as the Keras Sequential API. The PR also removes an old (non-functional) version ofSequential
that was not based onTensorGraph
.This PR isn't quite ready to merge (will need to add in losses beyond
binary_crossentropy
andmse
), but I wanted to put it out for feedback on the API design.