# (연구&서연) IT-STGCN – 실험결과시각화

SEOYEON CHOI  
2024-01-08

# Import

In [1]:
import pandas as pd
import numpy as np

import plotly.express as px
import plotly.graph_objects as go
import pickle

import plotly.io as pio

In [2]:
pd.options.plotting.backend = "plotly"
pio.templates.default = "plotly_white"

# Conditions

| Model         | FiveVTS | Chickenpox | Pedalme | Wikimath | Windmillsmall | MontevideoBus |
|----------------|--------|----------|--------|---------|------------|------------|
| Max iter.     | 30      | 30         | 30      | 30       | 30            | 30            |
| Epochs        | 50      | 50         | 50      | 50       | 50            | 50            |
| Lags          | 2       | 4          | 4       | 8        | 8             | 4             |
| Interpolation | linear  | linear     | nearest | linear   | linear        | nearest       |
| Filters       |         |            |         |          |               |               |
| GConvGRU      | 12      | 16         | 12      | 12       | 12            | 12            |
| GConvLSTM     | 12      | 32         | 2       | 64       | 16            | 12            |
| GCLSTM        | 4       | 16         | 4       | 64       | 16            | 12            |
| LRGCN         | 4       | 8          | 8       | 32       | 12            | 2             |
| DyGrEncoder   | 12      | 12         | 12      | 12       | 12            | 12            |
| EvolveGCNH    | No need | No need    | No need | No need  | No need       | No need       |
| EvolveGCNO    | No need | No need    | No need | No need  | No need       | No need       |
| TGCN          | 12      | 12         | 12      | 12       | 12            | 8             |
| DCRNN         | 2       | 16         | 8       | 12       | 4             | 12            |

# Data & 사용자정의함수

In [3]:
df_dataset = pd.DataFrame({
    'dataset':['fivenodes','chickenpox','pedalme','wikimath','windmillsmall','monte'],
    'node':[5,20,15,1068,11,675],
    'time':[200,522,36,731,17472,744]
})
df = pd.read_csv('./Body_Results.csv').iloc[:,1:]\
.merge(df_dataset)\
.assign(method = lambda df: df['method'].map({'STGCN':'STGNN','IT-STGCN':'IT-STGNN','GNAR':'GNAR'}))

In [4]:
df.head()

In [5]:
df2 = df.assign(model = lambda df: df.model.apply(lambda x: '' if x == "GNAR" else x))\
.eval('model = method+model')\
.assign(model = lambda df: df.model.str.replace("STGNN",""))
df2.head()

In [6]:
set(df.method)

{'GNAR', 'IT-STGNN', 'STGNN'}

In [7]:
def show_experiment_spec(df):
    df_cols = ['method','mrate','mtype','lags','nof_filters','inter_method','epoch','model']
    for dataset in df.dataset.unique():
        print(f'dataset: {dataset}')
        for col in df_cols: 
            df_query = df[df['dataset'] == dataset]
            print(f'{col}: {df_query[col].unique().tolist()}')
        print('---')

`-` 데이터세트별 실험셋팅

In [8]:
show_experiment_spec(df)

dataset: fivenodes
method: ['STGNN', 'IT-STGNN', 'GNAR']
mrate: [0.0, 0.7, 0.8, 0.3, 0.5, 0.6, 0.125]
mtype: [nan, 'rand', 'block']
lags: [2]
nof_filters: [12.0, 4.0, 2.0, nan]
inter_method: [nan, 'linear']
epoch: [50.0, nan]
model: ['GConvGRU', 'GConvLSTM', 'GCLSTM', 'DCRNN', 'LRGCN', 'TGCN', 'EvolveGCNO', 'DyGrEncoder', 'EvolveGCNH', 'GNAR']
---
dataset: chickenpox
method: ['STGNN', 'IT-STGNN', 'GNAR']
mrate: [0.0, 0.3, 0.8, 0.5, 0.6, 0.2877697841726618]
mtype: [nan, 'rand', 'block']
lags: [4]
nof_filters: [16.0, 32.0, 8.0, 12.0, nan]
inter_method: [nan, 'linear']
epoch: [50.0, nan]
model: ['GConvGRU', 'GConvLSTM', 'GCLSTM', 'DCRNN', 'LRGCN', 'TGCN', 'EvolveGCNO', 'DyGrEncoder', 'EvolveGCNH', 'GNAR']
---
dataset: pedalme
method: ['STGNN', 'IT-STGNN', 'GNAR']
mrate: [0.0, 0.3, 0.6, 0.5, 0.8, 0.2857142857142857]
mtype: [nan, 'rand', 'block']
lags: [4]
nof_filters: [12.0, 2.0, 4.0, 8.0, nan]
inter_method: [nan, 'nearest', 'linear']
epoch: [50.0, nan]
model: ['GConvGRU', 'GConvLSTM', 'GC

`-` 요약

In [9]:
df_summary = df2.groupby(["method","dataset","mrate","model"]).agg({'mse':'mean'}).reset_index()\
.pivot_table(index=['model'],columns=['dataset'],values='mse')
df_summary

# 시각화1: MissingRate (본문)

In [10]:
big = df.query("mtype=='rand' or mtype.isna()").query("dataset == 'chickenpox'").query("model == 'GConvLSTM'")\
.sort_values(by='mrate')\
.assign(mrate_jittered = lambda df: np.array(df['mrate'])+np.random.randn(len(df['mrate']))*0.01)
small = big.groupby(["dataset","mrate","method"]).agg({'mse':np.median}).reset_index().rename({'mse':'mse_median'},axis=1)
tidydata = big.merge(small)
#---#
fig = px.scatter(
    tidydata,
    y='mse_median',
    x='mrate',
    opacity=0.3,
    color='method',
    width=425,
    height=425,
    hover_data='mrate',
)
fig.data[0]['mode']='markers+lines'
fig.data[0]['marker']['size'] = 6
fig.data[0]['line']['width'] = 1.5
fig.data[0]['line']['dash'] = 'dashdot'
fig.data[1]['mode']='markers+lines'
fig.data[1]['marker']['size'] = 6
fig.data[1]['line']['width'] =1.5
fig.data[1]['line']['dash'] = 'dashdot'
box1 = px.box(
    tidydata.query("method=='STGNN'"),
    y='mse',
    x='mrate',
)
box1.data[0]['opacity']=0.5
box1.data[0]['marker']['color']='#636efa'
box2 = px.box(
    tidydata.query("method=='IT-STGNN'"),
    y='mse',
    x='mrate',
)
box2.data[0]['opacity']=0.5
box2.data[0]['marker']['color']='#EF553B'
fig.add_traces(box1.data)
fig.add_traces(box2.data)
fig.data[0]['showlegend'] =False
fig.data[1]['showlegend'] =False
fig.layout['xaxis']['title']['text']='Missign Rate'
fig.layout['yaxis']['title']['text']='MSE'
fig.layout['legend']['title']['text']=""
fig.layout['title']['text']='Chickenpox/GConvLSTM'
fig

  small = big.groupby(["dataset","mrate","method"]).agg({'mse':np.median}).reset_index().rename({'mse':'mse_median'},axis=1)

# 시각화2: MissingRate (부록)

In [11]:
tidydata = df.query("mtype!='block'").query("method!='GNAR'").query("dataset != 'fivenodes'")\
.groupby(["method","dataset","mrate","model"]).agg({'mse':'mean'}).reset_index()
#---#
fig = px.line(
    tidydata,
    x='mrate',
    y='mse',
    color='method',
    facet_row='model',
    facet_col='dataset',
    width=850,
    height=1000,
    
)
for scatter in fig.data:
    scatter['mode'] = 'lines+markers'
    scatter['line']['dash'] = 'dashdot'
for annotation in fig.layout['annotations']:
    annotation['text'] = annotation['text'].replace('dataset=','')
    annotation['text'] = annotation['text'].replace('model=','')
for k in [k for k in fig.layout if 'xaxis' in k]:
    fig.layout[k]['title']['text'] = None 
for k in [k for k in fig.layout if 'yaxis' in k]:
    fig.layout[k]['title']['text'] = None 
fig.update_yaxes(showticklabels=True,matches=None)
fig.update_xaxes(showticklabels=True,matches=None)

# 시각화3

In [12]:
def func(x):
    if 'IT' in x:
        return 'IT-STGNN'
    elif 'GNAR' in x:
        return 'GNAR'
    else: 
        return 'STGNN'

In [13]:
tidydata = df2.query('mtype != "block"').query('method!="GNAR"').groupby(["method","dataset","mrate","model"]).agg({'mse':'mean'}).reset_index()\
.pivot_table(index=['model'],columns=['dataset'],values='mse').stack().reset_index().rename({0:'mse'},axis=1).assign(
    method = lambda df: df['model'].apply(func)
)
tidydata = pd.concat([df.sort_values('mse').reset_index(drop=True).reset_index() for _,df in tidydata.groupby("dataset")])
#---#
fig = px.bar(
    tidydata,
    x='index',
    y='mse',
    color='method',
    facet_col='dataset',
    facet_col_wrap=2,
    text='model',
    height=800
)
fig

# 시각화4

In [14]:
tidydata = df.query('mtype != "block"').query('dataset != "fivenodes"')\
.groupby(["method","dataset","node","time"]).agg({'mse':'mean'}).reset_index()\
.assign(ratio = lambda df: df['time']/df['node'])\
.pivot_table(index=['dataset','ratio'] ,columns=['method'],values='mse')\
.assign(mse_diff = lambda df: df['STGNN']- df['IT-STGNN']).loc[:,'mse_diff']\
.reset_index()
tidydata
fig = px.scatter(
    tidydata,
    x='ratio',
    log_x=True,
    y='mse_diff',
    text='dataset',
    width=625,
    height=425,
)
fig.data[0]['textposition'] = ['top right']*3 + ['bottom right'] + ['top left']
fig.data[0]['marker']['size'] = 8
fig

# 시각화5

In [15]:
tidydata = df.assign(mtype = df['mtype'].fillna("rand"))\
.query('method != "GNAR"').query('dataset != "windmillsmall"').query('dataset != "wikimath"')\
.groupby(["method","dataset","mtype","mrate"]).agg({'mse':'mean'}).reset_index()\
.sort_values('mrate')\
.assign(mrate = lambda df: df['mrate'].apply(lambda x: f'{x:.3f}'))
#---#
fig = px.bar(
    tidydata,
    x='mrate',
    y='mse',
    color='method',
    facet_col='dataset',
    facet_col_wrap=2,    
    width=850,
    height=650,
    barmode='group',
    hover_data='mtype',
    opacity=0.2
)
fig.update_yaxes(showticklabels=True,matches=None)
fig.update_xaxes(showticklabels=True,matches=None)
for trace in fig.data:
    trace['marker']['opacity']=[0.2, 1, 0.2, 0.2, 0.2, 0.2, 0.2]
fig