-
Notifications
You must be signed in to change notification settings - Fork 274
/
run.py
180 lines (148 loc) · 8.96 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import time
import os
import torch
from torch.optim import Adam
from torch_geometric.data import DataLoader
import numpy as np
from torch.autograd import grad
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
class run():
r"""
The base script for running different 3DGN methods.
"""
def __init__(self):
pass
def run(self, device, train_dataset, valid_dataset, test_dataset, model, loss_func, evaluation, epochs=500, batch_size=32, vt_batch_size=32, lr=0.0005, lr_decay_factor=0.5, lr_decay_step_size=50, weight_decay=0,
energy_and_force=False, p=100, save_dir='', log_dir=''):
r"""
The run script for training and validation.
Args:
device (torch.device): Device for computation.
train_dataset: Training data.
valid_dataset: Validation data.
test_dataset: Test data.
model: Which 3DGN model to use. Should be one of the SchNet, DimeNetPP, and SphereNet.
loss_func (function): The used loss funtion for training.
evaluation (function): The evaluation function.
epochs (int, optinal): Number of total training epochs. (default: :obj:`500`)
batch_size (int, optinal): Number of samples in each minibatch in training. (default: :obj:`32`)
vt_batch_size (int, optinal): Number of samples in each minibatch in validation/testing. (default: :obj:`32`)
lr (float, optinal): Initial learning rate. (default: :obj:`0.0005`)
lr_decay_factor (float, optinal): Learning rate decay factor. (default: :obj:`0.5`)
lr_decay_step_size (int, optinal): epochs at which lr_initial <- lr_initial * lr_decay_factor. (default: :obj:`50`)
weight_decay (float, optinal): weight decay factor at the regularization term. (default: :obj:`0`)
energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the minus derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`)
p (int, optinal): The forces’ weight for a joint loss of forces and conserved energy during training. (default: :obj:`100`)
save_dir (str, optinal): The path to save trained models. If set to :obj:`''`, will not save the model. (default: :obj:`''`)
log_dir (str, optinal): The path to save log files. If set to :obj:`''`, will not save the log files. (default: :obj:`''`)
"""
model = model.to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f'#Params: {num_params}')
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = StepLR(optimizer, step_size=lr_decay_step_size, gamma=lr_decay_factor)
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, vt_batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, vt_batch_size, shuffle=False)
best_valid = float('inf')
best_test = float('inf')
if save_dir != '':
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if log_dir != '':
if not os.path.exists(log_dir):
os.makedirs(log_dir)
writer = SummaryWriter(log_dir=log_dir)
for epoch in range(1, epochs + 1):
print("\n=====Epoch {}".format(epoch), flush=True)
print('\nTraining...', flush=True)
train_mae = self.train(model, optimizer, train_loader, energy_and_force, p, loss_func, device)
print('\n\nEvaluating...', flush=True)
valid_mae = self.val(model, valid_loader, energy_and_force, p, evaluation, device)
print('\n\nTesting...', flush=True)
test_mae = self.val(model, test_loader, energy_and_force, p, evaluation, device)
print()
print({'Train': train_mae, 'Validation': valid_mae, 'Test': test_mae})
if log_dir != '':
writer.add_scalar('train_mae', train_mae, epoch)
writer.add_scalar('valid_mae', valid_mae, epoch)
writer.add_scalar('test_mae', test_mae, epoch)
if valid_mae < best_valid:
best_valid = valid_mae
best_test = test_mae
if save_dir != '':
print('Saving checkpoint...')
checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_valid_mae': best_valid, 'num_params': num_params}
torch.save(checkpoint, os.path.join(save_dir, 'valid_checkpoint.pt'))
scheduler.step()
print(f'Best validation MAE so far: {best_valid}')
print(f'Test MAE when got best validation result: {best_test}')
if log_dir != '':
writer.close()
def train(self, model, optimizer, train_loader, energy_and_force, p, loss_func, device):
r"""
The script for training.
Args:
model: Which 3DGN model to use. Should be one of the SchNet, DimeNetPP, and SphereNet.
optimizer (Optimizer): Pytorch optimizer for trainable parameters in training.
train_loader (Dataloader): Dataloader for training.
energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the minus derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`)
p (int, optinal): The forces’ weight for a joint loss of forces and conserved energy during training. (default: :obj:`100`)
loss_func (function): The used loss funtion for training.
device (torch.device): The device where the model is deployed.
:rtype: Traning loss. ( :obj:`mae`)
"""
model.train()
loss_accum = 0
for step, batch_data in enumerate(tqdm(train_loader)):
optimizer.zero_grad()
batch_data = batch_data.to(device)
out = model(batch_data)
if energy_and_force:
force = -grad(outputs=out, inputs=batch_data.pos, grad_outputs=torch.ones_like(out),create_graph=True,retain_graph=True)[0]
e_loss = loss_func(out, batch_data.y.unsqueeze(1))
f_loss = loss_func(force, batch_data.force)
loss = e_loss + p * f_loss
else:
loss = loss_func(out, batch_data.y.unsqueeze(1))
loss.backward()
optimizer.step()
loss_accum += loss.detach().cpu().item()
return loss_accum / (step + 1)
def val(self, model, data_loader, energy_and_force, p, evaluation, device):
r"""
The script for validation/test.
Args:
model: Which 3DGN model to use. Should be one of the SchNet, DimeNetPP, and SphereNet.
data_loader (Dataloader): Dataloader for validation or test.
energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the minus derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`)
p (int, optinal): The forces’ weight for a joint loss of forces and conserved energy. (default: :obj:`100`)
evaluation (function): The used funtion for evaluation.
device (torch.device, optional): The device where the model is deployed.
:rtype: Evaluation result. ( :obj:`mae`)
"""
model.eval()
preds = torch.Tensor([]).to(device)
targets = torch.Tensor([]).to(device)
if energy_and_force:
preds_force = torch.Tensor([]).to(device)
targets_force = torch.Tensor([]).to(device)
for step, batch_data in enumerate(tqdm(data_loader)):
batch_data = batch_data.to(device)
out = model(batch_data)
if energy_and_force:
force = -grad(outputs=out, inputs=batch_data.pos, grad_outputs=torch.ones_like(out),create_graph=True,retain_graph=True)[0]
preds_force = torch.cat([preds_force,force.detach_()], dim=0)
targets_force = torch.cat([targets_force,batch_data.force], dim=0)
preds = torch.cat([preds, out.detach_()], dim=0)
targets = torch.cat([targets, batch_data.y.unsqueeze(1)], dim=0)
input_dict = {"y_true": targets, "y_pred": preds}
if energy_and_force:
input_dict_force = {"y_true": targets_force, "y_pred": preds_force}
energy_mae = evaluation.eval(input_dict)['mae']
force_mae = evaluation.eval(input_dict_force)['mae']
print({'Energy MAE': energy_mae, 'Force MAE': force_mae})
return energy_mae + p * force_mae
return evaluation.eval(input_dict)['mae']