Hyperparameter optimization for TinNet models

In [None]:
# Loading modules

from __future__ import print_function, division

import os

import numpy as np
import torch

from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.suggest import ConcurrencyLimiter
from ray.tune.suggest.bayesopt import BayesOptSearch
from ase import io
from ase.db import connect

from tinnet.regression.regression import Regression

In [None]:
# Train the network

class TrainTinNet(tune.Trainable):
    def _setup(self, config):
        
        self.lr = config.get('lr', 0.01)
        self.atom_fea_len = int(config.get('atom_fea_len', 64))
        self.n_conv = int(config.get('n_conv', 3))
        self.h_fea_len = int(config.get('h_fea_len', 128))
        self.n_h = int(config.get('n_h', 1))
        
        db = connect('../Database.db')
        
        d_cen = []
        full_width = []
        tabulated_d_cen_inf = []
        tabulated_full_width_inf = []
        tabulated_mulliken = []
        tabulated_site_index = []
        tabulated_v2dd = []
        tabulated_v2ds = []
        atom_fea = []
        nbr_fea = []
        nbr_fea_idx = []
        tabulated_padding_fillter = []
        
        for r in db.select():
            d_cen += [r['data']['d_cen']]
            full_width += [r['data']['full_width']]
            tabulated_d_cen_inf += [r['data']['tabulated_d_cen_inf']]
            tabulated_full_width_inf += [r['data']['tabulated_full_width_inf']]
            tabulated_mulliken += [r['data']['tabulated_mulliken']]
            tabulated_site_index += [r['data']['tabulated_site_index']]
            tabulated_v2dd += [r['data']['tabulated_v2dd']]
            tabulated_v2ds += [r['data']['tabulated_v2ds']]
            atom_fea += [np.array(r['data']['atom_fea'], dtype=np.float32)]
            nbr_fea += [np.array(r['data']['nbr_fea'], dtype=np.float32)]
            nbr_fea_idx += [np.array(r['data']['nbr_fea_idx'], dtype=np.float32)]
            tabulated_padding_fillter += [np.array(r['data']['tabulated_padding_fillter'], dtype=np.int32)]
        
        self.d_cen = np.array(d_cen, dtype=np.float32)
        self.full_width = np.array(full_width, dtype=np.float32)
        self.tabulated_d_cen_inf = np.array(tabulated_d_cen_inf, dtype=np.float32)
        self.tabulated_full_width_inf = np.array(tabulated_full_width_inf, dtype=np.float32)
        self.tabulated_mulliken = np.array(tabulated_mulliken, dtype=np.float32)
        self.tabulated_site_index = np.array(tabulated_site_index, dtype=np.int32)
        self.tabulated_v2dd = np.array(tabulated_v2dd, dtype=np.float32)
        self.tabulated_v2ds = np.array(tabulated_v2ds, dtype=np.float32)
        self.atom_fea = atom_fea
        self.nbr_fea = nbr_fea
        self.nbr_fea_idx = nbr_fea_idx
        self.tabulated_padding_fillter = tabulated_padding_fillter
    
    def _train(self):
        
        self.model = Regression(self.atom_fea,
                                self.nbr_fea,
                                self.nbr_fea_idx,
                                self.d_cen,
                                phys_model='moment',
                                optim_algorithm='AdamW',
                                weight_decay=0.0001,
                                idx_validation=0,
                                idx_test=1,
                                lr=self.lr,
                                atom_fea_len=self.atom_fea_len,
                                n_conv=self.n_conv,
                                h_fea_len=self.h_fea_len,
                                n_h=self.n_h,
                                full_width=self.full_width,
                                tabulated_d_cen_inf=self.tabulated_d_cen_inf,
                                tabulated_padding_fillter=self.tabulated_padding_fillter,
                                tabulated_full_width_inf=self.tabulated_full_width_inf,
                                tabulated_mulliken=self.tabulated_mulliken,
                                tabulated_site_index=self.tabulated_site_index,
                                tabulated_v2dd=self.tabulated_v2dd,
                                tabulated_v2ds=self.tabulated_v2ds,
                                batch_size=64)
        
        final_ans_val_mae, \
        final_ans_val_mse,\
        final_ans_test_mae, \
        final_ans_test_mse \
                = self.model.train(25000)
        
        np.savetxt('final_ans_val_mae_'
                   + str(self.lr)
                   + '_'
                   + str(self.atom_fea_len)
                   + '_'
                   + str(self.n_conv)
                   + '_'
                   + str(self.h_fea_len)
                   + '_'
                   + str(self.n_h)
                   + '.txt', [final_ans_val_mae])
        
        np.savetxt('final_ans_val_mse_'
                   + str(self.lr)
                   + '_'
                   + str(self.atom_fea_len)
                   + '_'
                   + str(self.n_conv)
                   + '_'
                   + str(self.h_fea_len)
                   + '_'
                   + str(self.n_h)
                   + '.txt', [final_ans_val_mse])
        
        np.savetxt('final_ans_test_mae_'
                   + str(self.lr)
                   + '_'
                   + str(self.atom_fea_len)
                   + '_'
                   + str(self.n_conv)
                   + '_'
                   + str(self.h_fea_len)
                   + '_'
                   + str(self.n_h)
                   + '.txt', [final_ans_test_mae])
        
        np.savetxt('final_ans_test_mse_'
                   + str(self.lr)
                   + '_'
                   + str(self.atom_fea_len)
                   + '_'
                   + str(self.n_conv)
                   + '_'
                   + str(self.h_fea_len)
                   + '_'
                   + str(self.n_h)
                   + '.txt', [final_ans_test_mse])
        
        return {'mean_loss': final_ans_test_mse}

    def _save(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, 'model.pth')
        return checkpoint_path

    def _restore(self, checkpoint_path):
        self.model.load_state_dict(torch.load(checkpoint_path))

In [None]:
if __name__ == '__main__':
    
    algo = BayesOptSearch(utility_kwargs={
        'kind': 'ucb',
        'kappa': 2.5,
        'xi': 0.0
    })
    algo = ConcurrencyLimiter(algo, max_concurrent=4)
    scheduler = AsyncHyperBandScheduler()
    
    analysis = tune.run(
        TrainTinNet,
        name='TrainTinNet',
        metric='mean_loss',
        mode='min',
        search_alg=algo,
        scheduler=scheduler,
        stop={
            'mean_loss': 0.001,
            'training_iteration': 20,
        },
        resources_per_trial={
            'cpu': 12,
            'gpu': 1
        },
        num_samples= 500,
        checkpoint_at_end=True,
        checkpoint_freq=20,
        config={
            'lr': tune.loguniform(lower=0.0010, upper=0.004, base=10),
            'atom_fea_len': tune.uniform(lower=10, upper=206),
            'n_conv': tune.uniform(lower=4, upper=11),
            'h_fea_len': tune.uniform(lower=10, upper=101),
            'n_h': tune.uniform(lower=1, upper=5),
        })
    
    print('Best config is:', analysis.get_best_config(metric='mean_loss',
                                                      mode='min'))