In [1]:
import json,os,shutil,yaml,torch
import torch.nn as nn


from models import create_model
from utils.train import train_model, test_model
from utils.util import get_optimizer, get_loss, get_scheduler
from utils.data_container import get_data_loader
from utils.preprocess import  preprocessing_for_metric



def train(conf, data_category):
    print(json.dumps(conf, indent=4))

    #os.environ["CUDA_VISIBLE_DEVICES"] = str(conf['device'])
    os.environ["CUDA_VISIBLE_DEVICES"] = str(1)
    device = torch.device(0)

    model_name = conf['model']['name']
    optimizer_name = conf['optimizer']['name']
    data_set = conf['data']['dataset']
    scheduler_name = conf['scheduler']['name']
    loss = get_loss(**conf['loss'])

    loss.to(device)


    support = preprocessing_for_metric(data_category=data_category, dataset=conf['data']['dataset'],
                                           Normal_Method=conf['data']['Normal_Method'], _len=conf['data']['_len'], **conf['preprocess'])
    model, trainer = create_model(model_name,
                                  loss,
                                  conf['model'][model_name],
                                  data_category,
                                  device,
                                  support)

    optimizer = get_optimizer(optimizer_name, model.parameters(), conf['optimizer'][optimizer_name]['lr'])
    scheduler = get_scheduler(scheduler_name, optimizer, **conf['scheduler'][scheduler_name])
    if torch.cuda.device_count() > 1:
        print("use ", torch.cuda.device_count(), "GPUS")
        model = nn.DataParallel(model)
    else:
        model.to(device)

    save_folder = os.path.join('save', conf['name'], f'{data_set}_{"".join(data_category)}', conf['tag'])
    run_folder = os.path.join('run', conf['name'], f'{data_set}_{"".join(data_category)}', conf['tag'])

    shutil.rmtree(save_folder, ignore_errors=True)
    os.makedirs(save_folder)
    shutil.rmtree(run_folder, ignore_errors=True)
    os.makedirs(run_folder)

    with open(os.path.join(save_folder, 'config.yaml'), 'w+') as _f:
        yaml.safe_dump(conf, _f)

    data_loader, normal = get_data_loader(**conf['data'], data_category=data_category, device=device,
                                          model_name=model_name)


    train_model(model=model,
                       dataloaders=data_loader,
                       trainer=trainer,
                       optimizer=optimizer,
                       normal=normal,
                       scheduler=scheduler,
                       folder=save_folder,
                       tensorboard_folder=run_folder,
                       device=device,
                       **conf['train'])
    test_model(folder=save_folder,
                      trainer=trainer,
                      model=model,
                      normal=normal,
                      dataloaders=data_loader,
                      device=device)


if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--config', required=True, type=str, default='dsgrn-config',
    #                     help='Configuration filename for restoring the model.')
    # parser.add_argument('--con', required=False, type=str2bool, default='False',
    #                     help='Test.')
    # parser.add_argument('--stage', required=True, type=str2bool, default='True',
    #                     help='Stage.')
    # args = parser.parse_args()
    con = 'evoconv2-config'
    data = ['bike']
    with open(os.path.join('config', f'{con}.yaml')) as f:
        conf = yaml.safe_load(f)
    train(conf, data)


{
    "name": "Evoconv2",
    "tag": "train1",
    "device": 0,
    "data": {
        "dataset": "nogrid",
        "batch_size": 32,
        "X_list": [
            12,
            11,
            10,
            9,
            8,
            7,
            6,
            5,
            4,
            3,
            2,
            1
        ],
        "Y_list": [
            0,
            1,
            2,
            3,
            4,
            5,
            6,
            7,
            8,
            9,
            10,
            11
        ],
        "_len": [
            672,
            672
        ],
        "Normal_Method": "Standard"
    },
    "preprocess": {
        "hidden_size": 20,
        "normalized_category": "randomwalk"
    },
    "train": {
        "epochs": 200,
        "max_grad_norm": 5,
        "early_stop_steps": 10
    },
    "optimizer": {
        "name": "Adam",
        "Adam": {
            "lr": 0.0005,
            "weight_decay": 0.0005,
            

train epoch:   0, train loss: 4.03245: : 94it [00:18,  5.17it/s]
validate epoch:   0, validate loss: 4.84322: : 21it [00:00, 24.38it/s]


Better model at epoch 0 recorded.


test  epoch:   0, test  loss: 5.05964: : 21it [00:00, 25.09it/s]
train epoch:   1, train loss: 3.18362: : 94it [00:18,  5.09it/s]
validate epoch:   1, validate loss: 4.50127: : 21it [00:00, 24.22it/s]


Better model at epoch 1 recorded.


test  epoch:   1, test  loss: 4.58771: : 21it [00:00, 24.25it/s]
train epoch:   2, train loss: 2.57099: : 94it [00:18,  5.09it/s]
validate epoch:   2, validate loss: 3.93373: : 21it [00:00, 23.26it/s]


Better model at epoch 2 recorded.


test  epoch:   2, test  loss: 4.06912: : 21it [00:00, 23.47it/s]
train epoch:   3, train loss: 2.43413: : 94it [00:18,  5.06it/s]
validate epoch:   3, validate loss: 3.86365: : 21it [00:00, 23.86it/s]


Better model at epoch 3 recorded.


test  epoch:   3, test  loss: 3.95097: : 21it [00:00, 24.00it/s]
train epoch:   4, train loss: 2.36485: : 94it [00:18,  5.04it/s]
validate epoch:   4, validate loss: 3.77715: : 21it [00:00, 23.82it/s]


Better model at epoch 4 recorded.


test  epoch:   4, test  loss: 3.90704: : 21it [00:00, 23.62it/s]
train epoch:   5, train loss: 2.31471: : 94it [00:18,  5.05it/s]
validate epoch:   5, validate loss: 3.67512: : 21it [00:00, 23.17it/s]


Better model at epoch 5 recorded.


test  epoch:   5, test  loss: 3.81424: : 21it [00:00, 23.24it/s]
train epoch:   6, train loss: 2.27164: : 94it [00:18,  5.07it/s]
validate epoch:   6, validate loss: 3.76392: : 21it [00:00, 23.25it/s]
test  epoch:   6, test  loss: 3.83059: : 21it [00:00, 23.49it/s]
train epoch:   7, train loss: 2.23666: : 94it [00:18,  5.08it/s]
validate epoch:   7, validate loss: 3.61424: : 21it [00:00, 23.61it/s]


Better model at epoch 7 recorded.


test  epoch:   7, test  loss: 3.77588: : 21it [00:00, 24.14it/s]
train epoch:   8, train loss: 2.19641: : 94it [00:18,  5.07it/s]
validate epoch:   8, validate loss: 3.73963: : 21it [00:00, 24.21it/s]
test  epoch:   8, test  loss: 3.88836: : 21it [00:00, 23.78it/s]
train epoch:   9, train loss: 2.16411: : 94it [00:18,  5.10it/s]
validate epoch:   9, validate loss: 3.71803: : 21it [00:00, 23.67it/s]
test  epoch:   9, test  loss: 3.66235: : 21it [00:00, 24.14it/s]
train epoch:  10, train loss: 2.1478: : 94it [00:18,  5.14it/s] 
validate epoch:  10, validate loss: 3.75021: : 21it [00:00, 24.50it/s]
test  epoch:  10, test  loss: 3.90184: : 21it [00:00, 23.69it/s]
train epoch:  11, train loss: 2.13332: : 94it [00:18,  5.08it/s]
validate epoch:  11, validate loss: 3.61969: : 21it [00:00, 23.94it/s]
test  epoch:  11, test  loss: 3.77611: : 21it [00:00, 23.25it/s]
train epoch:  12, train loss: 2.12137: : 94it [00:18,  5.15it/s]
validate epoch:  12, validate loss: 4.07126: : 21it [00:00, 24.08i

Better model at epoch 13 recorded.


test  epoch:  13, test  loss: 3.52041: : 21it [00:00, 23.43it/s]
train epoch:  14, train loss: 2.11689: : 94it [00:18,  5.06it/s]
validate epoch:  14, validate loss: 3.42892: : 21it [00:00, 23.89it/s]
test  epoch:  14, test  loss: 3.57165: : 21it [00:00, 23.93it/s]
train epoch:  15, train loss: 2.11848: : 94it [00:20,  4.69it/s]
validate epoch:  15, validate loss: 3.65103: : 21it [00:00, 23.18it/s]
test  epoch:  15, test  loss: 3.73981: : 21it [00:00, 23.39it/s]
train epoch:  16, train loss: 2.16337: : 94it [00:20,  4.66it/s]
validate epoch:  16, validate loss: 3.45228: : 21it [00:00, 23.81it/s]
test  epoch:  16, test  loss: 3.55079: : 21it [00:00, 23.85it/s]
train epoch:  17, train loss: 2.16607: : 94it [00:18,  5.12it/s]
validate epoch:  17, validate loss: 3.89601: : 21it [00:00, 24.05it/s]
test  epoch:  17, test  loss: 4.08273: : 21it [00:00, 23.95it/s]
train epoch:  18, train loss: 2.20238: : 94it [00:18,  5.09it/s]
validate epoch:  18, validate loss: 3.39927: : 21it [00:00, 23.67i

Better model at epoch 19 recorded.


test  epoch:  19, test  loss: 3.17132: : 21it [00:00, 23.75it/s]
train epoch:  20, train loss: 2.17513: : 94it [00:18,  5.08it/s]
validate epoch:  20, validate loss: 3.05775: : 21it [00:00, 23.81it/s]


Better model at epoch 20 recorded.


test  epoch:  20, test  loss: 3.1855: : 21it [00:00, 24.17it/s] 
train epoch:  21, train loss: 2.20703: : 94it [00:18,  5.08it/s]
validate epoch:  21, validate loss: 2.89842: : 21it [00:00, 24.03it/s]


Better model at epoch 21 recorded.


test  epoch:  21, test  loss: 2.99493: : 21it [00:00, 23.62it/s]
train epoch:  22, train loss: 2.23088: : 94it [00:18,  5.14it/s]
validate epoch:  22, validate loss: 2.92805: : 21it [00:00, 24.00it/s]
test  epoch:  22, test  loss: 3.06281: : 21it [00:00, 23.12it/s]
train epoch:  23, train loss: 2.2928: : 94it [00:18,  5.06it/s] 
validate epoch:  23, validate loss: 2.81972: : 21it [00:00, 23.76it/s]


Better model at epoch 23 recorded.


test  epoch:  23, test  loss: 2.91204: : 21it [00:00, 23.42it/s]
train epoch:  24, train loss: 2.31453: : 94it [00:18,  5.07it/s]
validate epoch:  24, validate loss: 2.94844: : 21it [00:00, 23.72it/s]
test  epoch:  24, test  loss: 3.08155: : 21it [00:00, 23.94it/s]
train epoch:  25, train loss: 2.31486: : 94it [00:18,  5.11it/s]
validate epoch:  25, validate loss: 2.85655: : 21it [00:00, 24.09it/s]
test  epoch:  25, test  loss: 2.95855: : 21it [00:00, 23.78it/s]
train epoch:  26, train loss: 2.30103: : 94it [00:18,  5.12it/s]
validate epoch:  26, validate loss: 2.84261: : 21it [00:00, 23.58it/s]
test  epoch:  26, test  loss: 2.97707: : 21it [00:00, 23.92it/s]
train epoch:  27, train loss: 2.29467: : 94it [00:18,  5.08it/s]
validate epoch:  27, validate loss: 2.78856: : 21it [00:00, 24.23it/s]


Better model at epoch 27 recorded.


test  epoch:  27, test  loss: 2.91454: : 21it [00:00, 23.85it/s]
train epoch:  28, train loss: 2.30832: : 94it [00:18,  5.11it/s]
validate epoch:  28, validate loss: 2.80712: : 21it [00:00, 23.78it/s]
test  epoch:  28, test  loss: 2.92788: : 21it [00:00, 23.91it/s]
train epoch:  29, train loss: 2.30344: : 94it [00:18,  5.09it/s]
validate epoch:  29, validate loss: 2.73655: : 21it [00:00, 23.94it/s]


Better model at epoch 29 recorded.


test  epoch:  29, test  loss: 2.84966: : 21it [00:00, 24.08it/s]
train epoch:  30, train loss: 2.30469: : 94it [00:18,  5.09it/s]
validate epoch:  30, validate loss: 2.81835: : 21it [00:00, 23.62it/s]
test  epoch:  30, test  loss: 2.96549: : 21it [00:00, 23.15it/s]
train epoch:  31, train loss: 2.30248: : 94it [00:18,  5.09it/s]
validate epoch:  31, validate loss: 2.75764: : 21it [00:00, 23.51it/s]
test  epoch:  31, test  loss: 2.88399: : 21it [00:00, 23.83it/s]
train epoch:  32, train loss: 2.30587: : 94it [00:18,  5.07it/s]
validate epoch:  32, validate loss: 2.78812: : 21it [00:00, 23.52it/s]
test  epoch:  32, test  loss: 2.91382: : 21it [00:00, 23.58it/s]
train epoch:  33, train loss: 2.29851: : 94it [00:18,  5.09it/s]
validate epoch:  33, validate loss: 2.76151: : 21it [00:00, 23.99it/s]
test  epoch:  33, test  loss: 2.88916: : 21it [00:00, 23.69it/s]
train epoch:  34, train loss: 2.30471: : 94it [00:18,  5.12it/s]
validate epoch:  34, validate loss: 2.76932: : 21it [00:00, 23.86i

cost 925.2539465436712 seconds
model of epoch 29 successfully saved at `save/Evoconv2/nogrid_bike/train1/best_model.pkl`


21it [00:00, 26.11it/s]


test results:
{
    "MAE": {
        "horizon-0": 1.6422478423134004,
        "horizon-1": 1.6604370818195249,
        "horizon-2": 1.6855596587204797,
        "horizon-3": 1.705145652646017,
        "horizon-4": 1.729362078095959,
        "horizon-5": 1.7614194265538932,
        "horizon-6": 1.7965538894155764,
        "horizon-7": 1.8199155542248375,
        "horizon-8": 1.8344108700897193,
        "horizon-9": 1.8426612463911918,
        "horizon-10": 1.8549867944384613,
        "horizon-11": 1.8767794858569422
    },
    "RMSE": {
        "horizon-0": 2.5946983499756686,
        "horizon-1": 2.63771437043956,
        "horizon-2": 2.702657295996356,
        "horizon-3": 2.7616343460411956,
        "horizon-4": 2.8180497434612652,
        "horizon-5": 2.8852206399162124,
        "horizon-6": 2.958441310368821,
        "horizon-7": 2.9975772051391507,
        "horizon-8": 3.023009919604782,
        "horizon-9": 3.035125960464815,
        "horizon-10": 3.05490210038858,
        "horizo