Test the performance of pre-trained TinNet models

In [None]:
# Loading modules

from __future__ import print_function, division

import numpy as np
import multiprocessing

from ase import io
from ase.db import connect

from tinnet.regression.regression import Regression

In [None]:
# Setting variables

# Optimized hyperparameters
lr = 0.0044485033567158005
atom_fea_len = 106
n_conv = 9
h_fea_len = 60
n_h = 2

db = connect('../Database.db')

d_cen = []
full_width = []
tabulated_d_cen_inf = []
tabulated_full_width_inf = []
tabulated_mulliken = []
tabulated_site_index = []
tabulated_v2dd = []
tabulated_v2ds = []
atom_fea = []
nbr_fea = []
nbr_fea_idx = []
tabulated_padding_fillter = []

for r in db.select():
    d_cen += [r['data']['d_cen']]
    full_width += [r['data']['full_width']]
    tabulated_d_cen_inf += [r['data']['tabulated_d_cen_inf']]
    tabulated_full_width_inf += [r['data']['tabulated_full_width_inf']]
    tabulated_mulliken += [r['data']['tabulated_mulliken']]
    tabulated_site_index += [r['data']['tabulated_site_index']]
    tabulated_v2dd += [r['data']['tabulated_v2dd']]
    tabulated_v2ds += [r['data']['tabulated_v2ds']]
    atom_fea += [r['data']['atom_fea']]
    nbr_fea += [r['data']['nbr_fea']]
    nbr_fea_idx += [r['data']['nbr_fea_idx']]
    tabulated_padding_fillter += [r['data']['tabulated_padding_fillter']]

idx = np.arange(len(d_cen))

idx_1 = idx[:-38]
idx_2 = idx[-38:] # Last 38 images are pure metals

num = int(len(idx_1)*1.00) # % of database for training (1.00 means 100%)
np.random.seed(12345)
np.random.shuffle(idx_1)

idx_1 = idx_1[0:num]

idx = np.sort(np.concatenate((idx_1,idx_2)))

np.savetxt('index.txt', idx)

d_cen = np.array([d_cen[i] for i in idx], dtype=np.float32)
full_width = np.array([full_width[i] for i in idx], dtype=np.float32)

tabulated_d_cen_inf = np.array([tabulated_d_cen_inf[i] for i in idx], dtype=np.float32)
tabulated_full_width_inf = np.array([tabulated_full_width_inf[i] for i in idx], dtype=np.float32)
tabulated_mulliken = np.array([tabulated_mulliken[i] for i in idx], dtype=np.float32)
tabulated_site_index = np.array([tabulated_site_index[i] for i in idx], dtype=np.int32)
tabulated_v2dd = np.array([tabulated_v2dd[i] for i in idx], dtype=np.float32)
tabulated_v2ds = np.array([tabulated_v2ds[i] for i in idx], dtype=np.float32)

atom_fea = [np.array(atom_fea[i], dtype=np.float32) for i in idx]
nbr_fea = [np.array(nbr_fea[i], dtype=np.float32) for i in idx]
nbr_fea_idx = [np.array(nbr_fea_idx[i], dtype=np.float32) for i in idx]
tabulated_padding_fillter = [np.array(tabulated_padding_fillter[i], dtype=np.int32) for i in idx]

check_ans_train_mae = np.zeros((10,10))
check_ans_train_mse = np.zeros((10,10))
check_ans_val_mae = np.zeros((10,10))
check_ans_val_mse = np.zeros((10,10))
check_ans_test_mae = np.zeros((10,10))
check_ans_test_mse = np.zeros((10,10))

In [None]:
# Test the performance of pre-trained TinNet models

for idx_validation in range(0,10):
    for idx_test in range(0,10):
        try:
            model = Regression(atom_fea,
                               nbr_fea,
                               nbr_fea_idx,
                               d_cen,
                               phys_model='moment',
                               optim_algorithm='AdamW',
                               weight_decay=0.0001,
                               idx_validation=idx_validation,
                               idx_test=idx_test,
                               lr=lr,
                               atom_fea_len=atom_fea_len,
                               n_conv=n_conv,
                               h_fea_len=h_fea_len,
                               n_h=n_h,
                               full_width=full_width,
                               tabulated_d_cen_inf=tabulated_d_cen_inf,
                               tabulated_padding_fillter=tabulated_padding_fillter,
                               tabulated_full_width_inf=tabulated_full_width_inf,
                               tabulated_mulliken=tabulated_mulliken,
                               tabulated_site_index=tabulated_site_index,
                               tabulated_v2dd=tabulated_v2dd,
                               tabulated_v2ds=tabulated_v2ds,
                               batch_size=64)
            
            check_ans_train_mae[idx_test,idx_validation],\
            check_ans_train_mse[idx_test,idx_validation],\
            check_ans_val_mae[idx_test,idx_validation],\
            check_ans_val_mse[idx_test,idx_validation],\
            check_ans_test_mae[idx_test,idx_validation],\
            check_ans_test_mse[idx_test,idx_validation] = model.check_loss()
        
        except:
            pass

np.savetxt('check_ans_train_mae.txt', check_ans_train_mae)
np.savetxt('check_ans_train_mse.txt', check_ans_train_mse)
np.savetxt('check_ans_val_mae.txt', check_ans_val_mae)
np.savetxt('check_ans_val_mse.txt', check_ans_val_mse)
np.savetxt('check_ans_test_mae.txt', check_ans_test_mae)
np.savetxt('check_ans_test_mse.txt', check_ans_test_mse)