Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding complete forward and backward pass for ferminet pretraining #3553

Merged
merged 21 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 22 additions & 0 deletions deepchem/models/tests/test_ferminet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,25 @@ def test_forward():
mol = FerminetModel(FH_molecule, spin=1, ion_charge=-1)
result = mol.model.forward(mol.molecule.x)
assert result.size() == torch.Size([8])


@pytest.mark.dqc
def test_evaluate_hf_solution():
# Test for the evaluate_hf_solution function of FerminetModel class
H2_molecule = [['F', [0, 0, 0]], ['He', [0, 0, 0.748]]]
mol = FerminetModel(H2_molecule, spin=1, ion_charge=0)
electron_coordinates = np.random.rand(10, 11, 3)
spin_up_orbitals, spin_down_orbitals = mol.evaluate_hf(electron_coordinates)
# The solution should be of the shape (number of electrons, number of electrons)
assert np.shape(spin_up_orbitals) == (10, 6, 6)
assert np.shape(spin_down_orbitals) == (10, 5, 5)


@pytest.mark.dqc
def test_FerminetMode_pretrain():
# Test for the init function of FerminetModel class
H2_molecule = [['H', [0, 0, 0]], ['H', [0, 0, 0.748]]]
# Testing ionic initialization
mol = FerminetModel(H2_molecule, spin=0, ion_charge=0)
mol.train(nb_epoch=3)
assert mol.loss_value <= torch.tensor(1.0)
143 changes: 127 additions & 16 deletions deepchem/models/torch_models/ferminet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Implementation of the Ferminet class in pytorch
"""

from typing import List, Optional
from typing import List, Optional, Tuple
# import torch.nn as nn
from rdkit import Chem
import numpy as np
Expand All @@ -14,12 +14,6 @@
from deepchem.utils.electron_sampler import ElectronSampler


def test_f(x: np.ndarray) -> np.ndarray:
# dummy function which can be passed as the parameter f. f gives the log probability
# TODO replace this function with forward pass of the model in future
return 2 * np.log(np.random.uniform(low=0, high=1.0, size=np.shape(x)[0]))


class Ferminet(torch.nn.Module):
"""A deep-learning based Variational Monte Carlo method [1]_ for calculating the ab-initio
solution of a many-electron system.
Expand Down Expand Up @@ -71,6 +65,8 @@ def __init__(self,

Attributes
----------
running_diff: torch.Tensor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding an attribute tensor to keep track of running sum of MSELoss

torch tensor containing the loss which gets updated for each random walk performed
ferminet_layer: torch.nn.ModuleList
Modulelist containing the ferminet electron feature layer
ferminet_layer_envelope: torch.nn.ModuleList
Expand All @@ -95,6 +91,7 @@ def __init__(self,
self.ferminet_layer: torch.nn.ModuleList = torch.nn.ModuleList()
self.ferminet_layer_envelope: torch.nn.ModuleList = torch.nn.ModuleList(
)
self.running_diff: torch.Tensor = torch.zeros(self.batch_size)

self.ferminet_layer.append(
FerminetElectronFeature(self.n_one, self.n_two,
Expand All @@ -106,7 +103,7 @@ def __init__(self,
self.batch_size, [self.spin[0], self.spin[1]],
self.nucleon_pos.size()[0], self.determinant))

def forward(self, input) -> torch.Tensor:
def forward(self, input: np.ndarray) -> torch.Tensor:
"""
forward function

Expand Down Expand Up @@ -148,6 +145,30 @@ def forward(self, input) -> torch.Tensor:
0].forward(one_electron, one_electron_vector_permuted)
return psi

def loss(self,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss functions, right now its implemented for pertaining (MSELoss)

psi_up_mo: List[Optional[np.ndarray]] = [None],
psi_down_mo: List[Optional[np.ndarray]] = [None],
pretrain: List[bool] = [True]):
"""
Implements the loss function for both pretraining and the actual training parts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow up PR, please add more details on the loss. You should be able to list the actual latex for the formulas as well

Parameters
----------
psi_up_mo: List[Optional[np.ndarray]] (default [None])
numpy array containing the sampled hartreee fock up-spin orbitals
psi_down_mo: List[Optional[np.ndarray]] (default [None])
numpy array containing the sampled hartreee fock down-spin orbitals
pretrain: List[bool] (default [True])
indicates whether the model is pretraining
"""
criterion = torch.nn.MSELoss()
if pretrain:
psi_up_mo_torch = torch.from_numpy(psi_up_mo).unsqueeze(1)
psi_down_mo_torch = torch.from_numpy(psi_down_mo).unsqueeze(1)
self.running_diff = self.running_diff + criterion(
self.psi_up, psi_up_mo_torch.float()) + criterion(
self.psi_down, psi_down_mo_torch.float())


class FerminetModel(TorchModel):
"""A deep-learning based Variational Monte Carlo method [1]_ for calculating the ab-initio
Expand All @@ -162,6 +183,15 @@ class FerminetModel(TorchModel):

This method is based on the following paper:

Example
-------
>>> from deepchem.models.torch_models.Ferminet import FerminetModel
>>> H2_molecule = [['H', [0, 0, 0]], ['H', [0, 0, 0.748]]]
>>> mol = FerminetModel(H2_molecule, spin=0, ion_charge=0, training='pretraining')
>>> mol.train(nb_epoch=3)
>>> print(mol.model.psi_up.size())
torch.Size([1, 1])

References
----------
.. [1] Spencer, James S., et al. Better, Faster Fermionic Neural Networks. arXiv:2011.07125, arXiv, 13 Nov. 2020. arXiv.org, http://arxiv.org/abs/2011.07125.
Expand All @@ -177,8 +207,9 @@ def __init__(self,
ion_charge: int,
seed: Optional[int] = None,
batch_no: int = 8,
random_walk_steps=10,
steps_per_update=10):
random_walk_steps: int = 10,
steps_per_update: int = 10,
tasks: str = 'pretraining'):
"""
Parameters:
-----------
Expand All @@ -196,6 +227,8 @@ def __init__(self,
Number of random walk steps to be performed in a single move.
steps_per_update: int (default: 10)
Number of steps after which the electron sampler should update the electron parameters.
tasks: str (default: 'pretraining')
The type of task to be performed - 'pretraining', 'training'

Attributes:
-----------
Expand All @@ -205,10 +238,8 @@ def __init__(self,
Torch tensor containing electrons for each atom in the nucleus
molecule: ElectronSampler
ElectronSampler object which performs MCMC and samples electrons
loss_value: Optional[torch.Tensor] (default None)
loss_value: torch.Tensor (default torch.tensor(0))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attribute that keeps track of the last loss value. It differs form running mean by calculating the mean over all batches and the number of steps in MCMC

torch tensor storing the loss value from the last iteration
pretraining_loss_list: List (default [])
list with losses for every epoch
"""
self.nucleon_coordinates = nucleon_coordinates
self.seed = seed
Expand All @@ -218,8 +249,8 @@ def __init__(self,
self.batch_no = batch_no
self.random_walk_steps = random_walk_steps
self.steps_per_update = steps_per_update
self.loss_value: Optional[torch.Tensor] = None
self.pretraining_loss_list: List = []
self.loss_value: torch.Tensor = torch.tensor(0)
self.tasks = tasks

no_electrons = []
nucleons = []
Expand Down Expand Up @@ -280,7 +311,8 @@ def __init__(self,
batch_no=self.batch_no,
central_value=self.nucleon_pos,
seed=self.seed,
f=lambda x: test_f(x), # Will be replaced in successive PR
f=lambda x: self.random_walk(x
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing the sampling function to the random walk function which returns the log wavefucntion probability

), # Will be replaced in successive PR
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to remove the comment in your next update

steps=self.random_walk_steps,
steps_per_update=self.steps_per_update
) # sample the electrons using the electron sampler
Expand All @@ -291,6 +323,34 @@ def __init__(self,
self.model,
loss=torch.nn.MSELoss()) # will update the loss in successive PR

def evaluate_hf(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function to evaluate hf orbitals at the given electron coordinates. (As mentioned in the last PR split the prepare_hf function into smaller ones)

"""
Helper function to calculate orbital values at sampled electron's position.

Parameters:
-----------
x: np.ndarray
Contains the sampled electrons coordinates in a numpy array.

Returns:
--------
2 numpy arrays containing the up-spin and down-spin orbitals in a numpy array respectively.
"""
x = np.reshape(x, [-1, 3 * (self.up_spin + self.down_spin)])
leading_dims = x.shape[:-1]
x = np.reshape(x, [-1, 3])
coeffs = self.mf.mo_coeff
gto_op = 'GTOval_sph'
ao_values = self.mol.eval_gto(gto_op, x)
mo_values = tuple(np.matmul(ao_values, coeff) for coeff in coeffs)
mo_values_list = [
np.reshape(mo, leading_dims + (self.up_spin + self.down_spin, -1))
for mo in mo_values
]
return mo_values_list[0][
..., :self.up_spin, :self.up_spin], mo_values_list[1][
..., self.up_spin:, :self.down_spin]

def prepare_hf_solution(self):
"""Prepares the HF solution for the molecule system which is to be used in pretraining
"""
Expand All @@ -313,3 +373,54 @@ def prepare_hf_solution(self):
self.mol.build(parse_arg=False)
self.mf = pyscf.scf.UHF(self.mol)
_ = self.mf.kernel()

def random_walk(self, x: np.ndarray) -> np.ndarray:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function gets called at every step of electron sampling, and it returns the probability of the sampled electrons - log of the wavefunction from the model and the HF solution (this is only for pretraining).

"""
Function to be passed on to electron sampler for random walk and gets called at each step of sampling

Parameters
----------
x: np.ndarray
contains the sampled electrons coordinate in the shape (batch_size,number_of_electrons*3)

Returns:
--------
A numpy array containing the joint probability of the hartree fock and the sampled electron's position coordinates
"""
output = self.model.forward(x)
np_output = output.detach().cpu().numpy()
up_spin_mo, down_spin_mo = self.evaluate_hf(x)
hf_product = np.prod(
np.diagonal(up_spin_mo, axis1=1, axis2=2)**2, axis=1) * np.prod(
np.diagonal(down_spin_mo, axis1=1, axis2=2)**2, axis=1)
self.model.loss(up_spin_mo, down_spin_mo, pretrain=True)
return np.log(hf_product + np_output**2) + np.log(0.5)

def train(self,
nb_epoch: int = 200,
lr: float = 0.0075,
weight_decay: float = 0.0001):
"""
function to run training or pretraining.

Parameters
----------
nb_epoch: int (default: 200)
contains the number of pretraining steps to be performed
lr : float (default: 0.0075)
contains the learning rate for the model fitting
weight_decay: float (default: 0.0001)
contains the weight_decay for the model fitting
"""
optimizer = torch.optim.Adam(self.model.parameters(),
lr=lr,
weight_decay=weight_decay)
if (self.tasks == 'pretraining'):
for _ in range(nb_epoch):
optimizer.zero_grad()
self.molecule.move()
self.loss_value = (torch.mean(self.model.running_diff) /
self.random_walk_steps)
self.loss_value.backward()
optimizer.step()
self.model.running_diff = torch.zeros(self.batch_no)