-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Increase more code coverage in _chgnet.py #261
Conversation
…in TensorNet class
…oat_th and including linear layer in TensorNet to match the original implementations
WalkthroughThe recent modifications involve enhancing the Changes
Recent Review DetailsConfiguration used: .coderabbit.yaml Files selected for processing (1)
Additional comments not posted (1)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
for readout_field in ["atom_feat", "bond_feat", "angle_feat"]: | ||
for final_mlp_type in ["gated", "mlp"]: | ||
model = CHGNet( | ||
element_types=["Mo", "S"], | ||
activation_type=activation, | ||
bond_update_hidden_dims=bond_dim, | ||
learn_basis=learn_basis, | ||
angle_update_hidden_dims=angle_dim, | ||
conv_dropout=dropout, | ||
readout_field=readout_field, | ||
final_mlp_type=final_mlp_type, | ||
) | ||
global_out = model(g=graph) | ||
assert torch.numel(global_out) == 1 | ||
assert torch.numel(graph.ndata["magmom"]) == graph.num_nodes() | ||
model.save(".") | ||
CHGNet.load(".") | ||
os.remove("model.pt") | ||
os.remove("model.json") | ||
os.remove("state.pt") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor repeated code in test_model
method.
The instantiation of CHGNet
and subsequent operations are repeated for each combination of readout_field
and final_mlp_type
. Consider refactoring this to reduce code duplication and improve maintainability. You could define a helper function that takes readout_field
and final_mlp_type
as parameters and performs the instantiation and checks.
+ def instantiate_and_test_model(readout_field, final_mlp_type):
+ model = CHGNet(
+ element_types=["Mo", "S"],
+ activation_type=activation,
+ bond_update_hidden_dims=bond_dim,
+ learn_basis=learn_basis,
+ angle_update_hidden_dims=angle_dim,
+ conv_dropout=dropout,
+ readout_field=readout_field,
+ final_mlp_type=final_mlp_type,
+ )
+ global_out = model(g=graph)
+ assert torch.numel(global_out) == 1
+ assert torch.numel(graph.ndata["magmom"]) == graph.num_nodes()
+ model.save(".")
+ CHGNet.load(".")
+ os.remove("model.pt")
+ os.remove("model.json")
+ os.remove("state.pt")
- for readout_field in ["atom_feat", "bond_feat", "angle_feat"]:
- for final_mlp_type in ["gated", "mlp"]:
- model = CHGNet(
- element_types=["Mo", "S"],
- activation_type=activation,
- bond_update_hidden_dims=bond_dim,
- learn_basis=learn_basis,
- angle_update_hidden_dims=angle_dim,
- conv_dropout=dropout,
- readout_field=readout_field,
- final_mlp_type=final_mlp_type,
- )
- global_out = model(g=graph)
- assert torch.numel(global_out) == 1
- assert torch.numel(graph.ndata["magmom"]) == graph.num_nodes()
- model.save(".")
- CHGNet.load(".")
- os.remove("model.pt")
- os.remove("model.json")
- os.remove("state.pt")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
for readout_field in ["atom_feat", "bond_feat", "angle_feat"]: | |
for final_mlp_type in ["gated", "mlp"]: | |
model = CHGNet( | |
element_types=["Mo", "S"], | |
activation_type=activation, | |
bond_update_hidden_dims=bond_dim, | |
learn_basis=learn_basis, | |
angle_update_hidden_dims=angle_dim, | |
conv_dropout=dropout, | |
readout_field=readout_field, | |
final_mlp_type=final_mlp_type, | |
) | |
global_out = model(g=graph) | |
assert torch.numel(global_out) == 1 | |
assert torch.numel(graph.ndata["magmom"]) == graph.num_nodes() | |
model.save(".") | |
CHGNet.load(".") | |
os.remove("model.pt") | |
os.remove("model.json") | |
os.remove("state.pt") | |
def instantiate_and_test_model(readout_field, final_mlp_type): | |
model = CHGNet( | |
element_types=["Mo", "S"], | |
activation_type=activation, | |
bond_update_hidden_dims=bond_dim, | |
learn_basis=learn_basis, | |
angle_update_hidden_dims=angle_dim, | |
conv_dropout=dropout, | |
readout_field=readout_field, | |
final_mlp_type=final_mlp_type, | |
) | |
global_out = model(g=graph) | |
assert torch.numel(global_out) == 1 | |
assert torch.numel(graph.ndata["magmom"]) == graph.num_nodes() | |
model.save(".") | |
CHGNet.load(".") | |
os.remove("model.pt") | |
os.remove("model.json") | |
os.remove("state.pt") |
Summary
Increase more code coverage in _chgnet.py
Checklist
ruff
.mypy
.duecredit
@due.dcite
decorators to reference relevant papers by DOI (example)Tip: Install
pre-commit
hooks to auto-check types and linting before every commit: