In [11]:
from rlmd.configuration import configuration;
from rlmd.trajectory import trajectory;
from rlmd.step import environment;
from rlmd.model import Q_NN;
from rlmd.train import Context_Bandit;
from rlmd.action_space import actions;

import numpy as np;
import json;
import torch;
import os;

task = 'Cu_2';
horizon = 30;
n_traj = 100;

model = Q_NN(elements=[1,29],r_cut = 4, 
             N_emb=12, N_fit = 64, atom_max = 40);
model.load('Cu_all/model/model90')
trainer = Context_Bandit(model,temperature = 1000);

pool = ['POSCARs/CONTCAR_H_Cu'+str(i+1) for i in range(3)];
pool += ['POSCARs/CONTCAR_H_111Cu'+str(i+1) for i in range(3)];
pool += ['POSCARs/CONTCAR_H_100Cu'+str(i+1) for i in range(3)];
pool += ['POSCARs/CONTCAR_H_Cu_s3_111_110_train'+str(i+1) for i in range(3)];
pool += ['POSCARs/CONTCAR_H_110Cu'+str(i+1) for i in range(3)]

traj_list = [];
if(task not in os.listdir()):
    os.mkdir(task);
if('traj' not in os.listdir(task)):
    os.mkdir(task+'/traj');
if('model' not in os.listdir(task)):
    os.mkdir(task+'/model');

with open(task+'/loss.txt','w') as file:
    file.write('epoch\t loss\n');

for epoch in range(n_traj):
    conf = configuration();
    file = pool[np.random.randint(len(pool))];
    conf.load(file);
    print('epoch = '+str(epoch)+':  '+file);
    conf.set_potential();
    env = environment(conf, max_iter=50);
    env.relax(accuracy = 0.1)
    traj_list.append(trajectory(1,0));
    for tstep in range(horizon):
        
        action_space = actions(conf,act_mul = 1.6,act_mul_move = 1.2);
        act_id, act_probs,Q = trainer.select_action(conf.atoms,action_space);
        action = action_space[act_id];
        info = {'act':act_id, 'act_probs':act_probs.tolist(),'act_space':action_space,'state':conf.atoms.copy(),
                'E_min':conf.potential()};
        
        E_next, fail = env.step(action, accuracy = 0.1);
        if(not fail):
            E_s, fail = env.saddle(accuracy = 0.1);
            info['E_s'] = E_s;
        else:
            info['E_s'] = 0;
            print('fail step 1');

        info['next'], info['fail'], info['E_next'] =conf.atoms.copy(), fail, E_next;
        traj_list[-1].add(info);
        if(fail):
            print('fail')
        if(tstep%10==0):
            print('    t = '+str(tstep));
            
    for _ in range(3*int(1+np.sqrt(epoch))):
        loss = trainer.update(traj_list,0);
    with open('loss.txt','a') as file:
        file.write(str(epoch)+'\t'+ str(loss)+'\n');
    try:
        traj_list[epoch].save(task+'/traj/traj'+str(epoch));
    except:
        print('saving failure');
        
    if(epoch%10==0):
        model.save(task+'/model/model'+str(epoch));
        

epoch = 0:  POSCARs/CONTCAR_H_Cu_s3_111_110_train2
    t = 0
    t = 1
    t = 2
    t = 3
    t = 4
    t = 5
    t = 6
    t = 7
    t = 8
    t = 9
epoch = 1:  POSCARs/CONTCAR_H_Cu_s3_111_110_train1
    t = 0
    t = 1
    t = 2
    t = 3
    t = 4
    t = 5
    t = 6
    t = 7
    t = 8
    t = 9
epoch = 2:  POSCARs/CONTCAR_H_Cu_s3_111_110_train3
    t = 0
    t = 1
    t = 2
    t = 3
    t = 4
    t = 5
    t = 6
    t = 7
    t = 8
    t = 9
epoch = 3:  POSCARs/CONTCAR_H_Cu3
    t = 0
    t = 1
    t = 2
    t = 3
    t = 4
    t = 5
    t = 6
    t = 7
    t = 8
    t = 9
epoch = 4:  POSCARs/CONTCAR_H_110Cu2
    t = 0
    t = 1
    t = 2
    t = 3
    t = 4
    t = 5
    t = 6
    t = 7
    t = 8
    t = 9
epoch = 5:  POSCARs/CONTCAR_H_110Cu2
    t = 0
    t = 1
    t = 2
    t = 3
    t = 4
    t = 5
    t = 6
    t = 7
    t = 8
    t = 9
epoch = 6:  POSCARs/CONTCAR_H_110Cu2
    t = 0
    t = 1
    t = 2
    t = 3
    t = 4
    t = 5
    t = 6
    t = 7
    t = 8
    t = 9
ep