Skip to content

Commit

Permalink
update stress training
Browse files Browse the repository at this point in the history
when extracting data, stress_new = stress_old*1000
  • Loading branch information
jzhang-github committed Jan 2, 2024
1 parent b723072 commit 9b9a51e
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .spyproject/config/backups/workspace.ini.bak
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ save_data_on_exit = True
save_history = True
save_non_project_files = False
project_type = 'empty-project-type'
recent_files = []
recent_files = ['C:\\Users\\ZHANG Jun\\.spyder-py3\\temp.py', 'agat\\model\\model.py', 'Change_log.md', 'agat\\test\\POSCAR', 'agat\\test\\test_stress_model.py']

[main]
version = 0.2.0
Expand Down
2 changes: 1 addition & 1 deletion .spyproject/config/workspace.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ save_data_on_exit = True
save_history = True
save_non_project_files = False
project_type = 'empty-project-type'
recent_files = ['C:\\Users\\ZHANG Jun\\.spyder-py3\\temp.py', 'agat\\test\\test_stress_model.py', 'agat\\model\\model.py', 'Change_log.md']
recent_files = ['agat\\test\\dataloader_test.py', 'agat\\data\\build_dataset.py']

[main]
version = 0.2.0
Expand Down
2 changes: 2 additions & 0 deletions Change_log.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# [main](https://github.com/jzhang-github/AGAT/tree/main)
- when extracting data, the true stress is amplified 1000 to avoid the pytorch accuracy.
- Stresses of many snapshots are lower than 1e-6, which can be problematic when training with torch.float32

# [v9.0.0](https://github.com/jzhang-github/AGAT/tree/v9.0.0)
**Note: AGAT after this version (included) cannot load the well-trained model before.** If you need to do so, please use v8.0.5: https://pypi.org/project/agat/8.0.5/
Expand Down
3 changes: 2 additions & 1 deletion agat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@author: ZHANG Jun
"""

__version__ = '9.0.0'
__version__ = '9.0.1'

import os
# os.environ['DGLBACKEND']="pytorch"
Expand All @@ -14,6 +14,7 @@
from .data.build_dataset import BuildDatabase
from .data.build_dataset import CrystalGraph
from .model.fit import Fit
from .model.model import PotentialModel
from .app.cata.high_throughput_predict import HtAds

del os
8 changes: 5 additions & 3 deletions agat/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,14 @@ def calculate(self, atoms=None, properties=None, system_changes=['positions', 'n
graph, info = self.cg.get_graph(atoms)
graph = graph.to(self.device)

# Stresses of many snapshots are lower than 1e-6, which can be
# problematic when training with torch.float32
with torch.no_grad():
energy_pred, force_pred, stress_pred = self.model.forward(graph)
energy_pred, force_pred, stress1000_pred = self.model.forward(graph)

self.results = {'energy': energy_pred[0].item() * len(atoms),
'forces': force_pred.cpu().numpy()}
'forces': force_pred.cpu().numpy(),
'stress': stress1000_pred.cpu().numpy() / 1000}

# class GatCalculator(Calculator):
# """ Calculator with ASE module. Pymatgen is also needed.
Expand Down Expand Up @@ -186,4 +189,3 @@ def calculate(self, atoms=None, properties=None, system_changes=['positions', 'n
# debug
if __name__ == '__main__':
pass

19 changes: 14 additions & 5 deletions agat/data/build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,13 @@ def get_graph_from_ase(self, fname): # ase_atoms or file name
energy_true = torch.tensor(np.load(fname+'_energy.npy'), dtype=self.dtype)
graph_info['energy_true'] = energy_true
if self.data_config['build_properties']['stress']:
stress_true = torch.tensor(np.load(fname+'_stress.npy'), dtype=self.dtype)
graph_info['stress_true'] = stress_true
# Stresses of many snapshots are lower than 1e-6, which can be
# problematic when training with torch.float32
stress_true = np.load(fname+'_stress.npy')
# graph_info['stress_true'] = torch.tensor(stress_true,
# dtype=self.dtype)
graph_info['stress_true1000'] = torch.tensor(stress_true * 1000,
dtype=self.dtype)
if self.data_config['build_properties']['cell']:
cell_true = torch.tensor(ase_atoms.cell.array, dtype=self.dtype)
graph_info['cell_true'] = cell_true
Expand Down Expand Up @@ -354,8 +359,13 @@ def get_graph_from_pymatgen(self, crystal_fname):
energy_true = torch.tensor(np.load(crystal_fname+'_energy.npy'), dtype=self.dtype)
graph_info['energy_true'] = torch.tensor((energy_true), dtype=self.dtype)
if self.data_config['build_properties']['stress']:
stress_true = torch.tensor(np.load(crystal_fname+'_stress.npy'), dtype=self.dtype)
graph_info['stress_true'] = torch.tensor((stress_true), dtype=self.dtype)
# Stresses of many snapshots are lower than 1e-6, which can be
# problematic when training with torch.float32
stress_true = np.load(crystal_fname+'_stress.npy')
# graph_info['stress_true'] = torch.tensor(stress_true,
# dtype=self.dtype)
graph_info['stress_true1000'] = torch.tensor(stress_true * 1000,
dtype=self.dtype)
if self.data_config['build_properties']['cell']:
cell_true = torch.tensor(mycrystal.lattice.matrix, dtype=self.dtype)
graph_info['cell_true'] = torch.tensor((cell_true), dtype=self.dtype)
Expand Down Expand Up @@ -827,4 +837,3 @@ def select_graphs_random(fname: str, num: int):
if __name__ == '__main__':
ad = BuildDatabase(mode_of_NN='pymatgen_dist', num_of_cores=16)
ad.build()

4 changes: 2 additions & 2 deletions agat/default_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
'negative_slope': 0.2,
'criterion': nn.MSELoss(),
'a': 1.0,
'b': 50.0,
'c': 1000.0,
'b': 1.0,
'c': 1.0,
# 'optimizer': 'adam',
'learning_rate': 0.0001,
'weight_decay': 0.0, # weight decay (L2 penalty)
Expand Down
3 changes: 2 additions & 1 deletion agat/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@

from .layer import Layer
from .model import PotentialModel, CrystalPropertyModel, AtomicPropertyModel, AtomicVectorModel
from .fit import Fit
from .fit import Fit
from agat.lib.model_lib import save_model, load_model, save_state_dict, load_state_dict
10 changes: 5 additions & 5 deletions agat/model/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def fit(self, **train_config):
for i, (graph, props) in enumerate(self.train_loader):
energy_true = props['energy_true']
force_true = graph.ndata['forces_true']
stress_true = props['stress_true']
stress_true = props['stress_true1000']
optimizer.zero_grad()
energy_pred, force_pred, stress_pred = model.forward(graph)
energy_loss = criterion(energy_pred, energy_true)
Expand Down Expand Up @@ -241,7 +241,7 @@ def fit(self, **train_config):
for i, (graph, props) in enumerate(self.val_loader):
energy_true_all.append(props['energy_true'])
force_true = graph.ndata['forces_true']
stress_true_all.append(props['stress_true'])
stress_true_all.append(props['stress_true1000'])
energy_pred, force_pred, stress_pred = model.forward(graph)
energy_pred_all.append(energy_pred)
if self._has_adsorbate:
Expand Down Expand Up @@ -326,7 +326,7 @@ def fit(self, **train_config):
for i, (graph, props) in enumerate(self.test_loader):
energy_true_all.append(props['energy_true'])
force_true_all.append(graph.ndata['forces_true'])
stress_true_all.append(props['stress_true'])
stress_true_all.append(props['stress_true1000'])
energy_pred, force_pred, stress_pred = model.forward(graph)
energy_pred_all.append(energy_pred)
force_pred_all.append(force_pred)
Expand Down Expand Up @@ -356,11 +356,11 @@ def fit(self, **train_config):
Epoch : {epoch}
Energy loss: {energy_loss.item()}
Force_Loss : {force_loss.item()}
Stress_Loss: {stress_loss.item()}
Stress_Loss: {stress_loss.item()} (units: 1e-3 ASE stress)
Total_Loss : {total_loss.item()}
Energy_MAE : {energy_mae.item()}
Force_MAE : {force_mae.item()}
Stress_MAE : {stress_mae.item()}
Stress_MAE : {stress_mae.item()} (units: 1e-3 ASE stress)
Energy_R : {energy_r.item()}
Force_R : {force_r.item()}
Stress_R : {stress_r.item()}
Expand Down
2 changes: 1 addition & 1 deletion agat/test/export_cell_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
graph_list, props = load_graphs('all_graphs_generation_0.bin')

stresses = props['stress_true'].numpy()
stresses = props['stress_true'].numpy() # props['stress_true1000']


plt.hist(stresses[:,0], bins=100, range=[-0.01, 0.01])
Loading

0 comments on commit 9b9a51e

Please sign in to comment.