From 20182ebb1b5a0c22e5c86cfad6de90fd4b57b925 Mon Sep 17 00:00:00 2001 From: jzhang <39238584+jzhang-github@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:23:49 +0800 Subject: [PATCH] Update test_stress_model.py --- agat/test/test_stress_model.py | 95 +++++++++++++++++++++++++++------- 1 file changed, 77 insertions(+), 18 deletions(-) diff --git a/agat/test/test_stress_model.py b/agat/test/test_stress_model.py index 77fd33a..bff085a 100644 --- a/agat/test/test_stress_model.py +++ b/agat/test/test_stress_model.py @@ -17,6 +17,7 @@ import numpy as np import time +from datetime import datetime import torch.optim as optim from torch.utils.data import DataLoader, random_split from torch.optim import lr_scheduler @@ -110,7 +111,12 @@ def __init__(self, # self.stress_bias = torch.nn.Parameter(torch.randn(1), requires_grad=True) # self.stress_act = nn.LeakyReLU(negative_slope=self.negative_slope) - self.stress_outlayer = nn.Linear(6, 6, self.bias, self.device) + # self.stress_outlayer = nn.Linear(6, 6, False, self.device) + # self.stress_outlayer = nn.BatchNorm1d(6, device=self.device) + self.u2e = nn.Linear(self.gat_node_dim_list[0], self.stress_readout_node_list[0], + False, self.device) # propogate source nodes to edges. + self.skip_connect = nn.Linear(1, self.stress_readout_node_list[0], + False, self.device) for l in range(self.num_gat_layers): self.gat_layers.append(Layer(self.__gat_real_node_dim_list[l], @@ -140,16 +146,23 @@ def __init__(self, self.force_readout_layers.append(nn.Linear(self.force_readout_node_list[l-self.tail_readout_no_act[1]-1], self.force_readout_node_list[l-self.tail_readout_no_act[1]], self.bias, self.device)) + + # Input dim: (number of nodes, number of heads * number of out) # stress readout layer for l in range(self.num_stress_readout_layers-self.tail_readout_no_act[2]): self.stress_readout_layers.append(nn.Linear(self.stress_readout_node_list[l], self.stress_readout_node_list[l+1], self.bias, self.device)) + # self.stress_readout_layers.append(nn.BatchNorm1d( + # self.stress_readout_node_list[l+1], device=self.device)) self.stress_readout_layers.append(nn.LeakyReLU(negative_slope=self.negative_slope)) for l in range(self.tail_readout_no_act[2]): self.stress_readout_layers.append(nn.Linear(self.stress_readout_node_list[l-self.tail_readout_no_act[2]-1], self.stress_readout_node_list[l-self.tail_readout_no_act[2]], self.bias, self.device)) + # self.stress_readout_layers.append(nn.BatchNorm1d( + # self.stress_readout_node_list[l-self.tail_readout_no_act[2]], + # device=self.device)) self.__real_num_energy_readout_layers = len(self.energy_readout_layers) self.__real_num_force_readout_layers = len(self.force_readout_layers) @@ -210,6 +223,9 @@ def get_head_mechanism(self, fn_list, TorchTensor): TorchTensor_list.append(self.head_fn[func](TorchTensor)) return torch.cat(TorchTensor_list, 1) + # def get_unit_vector(self, direction): + # return direction/torch.norm(direction, dim=0) + def forward(self, graph): """The ``forward`` function of PotentialModel model. @@ -255,19 +271,36 @@ def forward(self, graph): force = graph.ndata['force_pred'] # Predict stress + graph.edata['stress_score'] = stress_score + graph.ndata['atom_code'] = self.u2e(graph.ndata['h']) + graph.apply_edges(fn.u_add_e('atom_code', 'stress_score', 'stress_score')) + stress_score = graph.edata['stress_score'] + self.skip_connect(graph.edata['dist']) + + # graph.edata['stress_score_test'][6] + # graph.edata['stress_score_test'][690] + + # torch.mean(graph.edata['stress_score_test'],dim=0) + # fn.copy_u('atom_code', 'm') + for l in range(self.__real_num_stress_readout_layers): stress_score = self.stress_readout_layers[l](stress_score) + + # unit_vector = graph.edata['direction']/torch.norm(graph.edata['direction'], dim=0) graph.edata['stress_score_vector'] = stress_score * torch.cat((graph.edata['direction'], graph.edata['direction']), dim=1) # shape (number of edges, 2) + + batch_edges = graph.batch_num_edges().tolist() + stress = torch.split(graph.edata['stress_score_vector'], batch_edges) + stress = torch.stack([torch.mean(s, dim=0) for s in stress]) + # graph.edata['stress_score_vector'] = self.stress_act(graph.edata['stress_score_vector']) - graph.edata['stress_score_vector'] = self.stress_outlayer(graph.edata['stress_score_vector']) - graph.update_all(fn.copy_e('stress_score_vector', 'm'), fn.sum('m', 'stress_pred')) # shape of graph.ndata['force_pred']: (number of nodes, 3) + # graph.edata['stress_score_vector'] = self.stress_outlayer(graph.edata['stress_score_vector']) + # graph.update_all(fn.copy_e('stress_score_vector', 'm'), fn.sum('m', 'stress_pred')) # shape of graph.ndata['force_pred']: (number of nodes, 3) # stress = torch.sum(graph.ndata['stress_pred'], dim=0) - stress = torch.split(graph.ndata['stress_pred'], batch_nodes) # shape of stress: number of atoms * 6 - stress = torch.stack([torch.mean(s, dim=0) for s in stress]) # + self.stress_bias + # stress = torch.split(graph.ndata['stress_pred'], batch_nodes) # shape of stress: number of atoms * 6 + # stress = torch.stack([torch.sum(s, dim=0) for s in stress]) # + self.stress_bias return energy, force, stress - class Fit(object): def __init__(self, **train_config): self.train_config = {**default_train_config, **config_parser(train_config)} @@ -381,7 +414,9 @@ def __init__(self, **train_config): os.mkdir(self.train_config['output_files']) # debug - self.writer = SummaryWriter('fit_debug', flush_secs=10) + TIMESTAMP = "{0:%Y-%m-%d--%H-%M-%S}".format(datetime.now()) + self.writer = SummaryWriter(os.path.join('fit_debug', TIMESTAMP), + flush_secs=10) def fit(self, **train_config): # update config if needed. @@ -403,6 +438,11 @@ def fit(self, **train_config): lr=self.train_config['learning_rate'], weight_decay=self.train_config['weight_decay']) + # # reset parameters + # nn.init.orthogonal_(model.stress_outlayer.weight) + # for l in model.stress_readout_layers: + # if hasattr(l, 'weight'): + # nn.init.orthogonal_(l.weight) # self.writer.add_graph(model, self._dataset[0][0], ) # load stat dict if there exists. @@ -485,17 +525,27 @@ def fit(self, **train_config): force_loss = criterion(force_pred, force_true) stress_loss = criterion(stress_pred, stress_true) self.writer.add_scalar('stress_loss', torch.flatten(stress_loss), batch_step) + self.writer.add_scalar('force_loss', torch.flatten(force_loss), batch_step) total_loss = a*energy_loss + b*force_loss + c*stress_loss total_loss.backward() - self.writer.add_histogram('stress_outlayer_weight_grad', - torch.flatten(model.stress_outlayer.weight.grad), + self.writer.add_histogram('u2e_weight_grad', + torch.flatten(model.u2e.weight.grad), + batch_step) + self.writer.add_histogram('u2e_weight', + torch.flatten(model.u2e.weight), batch_step) - self.writer.add_histogram('stress_outlayer_bias_grad', - torch.flatten(model.stress_outlayer.bias.grad), + self.writer.add_histogram('skip_weight_grad', + torch.flatten(model.skip_connect.weight.grad), batch_step) + # self.writer.add_histogram('stress_outlayer_bias_grad', + # torch.flatten(model.stress_outlayer.bias.grad), + # batch_step) self.writer.add_histogram('last_force_readout_layer_weight_grad', torch.flatten(model.force_readout_layers[5].weight.grad), batch_step) + self.writer.add_histogram('last_force_readout_layer_weight', + torch.flatten(model.force_readout_layers[5].weight), + batch_step) optimizer.step() dur = time.time() - start_time if self.verbose > 1: @@ -541,6 +591,11 @@ def fit(self, **train_config): energy_r = r(energy_pred_all, energy_true_all) force_r = r(force_pred_all, force_true_all) stress_r = r(stress_pred_all, stress_true_all) + self.writer.add_scalar('energy_r', energy_r, epoch) + self.writer.add_scalar('force_r', force_r, epoch) + self.writer.add_scalar('stress_r', stress_r, epoch) + self.writer.add_scalar('stress_mae', stress_mae, epoch) + # self.writer.add_scalar('stress_loss', stress_loss, epoch) if self.verbose > 0: print("{:0>5d} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:1.8f} {:10.1f} Validation_info".format( epoch, energy_loss.item(), force_loss.item(), stress_loss.item(), @@ -662,7 +717,8 @@ def fit(self, **train_config): if __name__ == '__main__': # import dgl - # g_list, l_list = dgl.load_graphs('all_graphs_generation_0.bin') + # g_list, l_list = dgl.load_graphs(os.path.join( + # '..', 'all_graphs_generation_0_aimd_only.bin')) # graph = g_list[1] #.to('cuda') # feat = graph.ndata['h'] @@ -684,17 +740,20 @@ def fit(self, **train_config): # # self = PM - - + import shutil + if os.path.isdir('agat_model'): + shutil.rmtree('agat_model') + print('remove agat_model') FIX_VALUE = [1,3,6] train_config = { 'verbose': 2, - 'dataset_path': os.path.join('all_graphs_generation_0.bin'), + 'dataset_path': os.path.join('..', 'all_graphs_generation_0_aimd_only.bin'), + # 'dataset_path': os.path.join('..', 'all_graphs_generation_0.bin'), 'model_save_dir': 'agat_model', - 'epochs': 2, + 'epochs': 200, 'output_files': 'out_file', - 'device': 'cpu', + 'device': 'cuda', 'validation_size': 0.15, 'test_size': 0.15, 'early_stop': True, @@ -711,7 +770,7 @@ def fit(self, **train_config): 'b': 1.0, 'c': 1.0, 'optimizer': 'adam', # Fix to sgd. - 'learning_rate': 0.0005, + 'learning_rate': 0.0001, 'weight_decay': 0.0, # weight decay (L2 penalty) 'batch_size': 64, 'val_batch_size': 400,