In [6]:
import numpy as np
import os, gc, json
import torch.nn
from torch.utils.data import DataLoader
from util.input_data import Dataset
from util.AdaBound import AdaBound
from torch.utils.tensorboard import SummaryWriter

def exec_model(
    scale,
    model_type,
    comment='',
    lr = 1e-5,
    wd = 1e-7,
    tries = 1,
    root_model = 'd:/MODELS/202204/nmm',
    root_data  = 'c:/WORKSPACE_KRICT/DATA/data_snu',
    num_epochs = 300,
    batch_size = 128,
    train_ratio = 0.7,
    valid_ratio = 0.2,
):
    gc.collect()

    dataset = Dataset()
    dataset.load_dataset(os.path.join(root_data, f'inputdata_{scale}.pickle'), silent=True)

    for n in range(0, tries):
        rseed  = 35 + n
        train_data, valid_data, test_data = dataset.train_test_split(train_ratio=train_ratio, 
                                                                     valid_ratio=valid_ratio,
                                                                     rseed=rseed)
        train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, 
                                    collate_fn=tr.collate_fn)
        val_data_loader = DataLoader(valid_data, batch_size=batch_size, collate_fn=tr.collate_fn)
        test_data_loader = DataLoader(test_data, batch_size=batch_size, collate_fn=tr.collate_fn)

        model = DistNN(dataset.n_atom_feats, dataset.n_rdf_feature, dataset.n_bdf_feature).cuda()
        optimizer = AdaBound(model.parameters(), lr=lr, weight_decay=wd)
        criterion = torch.nn.L1Loss()

        for i in range(99):
            root = os.path.join(root_model, model_type)
            if not os.path.isdir(root):
                os.makedirs(root)
            if '{}_{:02d}'.format(scale, i) not in ' '.join(os.listdir(root)):
                output_root = os.path.join(root, '{}_{:02d}'.format(scale, i))
                if len(comment) > 0: output_root += f'_{comment}'
                os.makedirs(output_root)
                break
        print(output_root)
        with open(os.path.join(output_root, 'params.json'),'w') as f:
            json.dump(dict(random_seed=rseed, learning_rate=lr, weight_decay=wd, 
                train_ratio=train_ratio, valid_ratio=valid_ratio, batch_size=batch_size), 
                f, indent=4)
        writer = SummaryWriter(output_root)
        #with torch.no_grad():
        #    dummy = iter(test_data_loader).next()
        #    writer.add_graph(model, dummy[:7])

        for epoch in range(1, num_epochs+1):
            train_loss, train_mae = tr.train(model, optimizer, train_data_loader, criterion)
            valid_loss, valid_mae, _, _, _ = tr.test(model, val_data_loader, criterion)
            print('Epoch [{}/{}]\tTrain/Valid Loss: {:.4f} / {:.4f}\tMAE: {:.4f} / {:.4f}'
                    .format(epoch, num_epochs, train_loss, valid_loss, train_mae, valid_mae))

            writer.add_scalar('train/loss', train_loss, epoch)
            writer.add_scalar('train/MAE', train_mae, epoch)
            writer.add_scalar('valid/loss', valid_loss, epoch)
            writer.add_scalar('valid/MAE', valid_mae, epoch)

            if epoch%20 == 0:
                torch.save(model.state_dict(), 
                           os.path.join(output_root, 'model.{:05d}.pt'.format(epoch)))
                _, _, idxs, targets, preds = tr.test(model, test_data_loader, criterion)
                np.savetxt(os.path.join(output_root, 'pred.{:05d}.txt'.format(epoch)), 
                           np.hstack([idxs, targets, preds]), delimiter=',')

In [8]:
from model.model_03r import DistNN
import util.trainer_log as tr

for scale in ['metal_FFF','metal_TFF','metal_TTT']:
    exec_model(scale=scale, model_type='M03R', comment='log', batch_size=256)

d:/MODELS/202204/nmm\M03R\metal_FFF_00_log
Epoch [1/300]	Train/Valid Loss: 2.4689 / 1.8823	MAE: 1.2529 / 1.1900
Epoch [2/300]	Train/Valid Loss: 1.5279 / 1.5210	MAE: 0.8525 / 1.1532
Epoch [3/300]	Train/Valid Loss: 0.7394 / 0.2886	MAE: 0.4728 / 0.2046
Epoch [4/300]	Train/Valid Loss: 0.2910 / 0.2549	MAE: 0.2431 / 0.1900
Epoch [5/300]	Train/Valid Loss: 0.2758 / 0.3123	MAE: 0.2350 / 0.1801
Epoch [6/300]	Train/Valid Loss: 0.2588 / 0.2366	MAE: 0.2104 / 0.1683
Epoch [7/300]	Train/Valid Loss: 0.2343 / 0.2405	MAE: 0.1706 / 0.1603
Epoch [8/300]	Train/Valid Loss: 0.2577 / 0.2418	MAE: 0.1972 / 0.1747
Epoch [9/300]	Train/Valid Loss: 0.2663 / 0.2814	MAE: 0.1824 / 0.1614
Epoch [10/300]	Train/Valid Loss: 0.2585 / 0.2929	MAE: 0.1852 / 0.2107
Epoch [11/300]	Train/Valid Loss: 0.3372 / 0.2259	MAE: 0.2368 / 0.1897
Epoch [12/300]	Train/Valid Loss: 0.3069 / 0.2901	MAE: 0.2015 / 0.1885
Epoch [13/300]	Train/Valid Loss: 0.2700 / 0.2175	MAE: 0.1805 / 0.1619
Epoch [14/300]	Train/Valid Loss: 0.2816 / 0.3345	MAE: 0.