Skip to content

Change input data structure from named tuple to a class with slots#287

Merged
wiederm merged 25 commits intomainfrom
dev-mutable-input
Oct 19, 2024
Merged

Change input data structure from named tuple to a class with slots#287
wiederm merged 25 commits intomainfrom
dev-mutable-input

Conversation

@wiederm
Copy link
Copy Markdown
Member

@wiederm wiederm commented Oct 16, 2024

Pull Request Summary

This PR addresses the issue #286.

We had two data structures (NNPInput and NNPInputTuple) to pass information from the dataset to the models. The NNPInput class was designed as dataclass and did input validation and casting to dtype/device during training. The dataclass then constructed a NamedTuple that contained all fields of the dataclass, which was then passed as input to the model --- in other words, the model only ever got the NamedTuple as input. To allow working with Jax models, the Named Tuple contained either jax arrays or pytorch tensors depending on passed keywords.

NamedTuples are useful during training, but in MD simulations the immutable datastructure leads to overhead.

Key changes

  • Change NamedTuple datastructure to a Python class with slots, and remove the double structure dataclass/namedtuple that was used for convenience.
  • change names in NNPInput and Metadata classes from E and F`` to per_system_energyandper_atom_force`

This PR proposes the following changes:

Pull Request Checklist

  • Issue(s) raised/addressed and linked
  • Includes appropriate unit test(s)
  • Appropriate docstring(s) added/updated
  • Appropriate .rst doc file(s) added/updated
  • PR is ready for review

…to a class with slots. The reason for this is that during MD only positions change, so we want t datastructure that doesn't add overhead when only changing a single field (e.g. positions)
@wiederm wiederm self-assigned this Oct 16, 2024
@wiederm wiederm linked an issue Oct 16, 2024 that may be closed by this pull request
@wiederm wiederm added the refactoring Improve the quality of the code without functional changes label Oct 17, 2024
@wiederm wiederm changed the base branch from main to dev-load-model-from-chkpt October 17, 2024 17:03
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Oct 17, 2024

Codecov Report

Attention: Patch coverage is 84.46602% with 32 lines in your changes missing coverage. Please review.

Project coverage is 84.97%. Comparing base (442df66) to head (a6a791c).
Report is 26 commits behind head on main.

Additional details and impacted files

Comment thread modelforge/dataset/dataset.py Outdated
@chrisiacovella
Copy link
Copy Markdown
Member

I think this is a good change; one concern, you've now started to introduce "per_system" instead of "per_molecule", but it's not consistent. I think that should be changed over fully before merging, otherwise that is going to be very confusing.

Comment thread modelforge/dataset/dataset.py Outdated
Comment thread modelforge/dataset/dataset.py Outdated
Comment thread modelforge/potential/aimnet2.py
Comment thread modelforge/tests/test_models.py Outdated
Comment thread modelforge/train/training.py Outdated
Base automatically changed from dev-load-model-from-chkpt to main October 19, 2024 10:20
@wiederm wiederm merged commit f56dc12 into main Oct 19, 2024
@wiederm wiederm deleted the dev-mutable-input branch October 19, 2024 20:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

refactoring Improve the quality of the code without functional changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Current input data structure not optimized for MD

3 participants