-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding function to calculate HF orbitals and made minor reformat - Ferminet #3466
Conversation
4653e95
to
ee9382c
Compare
except ModuleNotFoundError: | ||
pass | ||
|
||
|
||
@pytest.mark.torch | ||
def test_prepare_input_stream(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we deleting this older test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old test checks for input streams shape, which is nothing but each electrons distance from each nuclei, this can be done using deepchem utils pairwise distance itself and doesn't require tests!
@@ -22,9 +24,55 @@ def test_f(x: np.ndarray) -> np.ndarray: | |||
return 2 * np.log(np.random.uniform(low=0, high=1.0, size=np.shape(x)[0])) | |||
|
|||
|
|||
class Ferminet: | |||
class Ferminet(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this new layer to torch_layers.csv and layers.rst in the docs/?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I do this once the model is fully completed? As I am breaking the chunk of PR into smaller ones and am adding one function by function.
@@ -22,9 +24,55 @@ def test_f(x: np.ndarray) -> np.ndarray: | |||
return 2 * np.log(np.random.uniform(low=0, high=1.0, size=np.shape(x)[0])) | |||
|
|||
|
|||
class Ferminet: | |||
class Ferminet(torch.nn.Module): | |||
"""Approximates the log probability of the wave function of a molecule system using DNNs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a usage example here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I do this once the model is fully completed? As I am breaking the chunk of PR into smaller ones and am adding one function by function.
self.determinant = determinant | ||
|
||
|
||
class FerminetModel(TorchModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this class to the docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I do this once the model is fully completed? As I am breaking the chunk of PR into smaller ones and am adding one function by function.
self.nucleon_coordinates = nucleon_coordinates | ||
self.seed = seed | ||
self.batch_no = batch_no | ||
self.spin = spin | ||
self.ion_charge = charge | ||
|
||
def prepare_input_stream(self,) -> Tuple[Any, Any, Any, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we deleting this helper function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This helper function computes pairwise distances of each electron from nuclei and can be done in forward itself
@@ -124,46 +164,77 @@ def prepare_input_stream(self,) -> Tuple[Any, Any, Any, Any]: | |||
self.electron_no[int(electro_neg[-1 - iter][0])][0] += 1 | |||
|
|||
total_electrons = np.sum(self.electron_no) | |||
self.up_spin = (total_electrons + 2 * self.spin) // 2 | |||
self.down_spin = (total_electrons - 2 * self.spin) // 2 | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you give me a summary of the refactoring changes you've made here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes I have done are:
- added a new helper function prepare_hf which gives hartree fock orbitals values for pretraining.
- Deleted the code to prepare input streams as it should be done in forwards, as we need the gradient of output wrt to the electron's coordinates we are going to pass
For now, this model is empty and have added the helper function only. In successive PRs I will complete the forward function and add example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good start. I have some questions and requests for more docs below
@rbharath I have reformatted to numpydocs style. Is this fine? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Added prepare_hf_solution function to calculate hartree fock solution required for pertaining.
Reformatted the code into Ferminet class (class where ferminet model will implemented) and FerminetModel class (driver class providing all the necessary data to Ferminet class)
Fix #(issue)
Fixed the calculation of the spin of the molecule system
Type of change
Please check the option that is related to your PR.
Checklist
yapf -i <modified file>
and check no errors (yapf version must be 0.32.0)mypy -p deepchem
and check no errorsflake8 <modified file> --count
and check no errorspython -m doctest <modified file>
and check no errors