-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wgan Porting #3666
Wgan Porting #3666
Conversation
deepchem/models/torch_models/gan.py
Outdated
|
||
References | ||
---------- | ||
[1] Arjovsky, Martin, Soumith Chintala, and Léon Bottou. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check this renders correctly on the docs locally? It needs to be in numpydoc style
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
deepchem/models/torch_models/gan.py
Outdated
(https://arxiv.org/abs/1704.00028) | ||
""" | ||
|
||
def __init__(self, gradient_penalty=10.0, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type annotations are missing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
discrim_output_gen[1]) + discrim_output_train[1] | ||
|
||
|
||
class GradientPenaltyLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More detailed docstrings and usage examples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
discrim_output_gen[1]) + discrim_output_train[1] | ||
|
||
|
||
class GradientPenaltyLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need a unit test for this layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
deepchem/models/torch_models/gan.py
Outdated
class GradientPenaltyLayer(nn.Module): | ||
"""Implements the gradient penalty loss term for WGANs.""" | ||
|
||
def __init__(self, gan, discriminator, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type annotations here and rest of the layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
deepchem/models/torch_models/gan.py
Outdated
self.gan = gan | ||
self.discriminator = discriminator | ||
|
||
def forward(self, inputs, conditional_inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type annotations and return type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs more docs and type annotations
deepchem/models/torch_models/gan.py
Outdated
|
||
Notes | ||
----- | ||
This class is not intended to be used directly. It is used internally by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
deepchem/models/torch_models/gan.py
Outdated
the output from the discriminator, followed by the gradient penalty. | ||
""" | ||
# concatenate inputs and conditional_inputs | ||
# inputs = list(inputs) + list(conditional_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cruft? Should be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few cleanup comments and requests for documentation
discrim_output_gen[1]) + discrim_output_train[1] | ||
|
||
|
||
class GradientPenaltyLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to add thhis layer to the docs; can do in a follow up PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
WGAN Porting from Tensorflow to Pytorch
Type of change
Please check the option that is related to your PR.
Checklist
yapf -i <modified file>
and check no errors (yapf version must be 0.32.0)mypy -p deepchem
and check no errorsflake8 <modified file> --count
and check no errorspython -m doctest <modified file>
and check no errors