In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from rgnn.graph.dataset.reaction import ReactionDataset
from rgnn.graph.reaction import ReactionGraph
from rgnn.graph.utils import batch_to
from rgnn.models.reaction_models import PaiNN
from rgnn.models.reaction import ReactionGNN
from rlmd.configuration import configuration;
from rlmd.trajectory import trajectory;
from rlmd.step import environment;
from rlmd.model import Q_NN;
from rlmd.train import Context_Bandit;
from rlmd.action_space import actions;
from rlmd.train_graph import ContextBandit

In [17]:
task = 'MEA_freq1';
if task not in os.listdir():
    os.makedirs(task, exist_ok=True)
horizon = 30;
n_traj = 101;

model = Q_NN(elements=[1,24,27,28],r_cut = 4, 
             N_emb=24, N_fit = 128, atom_max = 40);
species = ["H", "Cr", "Co", "Ni"]
means = {'barrier': torch.tensor(0.6652), 'freq': torch.tensor(2.8553), 'delta_e': torch.tensor(0.0081)} 
stddevs = {'barrier': torch.tensor(0.4665), 'freq': torch.tensor(0.7151), 'delta_e': torch.tensor(0.2713)}
model_graph = PaiNN(species=species)
reaction_model = ReactionGNN(model_graph)

trainer = Context_Bandit(model,temperature = 1000);
trainer_graph = ContextBandit(reaction_model,temperature = 1000);

pool = ['POSCARs/CONTCAR_H_CCN'+str(i) for i in range(1,10)];
conf = configuration();
file = pool[np.random.randint(len(pool))];
conf.load(file);
# print('epoch = '+str(epoch)+':  '+file);
conf.set_potential(platform="mace");
env = environment(conf, logfile = task+'/log', max_iter=100);
env.relax(accuracy = 0.1)
# traj_list.append(trajectory(1,0));

action_space = actions(conf,dist_mul_body = 1.2, act_mul = 1.6,act_mul_move = 1.2);

None None
Using Materials Project MACE for MACECalculator with /home/hjchun/.cache/mace/5yyxdm76
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.


In [19]:
reaction_model.reaction_model.means = {'barrier': torch.tensor(0.6652), 'freq': torch.tensor(2.8553), 'delta_e': torch.tensor(0.0081)} 

In [21]:
reaction_model.reaction_model.means

{'barrier': tensor(0.6652), 'freq': tensor(2.8553), 'delta_e': tensor(0.0081)}

In [8]:
act_id, act_probs,Q = trainer.select_action(conf.atoms,action_space);
print(act_id)
print(act_probs)
print(Q)

13
tensor([0.0537, 0.0558, 0.0555, 0.0548, 0.0478, 0.0480, 0.0489, 0.0483, 0.0530,
        0.0517, 0.0528, 0.0516, 0.0485, 0.0458, 0.0489, 0.0463, 0.0483, 0.0456,
        0.0487, 0.0461], grad_fn=<SoftmaxBackward0>)
tensor([[0.0685, 0.1099],
        [0.0718, 0.1102],
        [0.0713, 0.1104],
        [0.0702, 0.1105],
        [0.0591, 0.1028],
        [0.0594, 0.1027],
        [0.0610, 0.1035],
        [0.0600, 0.1030],
        [0.0683, 0.0980],
        [0.0662, 0.0979],
        [0.0681, 0.0981],
        [0.0661, 0.0979],
        [0.0594, 0.1130],
        [0.0544, 0.1131],
        [0.0602, 0.1133],
        [0.0555, 0.1134],
        [0.0591, 0.1128],
        [0.0542, 0.1130],
        [0.0598, 0.1130],
        [0.0550, 0.1132]], grad_fn=<SelectBackward0>)


In [16]:
act_id, act_probs,Q = trainer_graph.select_action(conf.atoms,action_space);
print(act_id)
print(act_probs)
print(Q)

torch.Size([4])
torch.Size([16, 1]) torch.Size([16, 1]) torch.Size([16, 1]) torch.Size([16, 1])
torch.Size([16])
torch.Size([4])
torch.Size([4, 1]) torch.Size([4, 1]) torch.Size([4, 1]) torch.Size([4, 1])
torch.Size([4])
torch.Size([20])
torch.Size([20])
18
tensor([0.0379, 0.0435, 0.0418, 0.0436, 0.0472, 0.0446, 0.0467, 0.0480, 0.0432,
        0.0433, 0.0447, 0.0436, 0.0616, 0.0579, 0.0567, 0.0549, 0.0645, 0.0562,
        0.0624, 0.0576], device='cuda:0')
tensor([-0.1954, -0.0567, -0.0973, -0.0548,  0.0245, -0.0320,  0.0145,  0.0405,
        -0.0651, -0.0618, -0.0305, -0.0542,  0.2899,  0.2284,  0.2082,  0.1757,
         0.3372,  0.1989,  0.3030,  0.2236], device='cuda:0')


In [21]:
final_atoms = conf.atoms.copy()
final_positions = []
for i, pos in enumerate(final_atoms.get_positions()):
    if i == act[0]:
        new_pos = pos+act[1:]
        final_positions.append(new_pos)
    else:
        final_positions.append(pos)
final_atoms.set_positions(final_positions)

[-0.8453077224347045, 1.0320070718457819, -0.4247543502455824]
[ 8.94668499 14.08908368  7.4275432 ]


In [22]:
from ase import io
io.write("inital.vasp", conf.atoms)
io.write("final.vasp", final_atoms)