In [9]:
from lib.dataloader import get_dataloader
from model.models import STSSL
import yaml
import argparse
from lib.utils import (
    init_seed,
    get_model_params,
    load_graph, 
)

class Args:
    def __init__(self):
        self.data_dir = 'data'
        self.dataset = 'NYCTaxi'
        self.batch_size = 32
        self.test_batch_size = 32
        self.device = 'cuda'
        self.seed = 1
        self.mode = 'train'
        self.best_path = None
        self.debug = False
        self.input_length = 35
        self.graph_file = 'data/NYCTaxi/adj_mx.npz'
        self.num_nodes = 200
        self.d_input = 2
        self.d_output = 2
        self.d_model = 64
        self.dropout = 0.1
        self.percent = 0.1
        self.shm_temp = 0.5
        self.nmb_prototype = 4
        self.yita = 0.5
        self.epochs = 1000
        self.lr_init = 0.001
        self.early_stop = True
        self.early_stop_patience = 35
        self.grad_norm = True
        self.max_grad_norm = 5
        self.use_dwa = True
        self.temp = 2
        self.graph_init = 'shared_lpe_raw'
        self.self_attention_flag = True
        self.cross_attention_flag = False
        self.feedforward_flag = False
        self.layer_norm_flag = False
        self.additional_sa_flag = False
        self.learnable_flag = False
        self.cheb_order = 3
args = Args()
dataloader = get_dataloader(
        data_dir=args.data_dir, 
        dataset=args.dataset, 
        batch_size=args.batch_size, 
        test_batch_size=args.test_batch_size,
        scalar_type='Standard'
    )


model = STSSL(args).to(args.device)
model_parameters = get_model_params([model])

data['x_train'].shape:  (1912, 35, 200, 2) (1912, 1, 200, 2)


In [10]:
import torch
import numpy as np
from lib.metrics import test_metrics

@staticmethod
def test(model, dataloader, scaler, graph, logger, args):
    model.eval()
    y_pred = []
    y_true = []
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            repr1, repr2 = model(data, graph)                
            pred_output = model.predict(repr1, repr2)

            y_true.append(target)
            y_pred.append(pred_output)
    y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
    y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))

    test_results = []
    # inflow
    mae, mape, masked_mae, masked_counts, unmasked_counts = test_metrics(y_pred[..., 0], y_true[..., 0])
    print("INFLOW, MAE: {:.2f}, MAPE: {:.4f}%, masked_MAE: {:.2f}, masked_counts: {}, unmasked_counts: {}".format(mae, mape*100, masked_mae, masked_counts, unmasked_counts))
    test_results.append([mae, mape])
    # outflow 
    mae, mape, masked_mae, _, _ = test_metrics(y_pred[..., 1], y_true[..., 1])
    print("OUTFLOW, MAE: {:.2f}, MAPE: {:.4f}%, masked_MAE: {:.2f}".format(mae, mape*100, masked_mae))
    test_results.append([mae, mape]) 

    return np.stack(test_results, axis=0)

In [11]:
# model_path = r"D:\omer\ST-SSL\experiments\NYCTaxi\pred__seed=1\20240607-132543\8_neighbours adj_mx, 2 sa on 128\best_model.pth"
model_path = r"D:\omer\ST-SSL\experiments\NYCTaxi\pred__seed=1\20240607-131720\shared_lpe_raw adj_mx, 2 sa on 128\best_model.pth"
model = STSSL(args).to(args.device)
model_parameters = get_model_params([model])
model.load_state_dict(torch.load(model_path)["model"])
model.to(args.device)
model.eval()

test_results = test(model, dataloader['test'], dataloader['scaler'], load_graph(args.graph_file), None, args)

INFLOW, MAE: 11.68, MAPE: 16.1910%, masked_MAE: 16.86, masked_counts: 33712, unmasked_counts: 75488
OUTFLOW, MAE: 9.64, MAPE: 16.9679%, masked_MAE: 17.96


In [4]:
len(dataloader['test'])

18

In [6]:
32*18*200 - 33712

81488