Change input data structure from named tuple to a class with slots#287
Merged
Change input data structure from named tuple to a class with slots#287
Conversation
…to a class with slots. The reason for this is that during MD only positions change, so we want t datastructure that doesn't add overhead when only changing a single field (e.g. positions)
…rch JIT and JAX accept --- this is a first working attempt. Note that this requires that we do a global jax import. @chrisiacovella, please take a closer look at this
Member
|
I think this is a good change; one concern, you've now started to introduce "per_system" instead of "per_molecule", but it's not consistent. I think that should be changed over fully before merging, otherwise that is going to be very confusing. |
chrisiacovella
requested changes
Oct 17, 2024
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Pull Request Summary
This PR addresses the issue #286.
We had two data structures (NNPInput and NNPInputTuple) to pass information from the dataset to the models. The NNPInput class was designed as dataclass and did input validation and casting to dtype/device during training. The dataclass then constructed a NamedTuple that contained all fields of the dataclass, which was then passed as input to the model --- in other words, the model only ever got the NamedTuple as input. To allow working with Jax models, the Named Tuple contained either jax arrays or pytorch tensors depending on passed keywords.
NamedTuples are useful during training, but in MD simulations the immutable datastructure leads to overhead.
Key changes
EandF`` toper_system_energyandper_atom_force`This PR proposes the following changes:
Pull Request Checklist