# (연구&서연) IT-STGCN – 실험결과시각화

SEOYEON CHOI  
2024-01-08

# Imports

# Conditions

| Model         | FiveVTS | Chickenpox | Pedalme | Wikimath | Windmillsmall | MontevideoBus |
|----------------|--------|----------|--------|---------|------------|------------|
| Max iter.     | 30      | 30         | 30      | 30       | 30            | 30            |
| Epochs        | 50      | 50         | 50      | 50       | 50            | 50            |
| Lags          | 2       | 4          | 4       | 8        | 8             | 4             |
| Interpolation | linear  | linear     | nearest | linear   | linear        | nearest       |
| Filters       |         |            |         |          |               |               |
| GConvGRU      | 12      | 16         | 12      | 12       | 12            | 12            |
| GConvLSTM     | 12      | 32         | 2       | 64       | 16            | 12            |
| GCLSTM        | 4       | 16         | 4       | 64       | 16            | 12            |
| LRGCN         | 4       | 8          | 8       | 32       | 12            | 2             |
| DyGrEncoder   | 12      | 12         | 12      | 12       | 12            | 12            |
| EvolveGCNH    | No need | No need    | No need | No need  | No need       | No need       |
| EvolveGCNO    | No need | No need    | No need | No need  | No need       | No need       |
| TGCN          | 12      | 12         | 12      | 12       | 12            | 8             |
| DCRNN         | 2       | 16         | 8       | 12       | 4             | 12            |

# Import

In [4]:
import pandas as pd
import numpy as np

import plotly.express as px
import plotly.graph_objects as go
import pickle

import plotly.io as pio

In [5]:
pd.options.plotting.backend = "plotly"
pio.templates.default = "plotly_white"

# Data

In [59]:
df_dataset = pd.DataFrame({
    'dataset':['fivenodes','chickenpox','pedalme','wikimath','windmillsmall','monte'],
    'node':[5,20,15,1068,11,675],
    'time':[200,522,36,731,17472,744]
})
df = pd.read_csv('./Body_Results.csv').iloc[:,1:].merge(df_dataset)

# 시각화1: MissingRate (본문)

In [87]:
big = df.query("mtype=='rand' or mtype.isna()").query("dataset == 'wikimath'").query("model == 'GConvLSTM'")\
.sort_values(by='mrate')\
.assign(mrate_jittered = lambda df: np.array(df['mrate'])+np.random.randn(len(df['mrate']))*0.01)
small = big.groupby(["dataset","mrate","method"]).agg({'mse':'mean'}).reset_index().rename({'mse':'mse_mean'},axis=1)
tidydata = big.merge(small)
fig = px.scatter(
    tidydata,
    y='mse',
    x='mrate_jittered',
    opacity=0.1,
    color='method',
    width=750,
    height=500,
)
_fig1 = px.scatter(
    tidydata,
    y='mse_mean',
    x='mrate',
    color='method',
)
_fig1.data[0]['mode']='markers+lines'
_fig1.data[0]['marker']['size'] = 10
_fig1.data[0]['line']['width'] = 3
_fig1.data[0]['line']['dash'] = 'dashdot'
_fig1.data[1]['mode']='markers+lines'
_fig1.data[1]['marker']['size'] = 10
_fig1.data[1]['line']['width'] = 3
_fig1.data[1]['line']['dash'] = 'dashdot'
_fig2 = px.box(
    tidydata.query("method=='STGCN'"),
    y='mse',
    x='mrate',
)
_fig2.data[0]['opacity']=0.7
_fig2.data[0]['marker']['color']='#636efa'
_fig3 = px.box(
    tidydata.query("method=='IT-STGCN'"),
    y='mse',
    x='mrate',
)
_fig3.data[0]['opacity']=0.7
_fig3.data[0]['marker']['color']='#EF553B'
_fig3
for g in _fig1.data:
    fig.add_trace(g)
for g in _fig2.data:
    fig.add_trace(g)
for g in _fig3.data:
    fig.add_trace(g)
fig.data[0]['showlegend'] =False
fig.data[1]['showlegend'] =False
fig.layout['xaxis']['title']['text']='Missign Rate'
fig.layout['yaxis']['title']['text']='MSE'
fig.layout['legend']['title']['text']=""
fig.layout['title']['text']='Wikimath/GConvLSTM'
fig

# 시각화2: MissingRate (부록)

In [88]:
fig = df.query("mtype=='rand'")\
.query("method != 'GNAR'")\
.query("dataset != 'windmillsmall'")\
.groupby(["method","dataset","mrate","model"]).agg({'mse':'mean'}).reset_index()\
.plot.line(
    x='mrate',
    y='mse',
    color='method',
    facet_row='model',
    facet_col='dataset',
    width=750,
    height=1000,
    
)
for scatter in fig.data:
    scatter['mode'] = 'lines+markers'
    scatter['line']['dash'] = 'dashdot'
for annotation in fig.layout['annotations']:
    annotation['text'] = annotation['text'].replace('dataset=','')
    annotation['text'] = annotation['text'].replace('model=','')
for k in [k for k in fig.layout if 'xaxis' in k]:
    fig.layout[k]['title']['text'] = None 
for k in [k for k in fig.layout if 'yaxis' in k]:
    fig.layout[k]['title']['text'] = None 
fig.update_yaxes(showticklabels=True,matches=None)
fig.update_xaxes(showticklabels=True,matches=None)

# 시각화3

In [121]:
df.groupby(['dataset','mrate']).agg('count')

In [131]:
df.query('method != "GNAR" ')\
.query('mrate==0 or mrate==0.8')\
.query('dataset != "windmillsmall"')\
.groupby(['method','dataset','mrate','model','node','time']).agg({'mse':['mean','std']}).reset_index()\
.set_axis(['method','dataset','mrate','model','node','time','mse_mean','mse_std'],axis=1)\
.assign(node2 = lambda df: df['node']+np.random.randn(len(df))*df['mse_std'])\
.assign(time2 = lambda df: df['time']+np.random.randn(len(df))*df['mse_std'])

# 시각화

In [70]:
df[df['dataset'] == 'fivenodes']