-
Notifications
You must be signed in to change notification settings - Fork 547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V2]: Overhaul MultiHotAtomFeaturizer
#658
Conversation
At the point we're altering the initial featurization scheme, is there a reason we don't just eliminate alkali, alkaline earth, and transition metals as well as noble gases? The frequency of these elements in typical inputs must be I agree with the notion that we can improve the featurization scheme to improve density, but I think if we're doing so, then we should take it a step further:
|
Thanks for the comment. I think we all agree to limit the element types. It is worth a discussion regarding the specific elements to include as default. The elements you suggested are commonly seen in drug discovery applications, but I do think some 4th row metals are commonly seen in materials design datasets. Na and Mg are also commonly seen. The reason I included 4th row metals are trying to keep some generalizability, but the inclusion of them is certainly open to discussion. Regarding the hybridization, I think adding s makes sense for H, and the sp2d is mainly because I chose to include 4th row elements that can be hybridized this way. Another approach is to simply go through the dataset and only include elements that have appeared in the dataset. What is your opinion on this? |
I think the question comes down to frequency, i.e., what fraction of total atoms in a total dataset are represented by these "less common" elements. If we can produce a bar charts of atomic frequency for representative datasets, we can set some principled criterion of what to include or exclude. I don't doubt that Na and Mg are present in typical datasets, but if less than 1% of compounds have one of these atoms and these compounds only contain a single one, I'd be hard pressed to believe that an MPNN is learning anything about these atoms' contributions to molecular properties beyond just guessing the unconditional mean of an known atom type. More specifically, these atoms contribute an independent channel in the atomic featurization scheme, how many gradient updates do you expect these channels to receive relative to more conventional atom types in a standard organic dataset. I'm less familiar with materials datasets, so I'd be curious to hear more about chemical composition of these. I come to chemprop from a small molecule background, and my impression is that the large majority of users (e.g., MLPDS) fall into this camp. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You included the 0 padding in the atomic_nums
array. I am afraid that might cause problems for users who want to supply their atomic_nums (e.g. featurizer = MultiHotAtomFeaturizer(atomic_nums = [1, 6, 7, 8, 9])
) because they would need to remember to include the 0 padding. (My example would throw an error if one of the inputs had a sulfur with the current configuration.) If having a 0 pad for atomic number is always a good idea, what do you think of reverting your changes that put it in atomic_nums
? I've suggested these changes. I didn't make the corresponding required changes in test_atom.py
though.
In any case you'll also need to update i = self.atomic_nums.get(a.GetAtomicNum() - 1, len(self.atomic_nums))
in num_only()
, maybe to i = self.atomic_nums.get(a.GetAtomicNum(), len(self.atomic_nums))
.
chemprop/featurizers/atom.py
Outdated
+---------------------+-----------------+--------------+ | ||
|
||
NOTE: the above signature only applies for the default arguments, as the each slice (save for | ||
the final two) can increase in size depending on the input arguments. | ||
""" | ||
|
||
max_atomic_num: InitVar[int] = 100 | ||
# all elements in the first 4 rows of periodic talbe plus iodine and 0 padding for other elements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# all elements in the first 4 rows of periodic talbe plus iodine and 0 padding for other elements | |
# all elements in the first 4 rows of periodic table plus iodine |
chemprop/featurizers/atom.py
Outdated
atomic_nums: Sequence[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, | ||
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, | ||
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, | ||
30, 31, 32, 33, 34, 35, 36, 53]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
atomic_nums: Sequence[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, | |
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, | |
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, | |
30, 31, 32, 33, 34, 35, 36, 53]) | |
atomic_nums: Sequence[int] = field(default_factory=lambda: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, | |
11, 12, 13, 14, 15, 16, 17, 18, 19, 20, | |
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, | |
31, 32, 33, 34, 35, 36, 53]) |
chemprop/featurizers/atom.py
Outdated
@@ -88,7 +95,7 @@ def __post_init__(self, max_atomic_num: int = 100): | |||
self.hybridizations, | |||
] | |||
subfeat_sizes = [ | |||
1 + len(self.atomic_nums), | |||
len(self.atomic_nums), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
len(self.atomic_nums), | |
1 + len(self.atomic_nums), |
chemprop/featurizers/atom.py
Outdated
@@ -109,18 +116,23 @@ def __call__(self, a: Atom | None) -> np.ndarray: | |||
return x | |||
|
|||
feats = [ | |||
a.GetAtomicNum() - 1, | |||
a.GetAtomicNum() if a.GetAtomicNum() in self.atomic_nums else 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a.GetAtomicNum() if a.GetAtomicNum() in self.atomic_nums else 0, | |
a.GetAtomicNum(), |
chemprop/featurizers/atom.py
Outdated
a.GetTotalDegree(), | ||
a.GetFormalCharge(), | ||
int(a.GetChiralTag()), | ||
int(a.GetTotalNumHs()), | ||
a.GetHybridization(), | ||
] | ||
i = 0 | ||
pad = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pad = False |
chemprop/featurizers/atom.py
Outdated
if not pad: | ||
i += len(choices) | ||
pad = True | ||
else: | ||
i += len(choices) + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not pad: | |
i += len(choices) | |
pad = True | |
else: | |
i += len(choices) + 1 | |
i += len(choices) + 1 |
edit: this was mistaken. I got confused by a variable that seemingly does nothing? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is outside the scope of your PR, but I see that the current tests in test_atom.py
aren't very robust. The problem is that the list of atoms tested only has carbon, nitrogen, flurine, and oxygen. I suggest that instead of:
SMI = "Cn1nc(CC(=O)Nc2ccc3oc4ccccc4c3c2)c2ccccc2c1=O"
@pytest.fixture(params=list(Chem.MolFromSmiles(SMI).GetAtoms())[:5])
def atom(request):
...
@pytest.mark.parametrize(
"a,x_v_orig",
zip(
list(Chem.MolFromSmiles("Fc1cccc(C2(c3nnc(Cc4cccc5ccccc45)o3)CCOCC2)c1").GetAtoms()),
We use:
parser = Chem.SmilesParserParams()
parser.removeHs = False
SMI = "IC([Rb])([H])c1ccccc1"
@pytest.fixture(params=list(Chem.MolFromSmiles(SMI, parser).GetAtoms())[:5])
def atom(request):
...
@pytest.mark.parametrize(
"a,x_v_orig",
zip(
list(Chem.MolFromSmiles(SMI, parser).GetAtoms())[:5],
This approach tests the featurizer on Iodine, normal Carbon, Rubdium (not in default), Hydrogen, and aromatic Carbon.
tests/unit/featurizers/test_atom.py
Outdated
] | ||
# fmt: on | ||
), | ||
) | ||
def test_x_orig(a, x_v_orig): | ||
f = MultiHotAtomFeaturizer() | ||
x_v_calc = f(a) | ||
print(x_v_calc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a reminder to remove this print before the PR is finished.
Could you clarify what you mean here David? The |
chemprop/featurizers/atom.py
Outdated
@@ -36,44 +36,51 @@ class MultiHotAtomFeaturizer(AtomFeaturizer): | |||
+---------------------+-----------------+--------------+ | |||
| slice [start, stop) | subfeature | unknown pad? | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| slice [start, stop) | subfeature | unknown pad? | | |
| slice [start, stop) | subfeature | pad for unknown? | |
I just realized that "unknown pad" doesn't mean that we don't know how big the padding is, but that it means there is a (single) padding bit set aside for any values not explicitly in the subfeature. That is probably obvious for others, but I've been confused about this for a while. Perhaps "pad for unknown" is more clear. If you made this change, then all the following rows would also need to have this column made wider.
I was confused by the |
Okay, my understanding of why |
Yesterday Oscar and I had discussions about this PR. I'll summarize a bit here and @oscarwumit /@kevingreenman can correct me if needed. DefaultsWe first focused on what the defaults of featurization should be. Generally we feel that the default length of the 1 hot encoding doesn't need to be very small and that a separate featurizer can be the "small" one. So the lengths as are currently in the PR seem good enough to move forward. It is a bit unclear though whether there should even be a padding for unknown values for any features, including atomic number, degree, formal charge, chiral tag, #Hs, and hybridization. The added length isn't really the issue as much as the user experience. Here's a breakdown of arguments for padding vs not padding for unknown values:
A main argument for including the pad is that Chemprop v1 did this. Doesn't mean we have to, but also means that we should have a reason before changing it. ImplementationAs noted before, currently the PR treats I feel it would be better to treat Side noteThe tests are failing due to Simpler method?My main current concern with the PR is the way it pads To make sure we are all on the same page, I want to give a quick summary of how the v1 code works. The purpose of the code is to map (atom type) to (bit index in one hot encoding). V1 does this via a user accessible variable If we aren't using all consecutive atom types then I think the added complexity of this PR stems from trying to put the pad for unknown atomic numbers at the beginning of the one hot encoding, while it goes at the end for the rest of the features. A final thought about custom featurizers.The question was raised about how users could use
|
2d94a81
to
b4fbe07
Compare
Thanks for the insightful comments. I have modified the implementation as discussed in the meeting. Currently the test on atoms will pass but not for the CGRs, and resolving this could take some time. Help from someone familiar with the CGR code is appreciated. |
Agreed that resolving those CGR tests will take some time. So we can plan to include this in the 2.0 formal release and not the release candidate. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am going to approve this PR but would like to add food for thought: we should consider refactoring the "setup"s into separate @classmethod
s. That is, we define a MultiHotAtomFeaturizer
class with no default argument values and instead rely on separate constructors that set the defaults, e.g.,:
class MultiHotAtomFeaturizer:
... # dataclass fields go here but WITHOUT the `default_factory` values
@classmethod
def yang2020(cls, condensed: bool=True):
r"""build the atom featurizer used in [1]_
Parameters
-----------
condensed : bool, default=False
whether to use a condensed list of atom types. If `False`, use all atomic numbers :math:`z in [1, 100]` . Otherwise, use atomic numbers :math:`z in [1, 37] \union {53}`
References
-----------
.. [1] REF TO OG CHEMPROP PAPER
"""
@classmethod
def organic(cls):
r"""build a minimal featurizer with atom types for typical organic elements, i.e., :math:`z \in {1, 5, 6, 7, 8, 9, 15, 16, 17, 35, 53}`"""
This would be more idiomatic because users would now build their atom featurizers like so:
af = MultiHotAtomFeaturizer.yang2020(condensed=False)`
rather than just assuming that the initializer is doing one thing when in reality the documentation/init has changed in a previous commit.
Thanks for the comment. @shihchengli and I will work on making sure the CGR tests pass for this PR before merging. And I will incorporate the comments. Please do not merge this PR yet. |
b4fbe07
to
a4f237a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making this PR. The changes look good to me. Some minor suggestions are left. I will work on the CGR tests.
chemprop/cli/common.py
Outdated
choices=list(RxnMode.keys()), | ||
help="""Choices for multi-hot atom featurization scheme. This will affect both non-reatction and reaction feturization (case insensitive): | ||
- 'default': Includes all elements in the first 4 rows of the periodic talbe plus iodine and an 0 padding for other elements (default in Chemprop v2). | ||
- 'v1': Same implementation as Chemprop v1 default. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 'v1': Same implementation as Chemprop v1 default. | |
- 'v1': Includes the first 100 elements in the periodic table (same implementation as Chemprop v1 default). |
tests/unit/featurizers/test_atom.py
Outdated
if n == 53: # special check for Iodine | ||
assert x[len(atomic_num) - 1] == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest using a SMILES containing iodine as a test case to avoid these two lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think using an iodine example will eliminate these two lines because for atomic number 1-36, it directly corresponds to entries 0 to 35 of the atomic number feature vector. However, iodine has an atomic number of 53 but it is mapped to the 36 index of the feature vector, and therefore need a special check anyway.
tests/unit/featurizers/test_atom.py
Outdated
def test_x_orig(a, x_v_orig): | ||
f = MultiHotAtomFeaturizer() | ||
def test_x_orig_default(a, x_v_orig): | ||
f = MultiHotAtomFeaturizer.default() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, test the other two methods. It would also be good to make the test code as similar as test_bond.py
. The zip is used here to extract the first 4 atoms to compare, but the index is used in test_bond.py
instead.
chemprop/cli/utils/parsing.py
Outdated
|
||
case "ORGANIC": | ||
atom_featurizer=MultiHotAtomFeaturizer.organic() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to raise an error for an unknow multi_hot_atom_featurizer_mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be a new enum:
class AtomFeatureMode(EnumMapping):
DEFAULT = auto()
V1 = auto()
ORGANIC = auto()
so that the unknown case will be handled by that and we can throw a RuntimeError
if it falls through the match
-statement
Various updates based on PR review comments.
Default behavior for atom featurizer is set in mixins.py, so no need to specify here.
e1d4224
to
053b017
Compare
After investigation, the failure of the tests is due to the fact that the output scalers are not saved in the checkpoint files. I have manually updated the values in the checkpoint files so that we can pass the tests. The issue with the output scalers has been mentioned in #694 and will be resolved in #726. |
Thanks for the update. @shihchengli Can you rebase to consolidate similar commits together? After that, we can merge this in. |
138c516
to
e73e85b
Compare
Co-authored-by: david graff <60193893+davidegraff@users.noreply.github.com> remove scheme tables in MultiHotAtomFeaturizer
Co-authored-by: david graff <60193893+davidegraff@users.noreply.github.com>
e73e85b
to
e972c33
Compare
Thanks everyone for the good work. I will merge. |
Description
This PR attempts to improve the initial atom featurization by limiting the default supported elements to common chemistry.
Example / Current workflow
The current setup in Chemprop, which allocates 101 bits for atomic number, assumes that most of the training sets likely contain chemistry involving the first 100 elements of the periodic table. This design choice, while comprehensive, may not be optimally aligned with the practical needs of most chemical property prediction tasks because these tasks typically involve a much narrower range of elements. As a result, the current encoding method tends to create a very sparse vector that is not necessary and also can negatively impact model training speed and memory requirement.
Bugfix / Desired workflow
This PR seeks to address the abovementioned issue by changing the default encoding of atomic number to elements that are commonly used in applications like pharmaceuticals and materials design. Specifically, the default is changed to the first 4 rows of the periodic table plus iodine and a zero padding for other elements. This design choice should be sufficient for most common use-cases, and the implementation can be easily adapted to include additional elements for special cases.
I carried out some pre-liminary benchmark using Chemprop v1. When training on ~300k bi-molecular reactions to predict barrier heights, models with new featurization strategy can be trained ~40% faster while achieving similar accuracy to current implementation. Therefore, I think we should implement this change in Chemprop v2.
Questions
I also included more hybridization types supported by rdkit. I think the s hybridization makes sense for H atoms, especially when explicit H is used. The sp2d hybridization is less common, but I think it does not hurt to be included.
Relevant issues
This PR partially addresses the issue: #547
Further discussions are needed to decide if we want to change some features (e.g., formal charges, bond orders, num of Hs) from one-hot encoding to ordinal encoding.
Checklist
All relevant unit tests have been updated and passed check.