Skip to content

Commit

Permalink
Update test_stress_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhang-github committed Dec 28, 2023
1 parent 5bf2359 commit 20182eb
Showing 1 changed file with 77 additions and 18 deletions.
95 changes: 77 additions & 18 deletions agat/test/test_stress_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import time
from datetime import datetime
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.optim import lr_scheduler
Expand Down Expand Up @@ -110,7 +111,12 @@ def __init__(self,

# self.stress_bias = torch.nn.Parameter(torch.randn(1), requires_grad=True)
# self.stress_act = nn.LeakyReLU(negative_slope=self.negative_slope)
self.stress_outlayer = nn.Linear(6, 6, self.bias, self.device)
# self.stress_outlayer = nn.Linear(6, 6, False, self.device)
# self.stress_outlayer = nn.BatchNorm1d(6, device=self.device)
self.u2e = nn.Linear(self.gat_node_dim_list[0], self.stress_readout_node_list[0],
False, self.device) # propogate source nodes to edges.
self.skip_connect = nn.Linear(1, self.stress_readout_node_list[0],
False, self.device)

for l in range(self.num_gat_layers):
self.gat_layers.append(Layer(self.__gat_real_node_dim_list[l],
Expand Down Expand Up @@ -140,16 +146,23 @@ def __init__(self,
self.force_readout_layers.append(nn.Linear(self.force_readout_node_list[l-self.tail_readout_no_act[1]-1],
self.force_readout_node_list[l-self.tail_readout_no_act[1]],
self.bias, self.device))

# Input dim: (number of nodes, number of heads * number of out)
# stress readout layer
for l in range(self.num_stress_readout_layers-self.tail_readout_no_act[2]):
self.stress_readout_layers.append(nn.Linear(self.stress_readout_node_list[l],
self.stress_readout_node_list[l+1],
self.bias, self.device))
# self.stress_readout_layers.append(nn.BatchNorm1d(
# self.stress_readout_node_list[l+1], device=self.device))
self.stress_readout_layers.append(nn.LeakyReLU(negative_slope=self.negative_slope))
for l in range(self.tail_readout_no_act[2]):
self.stress_readout_layers.append(nn.Linear(self.stress_readout_node_list[l-self.tail_readout_no_act[2]-1],
self.stress_readout_node_list[l-self.tail_readout_no_act[2]],
self.bias, self.device))
# self.stress_readout_layers.append(nn.BatchNorm1d(
# self.stress_readout_node_list[l-self.tail_readout_no_act[2]],
# device=self.device))

self.__real_num_energy_readout_layers = len(self.energy_readout_layers)
self.__real_num_force_readout_layers = len(self.force_readout_layers)
Expand Down Expand Up @@ -210,6 +223,9 @@ def get_head_mechanism(self, fn_list, TorchTensor):
TorchTensor_list.append(self.head_fn[func](TorchTensor))
return torch.cat(TorchTensor_list, 1)

# def get_unit_vector(self, direction):
# return direction/torch.norm(direction, dim=0)

def forward(self, graph):
"""The ``forward`` function of PotentialModel model.
Expand Down Expand Up @@ -255,19 +271,36 @@ def forward(self, graph):
force = graph.ndata['force_pred']

# Predict stress
graph.edata['stress_score'] = stress_score
graph.ndata['atom_code'] = self.u2e(graph.ndata['h'])
graph.apply_edges(fn.u_add_e('atom_code', 'stress_score', 'stress_score'))
stress_score = graph.edata['stress_score'] + self.skip_connect(graph.edata['dist'])

# graph.edata['stress_score_test'][6]
# graph.edata['stress_score_test'][690]

# torch.mean(graph.edata['stress_score_test'],dim=0)
# fn.copy_u('atom_code', 'm')

for l in range(self.__real_num_stress_readout_layers):
stress_score = self.stress_readout_layers[l](stress_score)

# unit_vector = graph.edata['direction']/torch.norm(graph.edata['direction'], dim=0)
graph.edata['stress_score_vector'] = stress_score * torch.cat((graph.edata['direction'],
graph.edata['direction']), dim=1) # shape (number of edges, 2)

batch_edges = graph.batch_num_edges().tolist()
stress = torch.split(graph.edata['stress_score_vector'], batch_edges)
stress = torch.stack([torch.mean(s, dim=0) for s in stress])

# graph.edata['stress_score_vector'] = self.stress_act(graph.edata['stress_score_vector'])
graph.edata['stress_score_vector'] = self.stress_outlayer(graph.edata['stress_score_vector'])
graph.update_all(fn.copy_e('stress_score_vector', 'm'), fn.sum('m', 'stress_pred')) # shape of graph.ndata['force_pred']: (number of nodes, 3)
# graph.edata['stress_score_vector'] = self.stress_outlayer(graph.edata['stress_score_vector'])
# graph.update_all(fn.copy_e('stress_score_vector', 'm'), fn.sum('m', 'stress_pred')) # shape of graph.ndata['force_pred']: (number of nodes, 3)
# stress = torch.sum(graph.ndata['stress_pred'], dim=0)
stress = torch.split(graph.ndata['stress_pred'], batch_nodes) # shape of stress: number of atoms * 6
stress = torch.stack([torch.mean(s, dim=0) for s in stress]) # + self.stress_bias
# stress = torch.split(graph.ndata['stress_pred'], batch_nodes) # shape of stress: number of atoms * 6
# stress = torch.stack([torch.sum(s, dim=0) for s in stress]) # + self.stress_bias
return energy, force, stress


class Fit(object):
def __init__(self, **train_config):
self.train_config = {**default_train_config, **config_parser(train_config)}
Expand Down Expand Up @@ -381,7 +414,9 @@ def __init__(self, **train_config):
os.mkdir(self.train_config['output_files'])

# debug
self.writer = SummaryWriter('fit_debug', flush_secs=10)
TIMESTAMP = "{0:%Y-%m-%d--%H-%M-%S}".format(datetime.now())
self.writer = SummaryWriter(os.path.join('fit_debug', TIMESTAMP),
flush_secs=10)

def fit(self, **train_config):
# update config if needed.
Expand All @@ -403,6 +438,11 @@ def fit(self, **train_config):
lr=self.train_config['learning_rate'],
weight_decay=self.train_config['weight_decay'])

# # reset parameters
# nn.init.orthogonal_(model.stress_outlayer.weight)
# for l in model.stress_readout_layers:
# if hasattr(l, 'weight'):
# nn.init.orthogonal_(l.weight)

# self.writer.add_graph(model, self._dataset[0][0], )
# load stat dict if there exists.
Expand Down Expand Up @@ -485,17 +525,27 @@ def fit(self, **train_config):
force_loss = criterion(force_pred, force_true)
stress_loss = criterion(stress_pred, stress_true)
self.writer.add_scalar('stress_loss', torch.flatten(stress_loss), batch_step)
self.writer.add_scalar('force_loss', torch.flatten(force_loss), batch_step)
total_loss = a*energy_loss + b*force_loss + c*stress_loss
total_loss.backward()
self.writer.add_histogram('stress_outlayer_weight_grad',
torch.flatten(model.stress_outlayer.weight.grad),
self.writer.add_histogram('u2e_weight_grad',
torch.flatten(model.u2e.weight.grad),
batch_step)
self.writer.add_histogram('u2e_weight',
torch.flatten(model.u2e.weight),
batch_step)
self.writer.add_histogram('stress_outlayer_bias_grad',
torch.flatten(model.stress_outlayer.bias.grad),
self.writer.add_histogram('skip_weight_grad',
torch.flatten(model.skip_connect.weight.grad),
batch_step)
# self.writer.add_histogram('stress_outlayer_bias_grad',
# torch.flatten(model.stress_outlayer.bias.grad),
# batch_step)
self.writer.add_histogram('last_force_readout_layer_weight_grad',
torch.flatten(model.force_readout_layers[5].weight.grad),
batch_step)
self.writer.add_histogram('last_force_readout_layer_weight',
torch.flatten(model.force_readout_layers[5].weight),
batch_step)
optimizer.step()
dur = time.time() - start_time
if self.verbose > 1:
Expand Down Expand Up @@ -541,6 +591,11 @@ def fit(self, **train_config):
energy_r = r(energy_pred_all, energy_true_all)
force_r = r(force_pred_all, force_true_all)
stress_r = r(stress_pred_all, stress_true_all)
self.writer.add_scalar('energy_r', energy_r, epoch)
self.writer.add_scalar('force_r', force_r, epoch)
self.writer.add_scalar('stress_r', stress_r, epoch)
self.writer.add_scalar('stress_mae', stress_mae, epoch)
# self.writer.add_scalar('stress_loss', stress_loss, epoch)
if self.verbose > 0:
print("{:0>5d} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:10.1f} Validation_info".format(
epoch, energy_loss.item(), force_loss.item(), stress_loss.item(),
Expand Down Expand Up @@ -662,7 +717,8 @@ def fit(self, **train_config):

if __name__ == '__main__':
# import dgl
# g_list, l_list = dgl.load_graphs('all_graphs_generation_0.bin')
# g_list, l_list = dgl.load_graphs(os.path.join(
# '..', 'all_graphs_generation_0_aimd_only.bin'))
# graph = g_list[1] #.to('cuda')

# feat = graph.ndata['h']
Expand All @@ -684,17 +740,20 @@ def fit(self, **train_config):
# # self = PM




import shutil
if os.path.isdir('agat_model'):
shutil.rmtree('agat_model')
print('remove agat_model')

FIX_VALUE = [1,3,6]
train_config = {
'verbose': 2,
'dataset_path': os.path.join('all_graphs_generation_0.bin'),
'dataset_path': os.path.join('..', 'all_graphs_generation_0_aimd_only.bin'),
# 'dataset_path': os.path.join('..', 'all_graphs_generation_0.bin'),
'model_save_dir': 'agat_model',
'epochs': 2,
'epochs': 200,
'output_files': 'out_file',
'device': 'cpu',
'device': 'cuda',
'validation_size': 0.15,
'test_size': 0.15,
'early_stop': True,
Expand All @@ -711,7 +770,7 @@ def fit(self, **train_config):
'b': 1.0,
'c': 1.0,
'optimizer': 'adam', # Fix to sgd.
'learning_rate': 0.0005,
'learning_rate': 0.0001,
'weight_decay': 0.0, # weight decay (L2 penalty)
'batch_size': 64,
'val_batch_size': 400,
Expand Down

0 comments on commit 20182eb

Please sign in to comment.