-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Weave class and WeaveModel class #3529
Conversation
3ab66ac
to
3c9a397
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ARY2260 Can you do a preliminary round of review?
pad_batches=pad_batches): | ||
if y_b is not None: | ||
if self.model.mode == 'classification': | ||
y_b = to_one_hot(y_b.flatten(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
separate handling of labels in the data with in the default generator may not be required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default_generator function gets called internally for Numpy Dataset when model.fit() is run. And to keep it close to tensorflow implementation I thought it would be best to keep the default generator function.
n_tasks = self.n_tasks | ||
if self.mode == 'classification': | ||
n_classes = self.n_classes | ||
self.layer_2 = nn.LazyLinear(n_tasks * n_classes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rbharath will it be fine to use lazy linear over default nn.linear here. The usage here is based on the fact that input size is not known, but it may be possible to get that.
if weight_decay_penalty != 0.0: | ||
weights = [layer.weight for layer in self.model.layers2] | ||
if weight_decay_penalty_type == 'l1': | ||
regularization_loss = lambda: weight_decay_penalty * torch.sum( # noqa: E731 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please check again as we generally don't use #noqa to fix lint issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#noqa has to be used for lambda and since regularization_loss is a callable type I think it's required.
@@ -172,7 +171,8 @@ def __init__( | |||
] | |||
self.batch_normalize: bool = batch_normalize | |||
self.n_weave: int = n_weave | |||
torch.manual_seed(21) | |||
|
|||
# torch.manual_seed(21) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have readded the seed statement because otherwise the reload test is failing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it fine to add the seed statement here?
@@ -227,10 +227,15 @@ def __init__( | |||
|
|||
if n_layers > 0: | |||
self.layers2: nn.ModuleList = nn.ModuleList() | |||
in_size = 1408 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is input size always fixed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it actually depends on n_graph_feat. So I have changed this statement in terms of n_graph_feat.
def test_weave_singletask_classification_overfit(): | ||
"""Test weave model overfits tiny data.""" | ||
# np.random.seed(123) | ||
# torch.manual_seed(123) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think seed should be turned on for this test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
# Eval model on train | ||
scores = model.evaluate(dataset, [classification_metric]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mention a comment here suggesting to inspect model in future to understand low score
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I will do that. The unit test for tensorflow code uses the same value though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
else: | ||
self.layer_2 = nn.Linear(fully_connected_layer_sizes[1], n_tasks) | ||
|
||
def forward(self, inputs: OneOrMany[torch.Tensor]) -> List[torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add docstrings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
c151463
to
8fee49e
Compare
Parameters | ||
---------- | ||
inputs: OneOrMany[torch.Tensor] | ||
Should contain 5 tensors [atom_features, pair_features, pair_split, atom_split, atom_to_pair] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formatting here is a little off
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Added Weave class and WeaveModel class.
Type of change
Please check the option that is related to your PR.
Checklist
yapf -i <modified file>
and check no errors (yapf version must be 0.32.0)mypy -p deepchem
and check no errorsflake8 <modified file> --count
and check no errorspython -m doctest <modified file>
and check no errors