-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeepGraphInfomax GNNModular pretraining task #3358
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments mostly around modification to the docs and requests for more unit tests
deepchem/models/torch_models/gnn.py
Outdated
""" | ||
The forward method takes two inputs, `x` (local node representations) and `summary` (global graph representations), both of shape `(batch_size, hidden_dim)`. | ||
It computes the product of `summary` and `self.weight`, and then calculates the element-wise product of `x` and the resulting matrix `h`. | ||
Finally, it returns the sum of the element-wise product along dimension 1 (i.e., summing over the `hidden_dim`), resulting in a tensor of shape `(batch_size,)`, which represents the similarity scores between the local and global representations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a numpydoc style Parameters field here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
deepchem/models/torch_models/gnn.py
Outdated
@@ -178,6 +180,36 @@ def forward(self, data): | |||
return out | |||
|
|||
|
|||
class Discriminator(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a unit test for this layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Feel free to merge in once CI looks clear
Description
Adds in the unsupervised deepgraphinfomax pretraining task for GNNModular. This maximizes the mutual information between positive (same graph) and negative (different graphs) graph representations.
Type of change
Please check the option that is related to your PR.
Checklist
yapf -i <modified file>
and check no errors (yapf version must be 0.32.0)mypy -p deepchem
and check no errorsflake8 <modified file> --count
and check no errorspython -m doctest <modified file>
and check no errors