Skip to content

Commit

Permalink
Revert "update stress training"
Browse files Browse the repository at this point in the history
This reverts commit 9b9a51e.
  • Loading branch information
jzhang-github committed Jan 2, 2024
1 parent 9b9a51e commit 4f7a8a3
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 140 deletions.
2 changes: 1 addition & 1 deletion .spyproject/config/backups/workspace.ini.bak
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ save_data_on_exit = True
save_history = True
save_non_project_files = False
project_type = 'empty-project-type'
recent_files = ['C:\\Users\\ZHANG Jun\\.spyder-py3\\temp.py', 'agat\\model\\model.py', 'Change_log.md', 'agat\\test\\POSCAR', 'agat\\test\\test_stress_model.py']
recent_files = []

[main]
version = 0.2.0
Expand Down
2 changes: 1 addition & 1 deletion .spyproject/config/workspace.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ save_data_on_exit = True
save_history = True
save_non_project_files = False
project_type = 'empty-project-type'
recent_files = ['agat\\test\\dataloader_test.py', 'agat\\data\\build_dataset.py']
recent_files = ['C:\\Users\\ZHANG Jun\\.spyder-py3\\temp.py', 'agat\\test\\test_stress_model.py', 'agat\\model\\model.py', 'Change_log.md']

[main]
version = 0.2.0
Expand Down
2 changes: 0 additions & 2 deletions Change_log.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# [main](https://github.com/jzhang-github/AGAT/tree/main)
- when extracting data, the true stress is amplified 1000 to avoid the pytorch accuracy.
- Stresses of many snapshots are lower than 1e-6, which can be problematic when training with torch.float32

# [v9.0.0](https://github.com/jzhang-github/AGAT/tree/v9.0.0)
**Note: AGAT after this version (included) cannot load the well-trained model before.** If you need to do so, please use v8.0.5: https://pypi.org/project/agat/8.0.5/
Expand Down
3 changes: 1 addition & 2 deletions agat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@author: ZHANG Jun
"""

__version__ = '9.0.1'
__version__ = '9.0.0'

import os
# os.environ['DGLBACKEND']="pytorch"
Expand All @@ -14,7 +14,6 @@
from .data.build_dataset import BuildDatabase
from .data.build_dataset import CrystalGraph
from .model.fit import Fit
from .model.model import PotentialModel
from .app.cata.high_throughput_predict import HtAds

del os
8 changes: 3 additions & 5 deletions agat/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,11 @@ def calculate(self, atoms=None, properties=None, system_changes=['positions', 'n
graph, info = self.cg.get_graph(atoms)
graph = graph.to(self.device)

# Stresses of many snapshots are lower than 1e-6, which can be
# problematic when training with torch.float32
with torch.no_grad():
energy_pred, force_pred, stress1000_pred = self.model.forward(graph)
energy_pred, force_pred, stress_pred = self.model.forward(graph)

self.results = {'energy': energy_pred[0].item() * len(atoms),
'forces': force_pred.cpu().numpy(),
'stress': stress1000_pred.cpu().numpy() / 1000}
'forces': force_pred.cpu().numpy()}

# class GatCalculator(Calculator):
# """ Calculator with ASE module. Pymatgen is also needed.
Expand Down Expand Up @@ -189,3 +186,4 @@ def calculate(self, atoms=None, properties=None, system_changes=['positions', 'n
# debug
if __name__ == '__main__':
pass

19 changes: 5 additions & 14 deletions agat/data/build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,8 @@ def get_graph_from_ase(self, fname): # ase_atoms or file name
energy_true = torch.tensor(np.load(fname+'_energy.npy'), dtype=self.dtype)
graph_info['energy_true'] = energy_true
if self.data_config['build_properties']['stress']:
# Stresses of many snapshots are lower than 1e-6, which can be
# problematic when training with torch.float32
stress_true = np.load(fname+'_stress.npy')
# graph_info['stress_true'] = torch.tensor(stress_true,
# dtype=self.dtype)
graph_info['stress_true1000'] = torch.tensor(stress_true * 1000,
dtype=self.dtype)
stress_true = torch.tensor(np.load(fname+'_stress.npy'), dtype=self.dtype)
graph_info['stress_true'] = stress_true
if self.data_config['build_properties']['cell']:
cell_true = torch.tensor(ase_atoms.cell.array, dtype=self.dtype)
graph_info['cell_true'] = cell_true
Expand Down Expand Up @@ -359,13 +354,8 @@ def get_graph_from_pymatgen(self, crystal_fname):
energy_true = torch.tensor(np.load(crystal_fname+'_energy.npy'), dtype=self.dtype)
graph_info['energy_true'] = torch.tensor((energy_true), dtype=self.dtype)
if self.data_config['build_properties']['stress']:
# Stresses of many snapshots are lower than 1e-6, which can be
# problematic when training with torch.float32
stress_true = np.load(crystal_fname+'_stress.npy')
# graph_info['stress_true'] = torch.tensor(stress_true,
# dtype=self.dtype)
graph_info['stress_true1000'] = torch.tensor(stress_true * 1000,
dtype=self.dtype)
stress_true = torch.tensor(np.load(crystal_fname+'_stress.npy'), dtype=self.dtype)
graph_info['stress_true'] = torch.tensor((stress_true), dtype=self.dtype)
if self.data_config['build_properties']['cell']:
cell_true = torch.tensor(mycrystal.lattice.matrix, dtype=self.dtype)
graph_info['cell_true'] = torch.tensor((cell_true), dtype=self.dtype)
Expand Down Expand Up @@ -837,3 +827,4 @@ def select_graphs_random(fname: str, num: int):
if __name__ == '__main__':
ad = BuildDatabase(mode_of_NN='pymatgen_dist', num_of_cores=16)
ad.build()

4 changes: 2 additions & 2 deletions agat/default_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
'negative_slope': 0.2,
'criterion': nn.MSELoss(),
'a': 1.0,
'b': 1.0,
'c': 1.0,
'b': 50.0,
'c': 1000.0,
# 'optimizer': 'adam',
'learning_rate': 0.0001,
'weight_decay': 0.0, # weight decay (L2 penalty)
Expand Down
3 changes: 1 addition & 2 deletions agat/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@

from .layer import Layer
from .model import PotentialModel, CrystalPropertyModel, AtomicPropertyModel, AtomicVectorModel
from .fit import Fit
from agat.lib.model_lib import save_model, load_model, save_state_dict, load_state_dict
from .fit import Fit
10 changes: 5 additions & 5 deletions agat/model/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def fit(self, **train_config):
for i, (graph, props) in enumerate(self.train_loader):
energy_true = props['energy_true']
force_true = graph.ndata['forces_true']
stress_true = props['stress_true1000']
stress_true = props['stress_true']
optimizer.zero_grad()
energy_pred, force_pred, stress_pred = model.forward(graph)
energy_loss = criterion(energy_pred, energy_true)
Expand Down Expand Up @@ -241,7 +241,7 @@ def fit(self, **train_config):
for i, (graph, props) in enumerate(self.val_loader):
energy_true_all.append(props['energy_true'])
force_true = graph.ndata['forces_true']
stress_true_all.append(props['stress_true1000'])
stress_true_all.append(props['stress_true'])
energy_pred, force_pred, stress_pred = model.forward(graph)
energy_pred_all.append(energy_pred)
if self._has_adsorbate:
Expand Down Expand Up @@ -326,7 +326,7 @@ def fit(self, **train_config):
for i, (graph, props) in enumerate(self.test_loader):
energy_true_all.append(props['energy_true'])
force_true_all.append(graph.ndata['forces_true'])
stress_true_all.append(props['stress_true1000'])
stress_true_all.append(props['stress_true'])
energy_pred, force_pred, stress_pred = model.forward(graph)
energy_pred_all.append(energy_pred)
force_pred_all.append(force_pred)
Expand Down Expand Up @@ -356,11 +356,11 @@ def fit(self, **train_config):
Epoch : {epoch}
Energy loss: {energy_loss.item()}
Force_Loss : {force_loss.item()}
Stress_Loss: {stress_loss.item()} (units: 1e-3 ASE stress)
Stress_Loss: {stress_loss.item()}
Total_Loss : {total_loss.item()}
Energy_MAE : {energy_mae.item()}
Force_MAE : {force_mae.item()}
Stress_MAE : {stress_mae.item()} (units: 1e-3 ASE stress)
Stress_MAE : {stress_mae.item()}
Energy_R : {energy_r.item()}
Force_R : {force_r.item()}
Stress_R : {stress_r.item()}
Expand Down
2 changes: 1 addition & 1 deletion agat/test/export_cell_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
graph_list, props = load_graphs('all_graphs_generation_0.bin')

stresses = props['stress_true'].numpy() # props['stress_true1000']
stresses = props['stress_true'].numpy()


plt.hist(stresses[:,0], bins=100, range=[-0.01, 0.01])
Loading

0 comments on commit 4f7a8a3

Please sign in to comment.