New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement sample GAT model for working PyG with DeepChem #2109
Conversation
This PR is ready to review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really neat! Well written, well documented code :)
I've done a first pass with a couple of comments
get_bond_stereo_one_hot | ||
|
||
|
||
def constrcut_atom_feature(atom: RDKitAtom, h_bond_infos: List[Tuple[int, str]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: Should be construct_atom_feature
.
|
||
|
||
def constrcut_atom_feature(atom: RDKitAtom, h_bond_infos: List[Tuple[int, str]], | ||
sssr: List[Sequence]) -> List[float]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we perhaps make this return an numpy array instead of a list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reasons? I think this function are basically not used in other place. (Currently, this I added the underscore to this function name like _ construct_atom_feature
)
class MolGraphConvFeaturizer(MolecularFeaturizer): | ||
"""This class is a featurizer of gerneral graph convolution networks for molecules. | ||
|
||
The default node(atom) and edge(bond) representations are based on WeaveNet paper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps add a hyperlink to the weave paper in the references section below
The default node(atom) and edge(bond) representations are based on WeaveNet paper. | ||
If you want to use your own representations, you could use this class as a guide | ||
to define your original Featurizer. In many cases, it's enough to modify return values of | ||
`constrcut_atom_feature` or `constrcut_bond_feature`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: Should be construct for both of these
One more quick ask, could you update the model cheatsheet to add |
I will add a lot of documentations related to TorchModel, CGCNN, GAT in #2124 |
- Chirality: A one-hot vector of the chirality, "R" or "S". | ||
- Formal charge: Integer electronic charge. | ||
- Partial charge: Calculated partial charge. | ||
- Ring sizes: A one-hot vector of the number of rings (3-8) that include this atom. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you mean the size of the ring? Not many atoms belong to three rings, much less eight!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! I fixed.
""" | ||
Parameters | ||
---------- | ||
add_self_loop: bool, default False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps add_self_edges
would be clearer? This isn't really about loops, so the name could be confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agreed! I fixed.
) | ||
|
||
# construct edge (bond) information | ||
src, dist, bond_features = [], [], [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably want the variable to be called dest
(short for destination?), not dist
which sounds like it's short for "distance"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is short for destination. I fixed!
deepchem/models/torch_models/gat.py
Outdated
>> dataset_config = {"reload": False, "featurizer": featurizer, "transformers": []} | ||
>> tasks, datasets, transformers = dc.molnet.load_tox21(**dataset_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be simpler to just write this as
tasks, datasets, transformers = dc.molnet.load_tox21(reload=False, featurizer=featurizer, transformers=[])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed!
deepchem/models/torch_models/gat.py
Outdated
in_node_dim: int = 38, | ||
hidden_node_dim: int = 64, | ||
heads: int = 4, | ||
dropout_rate: float = 0.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with other models, this should be just dropout
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed!
deepchem/models/torch_models/gat.py
Outdated
heads: int = 4, | ||
dropout_rate: float = 0.0, | ||
num_conv: int = 3, | ||
predicator_hidden_feats: int = 32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was that supposed to be "predictor"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I fixed!
deepchem/utils/graph_conv_utils.py
Outdated
@@ -0,0 +1,512 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file doesn't really have anything to do with convolutions. How about calling it molecule_feature_utils.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed!
args_info += arg_name + '=' + str(self.__dict__[arg_name]) + ', ' | ||
return self.__class__.__name__ + '[' + args_info[:-2] + ']' | ||
|
||
def __str__(self) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation is needed to resolve the Windows CI. This is referred #1829. repr function shows all arguments when instantiating a class. str function shows just updated arguments when instantiating a class.
This PR is ready to a second review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good! I have a couple minor comments. Once those are merged, I think this is good to merge in
deepchem/feat/molecule_featurizers/mol_graph_conv_featurizer.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Feel free to merge in whenever ready :)
Thanks! I merge in |
This PR is a part of #1942
What I did
utils/graph_conv_utils.py
and the featurizer is more readable and customizable for usersTODO
utils/graph_conv_utils.py
)I will make another PR.