In [1]:
from config.cfg_algorithm import CFG_MODEL
from config.cfg_general import CFG_GENERAL
from config.cfg_dataset import CFG_DATASET
import os
from runner.builder import build_model
from data.data_loader import build_dataset
from utils.__init__ import config_md5
import torch

data = ["PEMS08","T-Drive","CHIBike","NYCTaxi"][1]
algo = "STID"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
epochs = 200
val_interval = 1
out_dir = "./checkpoints"


model_args = CFG_MODEL[algo](data, CFG_DATASET[data]["NUM_NODES"],
                        CFG_GENERAL.DATASET.HISTORY_SEQ_LEN,
                        CFG_GENERAL.DATASET.FUTURE_SEQ_LEN,
                        not CFG_GENERAL.DATASET.NORM_EACH_CHANNEL,
                        CFG_DATASET[data].get("STEPS_PER_DAY",None))

md5 = config_md5(CFG_GENERAL, CFG_DATASET[data], model_args)
ckpt_save_dir = os.path.join(out_dir, "_".join([data, algo, str(epochs)]),md5)
if not os.path.isdir(ckpt_save_dir):
    os.makedirs(ckpt_save_dir)

models = build_model(CFG_GENERAL, model_args)
datasets = build_dataset(CFG_GENERAL, CFG_DATASET[data])
future_data, history_data = next(iter(datasets["train"]))
future_data, history_data = future_data.to(device), history_data.to(device)

model = models["model"].to(device)
out = model(history_data,history_data,0,0,False)
out.shape,future_data.shape, history_data.shape


CUDA initialization: The NVIDIA driver on your system is too old (found version 11060). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)



(torch.Size([16, 1, 1024, 1]),
 torch.Size([16, 1, 1024, 3]),
 torch.Size([16, 6, 1024, 3]))

In [3]:
from utils.checkpoint import resume_model
from runner.builder import build_meter
from utils.logger import init_logger
from utils.meter_pool import MeterPool
import time
from data.transform import re_standard_transform
from tqdm import tqdm
from utils.metrics import masked_mae, masked_mape, masked_rmse
import numpy as np

logger = init_logger(f"{data}_{algo}_{epochs}",ckpt_save_dir)
ckpt_path = f'{algo}_{data}_best_val_MAE.pt'
checkpoint_dict = resume_model(ckpt_save_dir,ckpt_path)
models["model"].load_state_dict(checkpoint_dict['model_state_dict'], strict=True)
logger.info('start test')
models["model"].eval()
mode = "test"
epoch = 0
metrics = CFG_GENERAL.METRICS
tensbd = None


forward_features = models["forward_features"]
target_features = models["target_features"]
data_loader = datasets[mode]
scaler = datasets["scaler"]

model = models["model"]
meters = MeterPool()
build_meter(meters,metrics,mode)


if mode=="test":
    reals = []
    preds = []

test_start_time = time.time()
results = {k:[] for k in metrics.keys()}
for (future_data, history_data) in tqdm(data_loader):
    history_data = history_data.to(device)
    future_data = future_data.to(device)
        
    history_data = history_data[:,:,:,forward_features]
    future_data_4_dec = future_data[:,:,:,forward_features]
    prediction_data = model(history_data=history_data, future_data=future_data_4_dec, batch_seen=1, epoch=None, train=False)
    prediction = re_standard_transform(prediction_data[:,:,:,target_features],**scaler["args"])
    real_value = re_standard_transform(future_data[:,:,:,target_features],**scaler["args"])

    if mode=="test":
        reals.append(real_value[...,0].detach().cpu())
        preds.append(prediction[...,0].detach().cpu())

    for metric_name, metric_func in metrics.items():
        metric_item = metric_func(*[prediction, real_value, datasets['null_val'], datasets['mask_val']])
        meters.update(f"{mode}_{metric_name}", metric_item.item())
        results[metric_name].append(metric_func(*[prediction, real_value, np.nan,5]).item())

test_end_time = time.time()
meters.update(f"{mode}_time", test_end_time - test_start_time)
meters.print_meters(logger)
meters.plt_meters(epoch,tensbd)


2024-09-02 08:50:20,087 - NYCTaxi_STDCN_200 - INFO - start test
100%|██████████| 219/219 [00:02<00:00, 96.65it/s]
2024-09-02 08:50:22,357 - NYCTaxi_STDCN_200 - INFO - Result : [test_time: 2.27 (s), test_MAE: 14.247, test_MAPE: 14.131, test_RMSE: 22.690]


In [4]:
import numpy as np
for k in results.keys():
    print(np.array(results[k]).mean())

12.982404467177718
16.782948328479783
21.42200029277366
