-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding complete forward and backward pass for ferminet pretraining #3553
Changes from all commits
d4cf0d4
f3ecd09
c584568
996d189
59821a7
3bd274d
8006a91
3a81881
3c88ead
ee92a6d
0b64a1f
e3429c2
157af88
360f84c
382223f
c329f21
15b3404
7690834
1a2dd18
92fab49
c000e3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
Implementation of the Ferminet class in pytorch | ||
""" | ||
|
||
from typing import List, Optional | ||
from typing import List, Optional, Tuple | ||
# import torch.nn as nn | ||
from rdkit import Chem | ||
import numpy as np | ||
|
@@ -14,12 +14,6 @@ | |
from deepchem.utils.electron_sampler import ElectronSampler | ||
|
||
|
||
def test_f(x: np.ndarray) -> np.ndarray: | ||
# dummy function which can be passed as the parameter f. f gives the log probability | ||
# TODO replace this function with forward pass of the model in future | ||
return 2 * np.log(np.random.uniform(low=0, high=1.0, size=np.shape(x)[0])) | ||
|
||
|
||
class Ferminet(torch.nn.Module): | ||
"""A deep-learning based Variational Monte Carlo method [1]_ for calculating the ab-initio | ||
solution of a many-electron system. | ||
|
@@ -71,6 +65,8 @@ def __init__(self, | |
|
||
Attributes | ||
---------- | ||
running_diff: torch.Tensor | ||
torch tensor containing the loss which gets updated for each random walk performed | ||
ferminet_layer: torch.nn.ModuleList | ||
Modulelist containing the ferminet electron feature layer | ||
ferminet_layer_envelope: torch.nn.ModuleList | ||
|
@@ -95,6 +91,7 @@ def __init__(self, | |
self.ferminet_layer: torch.nn.ModuleList = torch.nn.ModuleList() | ||
self.ferminet_layer_envelope: torch.nn.ModuleList = torch.nn.ModuleList( | ||
) | ||
self.running_diff: torch.Tensor = torch.zeros(self.batch_size) | ||
|
||
self.ferminet_layer.append( | ||
FerminetElectronFeature(self.n_one, self.n_two, | ||
|
@@ -106,7 +103,7 @@ def __init__(self, | |
self.batch_size, [self.spin[0], self.spin[1]], | ||
self.nucleon_pos.size()[0], self.determinant)) | ||
|
||
def forward(self, input) -> torch.Tensor: | ||
def forward(self, input: np.ndarray) -> torch.Tensor: | ||
""" | ||
forward function | ||
|
||
|
@@ -148,6 +145,30 @@ def forward(self, input) -> torch.Tensor: | |
0].forward(one_electron, one_electron_vector_permuted) | ||
return psi | ||
|
||
def loss(self, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. loss functions, right now its implemented for pertaining (MSELoss) |
||
psi_up_mo: List[Optional[np.ndarray]] = [None], | ||
psi_down_mo: List[Optional[np.ndarray]] = [None], | ||
pretrain: List[bool] = [True]): | ||
""" | ||
Implements the loss function for both pretraining and the actual training parts. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a follow up PR, please add more details on the loss. You should be able to list the actual latex for the formulas as well |
||
Parameters | ||
---------- | ||
psi_up_mo: List[Optional[np.ndarray]] (default [None]) | ||
numpy array containing the sampled hartreee fock up-spin orbitals | ||
psi_down_mo: List[Optional[np.ndarray]] (default [None]) | ||
numpy array containing the sampled hartreee fock down-spin orbitals | ||
pretrain: List[bool] (default [True]) | ||
indicates whether the model is pretraining | ||
""" | ||
criterion = torch.nn.MSELoss() | ||
if pretrain: | ||
psi_up_mo_torch = torch.from_numpy(psi_up_mo).unsqueeze(1) | ||
psi_down_mo_torch = torch.from_numpy(psi_down_mo).unsqueeze(1) | ||
self.running_diff = self.running_diff + criterion( | ||
self.psi_up, psi_up_mo_torch.float()) + criterion( | ||
self.psi_down, psi_down_mo_torch.float()) | ||
|
||
|
||
class FerminetModel(TorchModel): | ||
"""A deep-learning based Variational Monte Carlo method [1]_ for calculating the ab-initio | ||
|
@@ -162,6 +183,15 @@ class FerminetModel(TorchModel): | |
|
||
This method is based on the following paper: | ||
|
||
Example | ||
------- | ||
>>> from deepchem.models.torch_models.Ferminet import FerminetModel | ||
>>> H2_molecule = [['H', [0, 0, 0]], ['H', [0, 0, 0.748]]] | ||
>>> mol = FerminetModel(H2_molecule, spin=0, ion_charge=0, training='pretraining') | ||
>>> mol.train(nb_epoch=3) | ||
>>> print(mol.model.psi_up.size()) | ||
torch.Size([1, 1]) | ||
|
||
References | ||
---------- | ||
.. [1] Spencer, James S., et al. Better, Faster Fermionic Neural Networks. arXiv:2011.07125, arXiv, 13 Nov. 2020. arXiv.org, http://arxiv.org/abs/2011.07125. | ||
|
@@ -177,8 +207,9 @@ def __init__(self, | |
ion_charge: int, | ||
seed: Optional[int] = None, | ||
batch_no: int = 8, | ||
random_walk_steps=10, | ||
steps_per_update=10): | ||
random_walk_steps: int = 10, | ||
steps_per_update: int = 10, | ||
tasks: str = 'pretraining'): | ||
""" | ||
Parameters: | ||
----------- | ||
|
@@ -196,6 +227,8 @@ def __init__(self, | |
Number of random walk steps to be performed in a single move. | ||
steps_per_update: int (default: 10) | ||
Number of steps after which the electron sampler should update the electron parameters. | ||
tasks: str (default: 'pretraining') | ||
The type of task to be performed - 'pretraining', 'training' | ||
|
||
Attributes: | ||
----------- | ||
|
@@ -205,10 +238,8 @@ def __init__(self, | |
Torch tensor containing electrons for each atom in the nucleus | ||
molecule: ElectronSampler | ||
ElectronSampler object which performs MCMC and samples electrons | ||
loss_value: Optional[torch.Tensor] (default None) | ||
loss_value: torch.Tensor (default torch.tensor(0)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. attribute that keeps track of the last loss value. It differs form running mean by calculating the mean over all batches and the number of steps in MCMC |
||
torch tensor storing the loss value from the last iteration | ||
pretraining_loss_list: List (default []) | ||
list with losses for every epoch | ||
""" | ||
self.nucleon_coordinates = nucleon_coordinates | ||
self.seed = seed | ||
|
@@ -218,8 +249,8 @@ def __init__(self, | |
self.batch_no = batch_no | ||
self.random_walk_steps = random_walk_steps | ||
self.steps_per_update = steps_per_update | ||
self.loss_value: Optional[torch.Tensor] = None | ||
self.pretraining_loss_list: List = [] | ||
self.loss_value: torch.Tensor = torch.tensor(0) | ||
self.tasks = tasks | ||
|
||
no_electrons = [] | ||
nucleons = [] | ||
|
@@ -280,7 +311,8 @@ def __init__(self, | |
batch_no=self.batch_no, | ||
central_value=self.nucleon_pos, | ||
seed=self.seed, | ||
f=lambda x: test_f(x), # Will be replaced in successive PR | ||
f=lambda x: self.random_walk(x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changing the sampling function to the random walk function which returns the log wavefucntion probability |
||
), # Will be replaced in successive PR | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sure to remove the comment in your next update |
||
steps=self.random_walk_steps, | ||
steps_per_update=self.steps_per_update | ||
) # sample the electrons using the electron sampler | ||
|
@@ -291,6 +323,34 @@ def __init__(self, | |
self.model, | ||
loss=torch.nn.MSELoss()) # will update the loss in successive PR | ||
|
||
def evaluate_hf(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function to evaluate hf orbitals at the given electron coordinates. (As mentioned in the last PR split the prepare_hf function into smaller ones) |
||
""" | ||
Helper function to calculate orbital values at sampled electron's position. | ||
|
||
Parameters: | ||
----------- | ||
x: np.ndarray | ||
Contains the sampled electrons coordinates in a numpy array. | ||
|
||
Returns: | ||
-------- | ||
2 numpy arrays containing the up-spin and down-spin orbitals in a numpy array respectively. | ||
""" | ||
x = np.reshape(x, [-1, 3 * (self.up_spin + self.down_spin)]) | ||
leading_dims = x.shape[:-1] | ||
x = np.reshape(x, [-1, 3]) | ||
coeffs = self.mf.mo_coeff | ||
gto_op = 'GTOval_sph' | ||
ao_values = self.mol.eval_gto(gto_op, x) | ||
mo_values = tuple(np.matmul(ao_values, coeff) for coeff in coeffs) | ||
mo_values_list = [ | ||
np.reshape(mo, leading_dims + (self.up_spin + self.down_spin, -1)) | ||
for mo in mo_values | ||
] | ||
return mo_values_list[0][ | ||
..., :self.up_spin, :self.up_spin], mo_values_list[1][ | ||
..., self.up_spin:, :self.down_spin] | ||
|
||
def prepare_hf_solution(self): | ||
"""Prepares the HF solution for the molecule system which is to be used in pretraining | ||
""" | ||
|
@@ -313,3 +373,54 @@ def prepare_hf_solution(self): | |
self.mol.build(parse_arg=False) | ||
self.mf = pyscf.scf.UHF(self.mol) | ||
_ = self.mf.kernel() | ||
|
||
def random_walk(self, x: np.ndarray) -> np.ndarray: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function gets called at every step of electron sampling, and it returns the probability of the sampled electrons - log of the wavefunction from the model and the HF solution (this is only for pretraining). |
||
""" | ||
Function to be passed on to electron sampler for random walk and gets called at each step of sampling | ||
|
||
Parameters | ||
---------- | ||
x: np.ndarray | ||
contains the sampled electrons coordinate in the shape (batch_size,number_of_electrons*3) | ||
|
||
Returns: | ||
-------- | ||
A numpy array containing the joint probability of the hartree fock and the sampled electron's position coordinates | ||
""" | ||
output = self.model.forward(x) | ||
np_output = output.detach().cpu().numpy() | ||
up_spin_mo, down_spin_mo = self.evaluate_hf(x) | ||
hf_product = np.prod( | ||
np.diagonal(up_spin_mo, axis1=1, axis2=2)**2, axis=1) * np.prod( | ||
np.diagonal(down_spin_mo, axis1=1, axis2=2)**2, axis=1) | ||
self.model.loss(up_spin_mo, down_spin_mo, pretrain=True) | ||
return np.log(hf_product + np_output**2) + np.log(0.5) | ||
|
||
def train(self, | ||
nb_epoch: int = 200, | ||
lr: float = 0.0075, | ||
weight_decay: float = 0.0001): | ||
""" | ||
function to run training or pretraining. | ||
|
||
Parameters | ||
---------- | ||
nb_epoch: int (default: 200) | ||
contains the number of pretraining steps to be performed | ||
lr : float (default: 0.0075) | ||
contains the learning rate for the model fitting | ||
weight_decay: float (default: 0.0001) | ||
contains the weight_decay for the model fitting | ||
""" | ||
optimizer = torch.optim.Adam(self.model.parameters(), | ||
lr=lr, | ||
weight_decay=weight_decay) | ||
if (self.tasks == 'pretraining'): | ||
for _ in range(nb_epoch): | ||
optimizer.zero_grad() | ||
self.molecule.move() | ||
self.loss_value = (torch.mean(self.model.running_diff) / | ||
self.random_walk_steps) | ||
self.loss_value.backward() | ||
optimizer.step() | ||
self.model.running_diff = torch.zeros(self.batch_no) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding an attribute tensor to keep track of running sum of MSELoss