Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FermiNet Training Complete (Bacward pass + Forward) #3689

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 14 additions & 2 deletions deepchem/models/tests/test_ferminet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_evaluate_hf_solution():


@pytest.mark.dqc
def test_FerminetMode_pretrain():
def test_FerminetModel_pretrain():
rbharath marked this conversation as resolved.
Show resolved Hide resolved
# Test for the init function of FerminetModel class
H2_molecule = [['H', [0, 0, 0]], ['H', [0, 0, 0.748]]]
# Testing ionic initialization
Expand All @@ -62,7 +62,7 @@ def test_FerminetMode_pretrain():


@pytest.mark.dqc
def test_FerminetMode_energy():
def test_FerminetModel_energy():
# Test for the init function of FerminetModel class
H2_molecule = [['H', [0, 0, 0]], ['H', [0, 0, 0.748]]]
# Testing ionic initialization
Expand All @@ -74,3 +74,15 @@ def test_FerminetMode_energy():
)
mean_energy = torch.mean(energy)
assert mean_energy <= torch.tensor(1.0)


@pytest.mark.dqc
def test_FerminetModel_train():
# Test for the init function of FerminetModel class
H2_molecule = [['H', [0, 0, 0]], ['H', [0, 0, 0.748]]]
# Testing ionic initialization
mol = FerminetModel(H2_molecule, spin=0, ion_charge=0)
mol.train(nb_epoch=10)
mol.prepare_train()
mol.train(nb_epoch=10)
assert mol.final_energy <= torch.tensor(0.0)
157 changes: 137 additions & 20 deletions deepchem/models/torch_models/ferminet.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,19 @@ def loss(self,
indicates whether the model is pretraining
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A List[bool] is very awkward; let's swap this to True/False in a subsequent PR

"""
criterion = torch.nn.MSELoss()
if pretrain:
if pretrain[0]:
psi_up_mo_torch = torch.from_numpy(psi_up_mo).unsqueeze(1)
psi_down_mo_torch = torch.from_numpy(psi_down_mo).unsqueeze(1)
self.running_diff = self.running_diff + criterion(
self.psi_up, psi_up_mo_torch.float()) + criterion(
self.psi_down, psi_down_mo_torch.float())
self.running_diff = self.running_diff.float() + criterion(
rbharath marked this conversation as resolved.
Show resolved Hide resolved
self.psi_up.float(),
psi_up_mo_torch.float()).float() + criterion(
self.psi_down.float(), psi_down_mo_torch.float()).float()
else:
energy = self.calculate_electron_electron(
) - self.calculate_electron_nuclear(
) + self.nuclear_nuclear_potential + self.calculate_kinetic_energy(
)
return energy.detach()

def calculate_nuclear_nuclear(self,) -> torch.Tensor:
"""
Expand Down Expand Up @@ -405,7 +412,8 @@ def __init__(self,
steps_per_update=self.steps_per_update
) # sample the electrons using the electron sampler
self.molecule.gauss_initialize_position(
self.electron_no) # initialize the position of the electrons
self.electron_no,
stddev=1.0) # initialize the position of the electrons
self.prepare_hf_solution()
super(FerminetModel, self).__init__(
self.model,
Expand Down Expand Up @@ -462,7 +470,7 @@ def prepare_hf_solution(self):
self.mf = pyscf.scf.UHF(self.mol)
_ = self.mf.kernel()

def random_walk(self, x: np.ndarray) -> np.ndarray:
def random_walk(self, x: np.ndarray):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doing this to avoid Mypy errors

"""
Function to be passed on to electron sampler for random walk and gets called at each step of sampling

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document why the random walk is different for pretraining and finetuning

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also document burn as a phase

Expand All @@ -476,42 +484,151 @@ def random_walk(self, x: np.ndarray) -> np.ndarray:
A numpy array containing the joint probability of the hartree fock and the sampled electron's position coordinates
"""
x_torch = torch.from_numpy(x).view(self.batch_no, -1, 3)
x_torch.requires_grad = True
if self.tasks == 'pretraining':
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding pretraining and training part separately in the random_Walk fucntion

x_torch.requires_grad = True
else:
x_torch.requires_grad = False
output = self.model.forward(x_torch)

np_output = output.detach().cpu().numpy()
up_spin_mo, down_spin_mo = self.evaluate_hf(x)
hf_product = np.prod(
np.diagonal(up_spin_mo, axis1=1, axis2=2)**2, axis=1) * np.prod(
np.diagonal(down_spin_mo, axis1=1, axis2=2)**2, axis=1)
self.model.loss(up_spin_mo, down_spin_mo, pretrain=True)
return np.log(hf_product + np_output**2) + np.log(0.5)

if self.tasks == 'pretraining':
up_spin_mo, down_spin_mo = self.evaluate_hf(x)
hf_product = np.prod(
np.diagonal(up_spin_mo, axis1=1, axis2=2), axis=1) * np.prod(
np.diagonal(down_spin_mo, axis1=1, axis2=2), axis=1)
self.model.loss(up_spin_mo, down_spin_mo)
np_output[:int(self.batch_no / 2)] = hf_product[:int(self.batch_no /
2)]
return 2 * np.log(np.abs(np_output))

if self.tasks == 'burn':
return 2 * np.log(np.abs(np_output))

if self.tasks == 'training':
energy = self.model.loss(pretrain=[False])
self.energy_sampled: torch.Tensor = torch.cat(
(self.energy_sampled, energy.unsqueeze(0)))
return 2 * np.log(np.abs(np_output))

def prepare_train(self, burn_in: int = 100):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function performs burn-in and changes the parameter before training is done

"""
Function to perform burn-in and to change the model parameters for training.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More details on why this is necessary


Parameters
----------
burn_in:int (default: 100)
number of steps for to perform burn-in before the aactual training.
"""
self.tasks = 'burn'
self.molecule.gauss_initialize_position(self.electron_no, stddev=1.0)
tmp_x = self.molecule.x
for _ in range(burn_in):
self.molecule.move(stddev=0.02)
self.molecule.x = tmp_x
self.molecule.move(stddev=0.02)
self.tasks = 'training'

def train(self,
nb_epoch: int = 200,
lr: float = 0.0075,
weight_decay: float = 0.0001):
lr: float = 0.002,
weight_decay: float = 0,
std: float = 0.08,
std_init: float = 0.02,
steps_std: int = 100):
"""
function to run training or pretraining.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring should explain in detail why we need to overwrite the TorchModel implementation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain in detail how pretrain works vs train works. Multiple paragraphs please


Parameters
----------
nb_epoch: int (default: 200)
contains the number of pretraining steps to be performed
lr : float (default: 0.0075)
lr : float (default: 0.002)
contains the learning rate for the model fitting
weight_decay: float (default: 0.0001)
weight_decay: float (default: 0.002)
contains the weight_decay for the model fitting
std: float (default: 0.08)
The standard deviation for the electron update during training
std_init: float (default: 0.02)
The standard deviation for the electron update during pretraining
steps_std: float (default 100)
The number of steps for standard deviation increase
"""

# hook function below is an efficient way modifying the gradients on the go rather than looping
def energy_hook(grad, random_walk_steps):
"""
hook function to modify the gradients
"""
# using non-local variables as a means of parameter passing
nonlocal energy_local, energy_mean
new_grad = (2 / random_walk_steps) * (
(energy_local - energy_mean) * grad)
return new_grad.float()

optimizer = torch.optim.Adam(self.model.parameters(),
lr=lr,
weight_decay=weight_decay)

if (self.tasks == 'pretraining'):
for _ in range(nb_epoch):
for iteration in range(nb_epoch):
optimizer.zero_grad()
self.molecule.move()
accept = self.molecule.move(stddev=std_init)
if iteration % steps_std == 0:
if accept > 0.55:
std_init *= 1.1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Magic numbers are bad; These need to be documented or tunable parameters

else:
std_init /= 1.1
self.loss_value = (torch.mean(self.model.running_diff) /
self.random_walk_steps)
self.loss_value.backward()
optimizer.step()
self.model.running_diff = torch.zeros(self.batch_no)

if (self.tasks == 'training'):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding the training part

energy_local = None
optimizer = torch.optim.Adam(self.model.parameters(),
lr=lr,
weight_decay=weight_decay)
self.final_energy = torch.tensor(0.0)
with torch.no_grad():
hooks = list(
map(
lambda param: param.
register_hook(lambda grad: energy_hook(
grad, self.random_walk_steps)),
self.model.parameters()))
for iteration in range(nb_epoch):
optimizer.zero_grad()
self.energy_sampled = torch.tensor([])
# the move function calculates the energy of sampled electrons and samples new set of electrons (does not calculate loss)
accept = self.molecule.move(stddev=std)
if iteration % steps_std == 0:
if accept > 0.55:
std_init *= 1.2
else:
std_init /= 1.2
median, _ = torch.median(self.energy_sampled, 0)
variance = torch.mean(torch.abs(self.energy_sampled - median))
# clipping local energies which are away 5 times the variance from the median
clamped_energy = torch.clamp(self.energy_sampled,
max=median + 5 * variance,
min=median - 5 * variance)
energy_mean = torch.mean(clamped_energy)
self.final_energy = self.final_energy + energy_mean
# using the sampled electrons from the electron sampler for bacward pass and modifying gradients
sample_history = torch.from_numpy(
self.molecule.sampled_electrons).view(
self.random_walk_steps, self.batch_no, -1, 3)
optimizer.zero_grad()
for i in range(self.random_walk_steps):
# going through each step of random walk and calculating the modified gradients with local energies
input_electron = sample_history[i]
input_electron.requires_grad = True
energy_local = torch.mean(clamped_energy[i])
self.model.forward(input_electron)
self.loss_value = torch.mean(
torch.log(torch.abs(self.model.psi)))
self.loss_value.backward()
optimizer.step()
self.final_energy = self.final_energy / nb_epoch
list(map(lambda hook: hook.remove(), hooks))