In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
sys.path.append('..')
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

from dataset.nyc_taxi.nyctaxi_dataset import NYCTaxiDataset

from metric_utils import *

In [3]:
import warnings
warnings.filterwarnings('ignore')

# Main Experiments: Fully Observed Input

## YellowCab & GreenCab

In [4]:
nyctaxi_all_results = collect_results(['NYCTaxi', 'NYCTaxiGreen'], [1.0,])
modelnames=[
                'HA', 'Static', 
                'GRU', 'Informer_nodt', 'Graph WaveNet', 'MTGNN', 'KoopmanAE',
                'Multitask_GWplusCKO_gate_ups_ds_convfusion'
            ]
df_lat = latex_results_multihorizon_multikr(nyctaxi_all_results, modelnames, datanames=('NYCTaxi', 'NYCTaxiGreen'),
                                         reslist=('30min',), metriclist=('rmse', 'mae',), 
                                 horizon_name_dict={1: '30min', 12: '6h', 48: '1d', 240: '5d', 480: '10d'},
                                 model_name_dict={
                                     'Informer_nodt': 'Informer',
                                     'Multitask_GWplusCKO_gate': '\modelshortname (no ups/ds)',
                                     'Multitask_GWplusCKO_gate_ups_ds_convfusion': '\modelshortname'
                                 },
                                         data_name_dict={'NYCTaxi': 'YellowCab', 'NYCTaxiGreen': 'GreenCab'}, drop_obs_ratio=True
                                )
print(df_lat)

mean: [123.28514811 123.25635562]
var: [65351.27791198 58547.75001079]
2019-04-27 00:00:00
mean: [6.88223632 6.87954762]
var: [243.2088257  186.28749241]
2019-04-27 00:00:00
\begin{tabular}{cccccccccccccc}
\toprule
         &     &      &      HA &                       Static &            GRU &                           Informer &                      Graph WaveNet &          MTGNN &      KoopmanAE &         \modelshortname &    RelErr &  RelErrGW \\
Data & Horizon & Metric &         &                              &                &                                    &                                    &                &                &                         &           &           \\
\midrule
\multirow{8}{*}{YellowCab} & \multirow{2}{*}{30min} & MAE &  19.428 &  \underline{\textit{12.499}} &  22.690(1.848) &                      20.240(1.153) &                      20.261(0.839) &  22.258(1.167) &  15.451(0.349) &  \textbf{12.265(0.641)} &   -1.87\% &  -39.46\% \\
         &     

## Solar Energy

In [5]:
solar_energy_all_results = collect_results(['Solar Energy 10min '], [1.0,])
modelnames=[
                'HA', 'Static', 
                'GRU', 'Informer', 'Graph WaveNet', 'MTGNN', 'KoopmanAE',
                'Multitask_GWplusCKO_gate_ups_ds_convfusion'
            ]
df_lat = latex_results_multihorizon_multikr(solar_energy_all_results, modelnames, datanames=('Solar Energy 10min ',),
                                         reslist=('10min',), metriclist=('rmse', 'mae',), 
                                 horizon_name_dict={1: '10min', 6: '1h', 36: '6h', 432: '3d'},
                                 model_name_dict={
                                     'Informer_nodt': 'Informer',
                                     'Multitask_GWplusCKO_gate': '\modelshortname (no ups/ds)',
                                     'Multitask_GWplusCKO_gate_ups_ds_convfusion': '\modelshortname'
                                 },
                                         data_name_dict={'NYCTaxi': 'YellowCab', 'NYCTaxiGreen': 'GreenCab', 'Solar Energy 10min ': 'Solar Energy'}, drop_obs_ratio=True
                                )
print(df_lat)

mean: [23.68199604]
var: [7941.12710593]
2006-10-10 00:00:00
\begin{tabular}{cccccccccccccc}
\toprule
             &    &      &                            HA &   Static &              GRU &                             Informer &   Graph WaveNet &            MTGNN &       KoopmanAE &          \modelshortname &    RelErr &  RelErrGW \\
Data & Horizon & Metric &                               &          &                  &                                      &                 &                  &                 &                          &           &           \\
\midrule
\multirow{8}{*}{Solar Energy} & \multirow{2}{*}{10min} & MAE &   \underline{\textit{53.066}} &   63.137 &    93.171(3.893) &                        68.768(4.067) &  124.946(0.226) &   117.190(5.833) &  151.358(2.996) &   \textbf{45.365(1.094)} &  -14.51\% &  -63.69\% \\
             &    & RMSE &  \underline{\textit{143.272}} &  188.282 &   147.998(4.483) &                       148.671(6.116) &  310.096(0.197) &  29

# With Partially Observed Input

In [8]:
nyctaxi_partial_results = collect_results(['NYCTaxi',], [0.8,0.6,0.4,0.2])
modelnames=[
                'HA', 'Static', 
#                 'GRU',
    'Informer_nodt', 'Graph WaveNet', 'MTGNN', 'KoopmanAE',
                'Multitask_GWplusCKO_gate_ups_ds_convfusion'
            ]
print(latex_results_multihorizon_multikr(nyctaxi_partial_results, modelnames, datanames=('NYCTaxi',),
                                         reslist=('30min',), metriclist=('rmse', 'mae',), 
                                 horizon_name_dict={1: '30min', 12: '6h', 48: '1d', 240: '5d', 480: '10d'},
                                 model_name_dict={
                                     'Informer_nodt': 'Informer',
                                     'Multitask_GWplusCKO_gate': '\modelshortname (no ups/ds)',
                                     'Multitask_GWplusCKO_gate_ups_ds_convfusion': '\modelshortname'
                                 },
                                         data_name_dict={'NYCTaxi': 'YellowCab', 'NYCTaxiGreen': 'GreenCab'}, drop_obs_ratio=False, horizon_list=[480], dropdata=True, drophorizon=True
                                ))

mean: [123.28514811 123.25635562]
var: [65351.27791198 58547.75001079]
2019-04-27 00:00:00
mean: [123.28514811 123.25635562]
var: [65351.27791198 58547.75001079]
2019-04-27 00:00:00
mean: [123.28514811 123.25635562]
var: [65351.27791198 58547.75001079]
2019-04-27 00:00:00
mean: [123.28514811 123.25635562]
var: [65351.27791198 58547.75001079]
2019-04-27 00:00:00


KeyError: '30min'

In [10]:
print(latex_results_multihorizon_multikr(nyctaxi_partial_results, modelnames, datanames=('NYCTaxi',),
                                         reslist=('30min',), metriclist=('rmse', 'mae',), 
                                 horizon_name_dict={1: '30min', 12: '6h', 48: '1d', 240: '5d', 480: '10d'},
                                 model_name_dict={
                                     'Informer_nodt': 'Informer',
                                     'Multitask_GWplusCKO_gate': '\modelshortname (no ups/ds)',
                                     'Multitask_GWplusCKO_gate_ups_ds_convfusion': '\modelshortname'
                                 },
                                         data_name_dict={'NYCTaxi': 'YellowCab', 'NYCTaxiGreen': 'GreenCab'}, drop_obs_ratio=False, horizon_list=[480], dropdata=True, drophorizon=True
                                ))

\begin{tabular}{ccccccccccc}
\toprule
    &      &      HA &                       Static &       Informer &                       Graph WaveNet &          MTGNN &                           KoopmanAE &         \modelshortname &    RelErr &  RelErrGW \\
Obs Ratio & Metric &         &                              &                &                                     &                &                                     &                         &           &           \\
\midrule
\multirow{2}{*}{0.8} & MAE &  21.952 &  \underline{\textit{14.206}} &  19.042(0.982) &                       16.528(0.291) &  18.426(0.878) &                       18.434(1.573) &  \textbf{12.560(0.655)} &  -11.59\% &  -24.01\% \\
    & RMSE &  35.522 &  \underline{\textit{26.661}} &  29.209(1.293) &                       31.099(0.435) &  32.725(2.018) &                       29.794(2.198) &  \textbf{22.579(0.712)} &  -15.31\% &  -27.40\% \\
\cline{1-11}
\multirow{2}{*}{0.6} & MAE &  21.971 &  \underline{\text

# Ablation Study

In [None]:
modelnames=[
    'Multitask_nosa',
    'Multitask_ds_ups_convfusion',
    'Multitask_GWplusCKO_gate',
    'Multitask_GWplusCKO_gate_ups_ds_convfusion'
]
all_results_sn = parse_model_results(
    dataname='NYCTaxi',
    modelnames=modelnames,
    horizons=[1, 12, 48, 480],
    res_list=('30min',),
    threshold=1e-6,
    corresponding=True,
    target_node_res=0
)
pd.options.display.max_colwidth = 100
print(latex_results_multihorizon(all_results_sn, modelnames, reslist=('30min',), metriclist=('mae', 'rmse',), 
                                 horizon_name_dict={1: '30min', 12: '6h', 48: '1d', 240: '5d', 480: '10d'},
                                 model_name_dict={
                                        'Multitask_ds_ups_convfusion': 'w/o Koopman',
                                        'Multitask_GWplusCKO_gate': 'w/o ups/ds',
                                        'Multitask_GWplusCKO_gate_ups_ds_convfusion': '\modelshortname',
                                        'Multitask_nosa': 'w/o Self-Attn'
                                 }, with_relative=True
                                ))

In [None]:
modelnames=[
    'Multitask_nosa',
    'Multitask_ds_ups_convfusion',
    'Multitask_GWplusCKO_gate',
    'Multitask_GWplusCKO_gate_ups_ds_convfusion'
]
all_results_sn = parse_model_results(
    dataname='NYCTaxi',
    modelnames=modelnames,
    horizons=[1, 12, 48, 480],
    res_list=('30min',),
    threshold=1e-6,
    corresponding=True,
    target_node_res=0
)
pd.options.display.max_colwidth = 100
print(latex_results_multihorizon(all_results_sn, modelnames, reslist=('30min',), metriclist=('mae', 'rmse',), 
                                 horizon_name_dict={1: '30min', 12: '6h', 48: '1d', 240: '5d', 480: '10d'},
                                 model_name_dict={
                                        'Multitask_ds_ups_convfusion': 'w/o Koopman',
                                        'Multitask_GWplusCKO_gate': 'w/o ups/ds',
                                        'Multitask_GWplusCKO_gate_ups_ds_convfusion': '\modelshortname',
                                        'Multitask_nosa': 'w/o Self-Attn'
                                 }, with_relative=True
                                ))

# Additional Results for Rebuttal

## With Partially Observed Input (NYCTaxiGreen and Solar Energy 10min)

In [11]:
nyctaxi_partial_results = collect_results(['NYCTaxiGreen',], [0.8, 0.6, 0.4, 0.2])
modelnames=[
                'HA', 'Static', 
                'GRU', 'Informer_nodt', 'Graph WaveNet', 'MTGNN', 'KoopmanAE',
                'Multitask_GWplusCKO_gate_ups_ds_convfusion'
            ]
print(latex_results_multihorizon_multikr(nyctaxi_partial_results, modelnames, datanames=('NYCTaxiGreen',),
                                         reslist=('30min',), metriclist=('rmse', 'mae',), 
                                 horizon_name_dict={1: '30min', 12: '6h', 48: '1d', 240: '5d', 480: '10d'},
                                 model_name_dict={
                                     'Informer_nodt': 'Informer',
                                     'Multitask_GWplusCKO_gate': '\modelshortname (no ups/ds)',
                                     'Multitask_GWplusCKO_gate_ups_ds_convfusion': '\modelshortname'
                                 },
                                         data_name_dict={'NYCTaxi': 'YellowCab', 'NYCTaxiGreen': 'GreenCab'}, drop_obs_ratio=False, horizon_list=[480], dropdata=True, drophorizon=True
                                ))

mean: [6.88223632 6.87954762]
var: [243.2088257  186.28749241]
2019-04-27 00:00:00
mean: [6.88223632 6.87954762]
var: [243.2088257  186.28749241]
2019-04-27 00:00:00
mean: [6.88223632 6.87954762]
var: [243.2088257  186.28749241]
2019-04-27 00:00:00
mean: [6.88223632 6.87954762]
var: [243.2088257  186.28749241]
2019-04-27 00:00:00
\begin{tabular}{ccccccccccccc}
\toprule
    &      &     HA & Static &           GRU &      Informer &                      Graph WaveNet &         MTGNN &     KoopmanAE &        \modelshortname &    RelErr &  RelErrGW \\
Obs Ratio & Metric &        &        &               &               &                                    &               &               &                        &           &           \\
\midrule
\multirow{2}{*}{0.8} & MAE &  3.765 &  2.084 &  2.666(0.044) &  1.962(0.069) &  \underline{\textit{1.770(0.007)}} &  2.086(0.111) &  2.782(0.553) &  \textbf{1.733(0.021)} &   -2.09\% &   -2.09\% \\
    & RMSE &  5.703 &  3.092 &  3.898(0.075) &  2

In [12]:
solar_energy_all_results = collect_results(['Solar Energy 10min '], [0.8, 0.6, 0.4, 0.2,])
modelnames=[
                'HA', 'Static', 
                'GRU', 'Informer', 'Graph WaveNet', 'MTGNN', 'KoopmanAE',
                'Multitask_GWplusCKO_gate_ups_ds_convfusion'
            ]
df_lat = latex_results_multihorizon_multikr(solar_energy_all_results, modelnames, datanames=('Solar Energy 10min ',),
                                         reslist=('10min',), metriclist=('rmse', 'mae',), 
                                 horizon_name_dict={1: '10min', 6: '1h', 36: '6h', 432: '3d'},
                                 model_name_dict={
                                     'Informer_nodt': 'Informer',
                                     'Multitask_GWplusCKO_gate': '\modelshortname (no ups/ds)',
                                     'Multitask_GWplusCKO_gate_ups_ds_convfusion': '\modelshortname'
                                 },
                                         data_name_dict={'NYCTaxi': 'YellowCab', 'NYCTaxiGreen': 'GreenCab', 'Solar Energy 10min ': 'Solar Energy'}, drop_obs_ratio=False, horizon_list=[432], dropdata=True, drophorizon=True
                                )
print(df_lat)

mean: [23.68199604]
var: [7941.12710593]
2006-10-10 00:00:00
mean: [23.68199604]
var: [7941.12710593]
2006-10-10 00:00:00
mean: [23.68199604]
var: [7941.12710593]
2006-10-10 00:00:00
mean: [23.68199604]
var: [7941.12710593]
2006-10-10 00:00:00
\begin{tabular}{ccccccccccccc}
\toprule
    &      &       HA &   Static &              GRU &                             Informer &   Graph WaveNet &                                 MTGNN &       KoopmanAE &          \modelshortname &    RelErr &  RelErrGW \\
Obs Ratio & Metric &          &          &                  &                                      &                 &                                       &                 &                          &           &           \\
\midrule
\multirow{2}{*}{0.8} & MAE &  200.729 &  261.134 &   118.284(8.741) &   \underline{\textit{91.457(2.998)}} &  135.627(0.021) &                        122.947(8.833) &  147.520(0.747) &   \textbf{68.340(1.260)} &  -25.28\% &  -49.61\% \\
    & RMSE &  255.24

In [None]:
nyctaxi_all_results = collect_results(['NYCTaxi', 'NYCTaxiGreen'], [1.0,])
modelnames=[
                'HA', 'Static', 
                'GRU', 'Informer_nodt', 'Graph WaveNet', 'MTGNN', 'KoopmanAE',
                'Multitask_GWplusCKO_gate_ups_ds_convfusion'
            ]
df_lat = latex_results_multihorizon_multikr(nyctaxi_all_results, modelnames, datanames=('NYCTaxi', 'NYCTaxiGreen'),
                                         reslist=('30min',), metriclist=('rmse', 'mae',), 
                                 horizon_name_dict={1: '30min', 12: '6h', 48: '1d', 240: '5d', 480: '10d'},
                                 model_name_dict={
                                     'Informer_nodt': 'Informer',
                                     'Multitask_GWplusCKO_gate': '\modelshortname (no ups/ds)',
                                     'Multitask_GWplusCKO_gate_ups_ds_convfusion': '\modelshortname'
                                 },
                                         data_name_dict={'NYCTaxi': 'YellowCab', 'NYCTaxiGreen': 'GreenCab'}, drop_obs_ratio=True
                                )
print(df_lat)

In [None]:
solar_energy_all_results = collect_results(['Solar Energy 10min '], [1.0,])
modelnames=[
                'HA', 'Static', 
                'GRU', 'Informer', 'Graph WaveNet', 'MTGNN', 'KoopmanAE',
                'Multitask_GWplusCKO_gate_ups_ds_convfusion'
            ]
df_lat = latex_results_multihorizon_multikr(solar_energy_all_results, modelnames, datanames=('Solar Energy 10min ',),
                                         reslist=('10min',), metriclist=('rmse', 'mae',), 
                                 horizon_name_dict={1: '10min', 6: '1h', 36: '6h', 432: '3d'},
                                 model_name_dict={
                                     'Informer_nodt': 'Informer',
                                     'Multitask_GWplusCKO_gate': '\modelshortname (no ups/ds)',
                                     'Multitask_GWplusCKO_gate_ups_ds_convfusion': '\modelshortname'
                                 },
                                         data_name_dict={'NYCTaxi': 'YellowCab', 'NYCTaxiGreen': 'GreenCab', 'Solar Energy 10min ': 'Solar Energy'}, drop_obs_ratio=True
                                )
print(df_lat)