Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding energy function for ferminet #3596

Merged
merged 3 commits into from
Oct 6, 2023

Conversation

shaipranesh2
Copy link
Contributor

Description

Fix #(issue)

Type of change

Please check the option that is related to your PR.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
    • In this case, we recommend to discuss your modification on GitHub issues before creating the PR
  • Documentations (modification for documents)

Checklist

  • My code follows the style guidelines of this project
    • Run yapf -i <modified file> and check no errors (yapf version must be 0.32.0)
    • Run mypy -p deepchem and check no errors
    • Run flake8 <modified file> --count and check no errors
    • Run python -m doctest <modified file> and check no errors
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New unit tests pass locally with my changes
  • I have checked my code and corrected any misspellings

@shaipranesh2 shaipranesh2 changed the title adding energy function adding energy function for ferminet Oct 5, 2023
mol = FerminetModel(H2_molecule, spin=0, ion_charge=0)
mol.train(nb_epoch=50)
mol.model.forward(mol.molecule.x)
energy = mol.model.calculate_electron_electron(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the actual way in which the total energy is calculated - K.E+E_E-N_E+N_N

K.E- kinetic energy
E_E- electron-electron potential energy
N_E- Nuclear-electron potential energy
E_E- Electron-Electron potential energy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this as a comment directly (can be done in a future PR)

@@ -118,12 +122,14 @@ def forward(self, input: np.ndarray) -> torch.Tensor:
contains the wavefunction - 'psi' value. It is in the shape (batch_size), where each row corresponds to the solution of one of the batches
"""
# creating one and two electron features
# torch.autograd.set_detect_anomaly(True)
eps = torch.tensor(1e-36)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding epsilon/eps so that the derivative of norm of x - 1/sqrt(x) will be defined even when x is 0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this a comment as well

self.input = torch.from_numpy(input)
self.input.requires_grad = True
self.input = self.input.reshape((self.batch_size, -1, 3))
two_electron_vector = self.input.unsqueeze(1) - self.input.unsqueeze(2)
two_electron_distance = torch.norm(two_electron_vector,
dim=3).unsqueeze(3)
two_electron_distance = torch.linalg.norm(two_electron_vector + eps,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only 2 electron distance needs norm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@@ -141,9 +147,9 @@ def forward(self, input: np.ndarray) -> torch.Tensor:

one_electron, _ = self.ferminet_layer[0].forward(
one_electron.to(torch.float32), two_electron.to(torch.float32))
psi, self.psi_up, self.psi_down = self.ferminet_layer_envelope[
self.psi, self.psi_up, self.psi_down = self.ferminet_layer_envelope[
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made it self.psi, so that methods can be called to calculate derivatives

"""
log_probability = torch.log(torch.abs(self.psi))
jacobian = list(
map(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map works here instead of loop and is faster too!

self.v[0].weight.data = self.v[0].weight.data
self.v[0].bias.data = self.v[0].bias.data

self.w.append(nn.Linear(4, self.n_two[0], bias=True))
self.w[0].weight.data.fill_(2.5e-7)
self.w[0].bias.data.fill_(2.5e-7)
self.w[0].weight.data.fill_(1e-3)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now fixing the initial weight to 1e-3, which gives me decent results for H2. In the Jax implementation of ferminet, they used a different initialization of weights. I will look into it and put it up all these in the last 'optimization' PR(better learning rates, better initialization width,etc)

self.w[i].bias.data = self.w[i].bias.data

self.projection_module = nn.ModuleList()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this from the suggestions made to me by Peter during my GSoC last year for adding skip connections even in first layer

map(lambda x: torch.sum(torch.pow(x, 2)), jacobian))
jacobian_square_sum = torch.tensor(0.0)
hessian = torch.tensor(0.0)
for i in range(self.batch_size):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to use https://pytorch.org/docs/master/generated/torch.vmap.html#torch.vmap instead to speed this up. Can be done in future PR

@@ -5011,9 +5020,12 @@ def forward(self, one_electron: torch.Tensor, two_electron: torch.Tensor):
dim=1)
if l == 0 or (self.n_one[l] != self.n_one[l - 1]) or (
self.n_two[l] != self.n_two[l - 1]):
one_electron_tmp[:, i, :] = torch.tanh(self.v[l](f))
one_electron_tmp[:, i, :] = torch.tanh(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't related to the energy changes

Copy link
Member

@rbharath rbharath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK to merge in once tests finish running

Copy link
Member

@rbharath rbharath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@rbharath rbharath merged commit 039c6a7 into deepchem:master Oct 6, 2023
23 of 33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants