In [1]:
# import libraries
import os
import math
import time
import yaml
import datetime 
import numpy as np
import pandas as pd
from tqdm import trange
from tqdm.notebook import tqdm

import seaborn as sns
import matplotlib.pyplot as plt
import plotly.graph_objects as go

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error

import torch
import torch.nn as nn
from torch.nn import LayerNorm
from torch.utils.data import Dataset, DataLoader
from torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer

# set random seed
fix_seed = 1111
np.random.seed(fix_seed)
torch.manual_seed(fix_seed)

# set device 
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device)

mps


In [2]:
# import data
df = pd.read_csv('/Users/koki/PycharmProjects/MasterThesis/data/preprocessed/201601to202309_10areas_for_analysis.csv')
df['datetime'] = pd.to_datetime(df['datetime'])
df = df.set_index('datetime')

In [3]:
# covid dummy
covid_1 = pd.date_range(start='2020-03-01', end='2020-05-30 21:00', freq='3h')  # 第1波
covid_2 = pd.date_range(start='2020-11-01', end='2021-04-30 21:00', freq='3h')  # 第3波, 第4波
covid_3 = pd.date_range(start='2021-07-01', end='2021-09-30 21:00', freq='3h')  # 第5波
covid_4 = pd.date_range(start='2021-12-01', end='2022-03-31 21:00', freq='3h')  # 第6波(オミクロン株の拡大)
covid = covid_1.union(covid_2).union(covid_3).union(covid_4)
df['covid'] = df.index.map(lambda d: int(d in covid))

In [4]:
df = df.drop(columns=[
                      'month_1', 'month_10', 'month_11', 'month_12', 'month_2', 'month_3', 
                      'month_4', 'month_5', 'month_6', 'month_7', 'month_8', 'month_9',
                      'hour_0', 'hour_12', 'hour_15', 'hour_18', 'hour_21', 'hour_3', 'hour_6', 'hour_9',
                      # 'dow_Fri', 'dow_Mon', 'dow_Sat', 'dow_Sun', 'dow_Thu', 'dow_Tue', 'dow_Wed'
                    ])

In [5]:
df

Unnamed: 0_level_0,渋谷駅_total,新宿駅_total,町田駅_total,川崎駅_total,立川駅_total,八王子駅_total,北千住駅_total,東京駅_total,赤羽駅_total,自由が丘駅_total,...,八王子_rainfall,八王子_temperature,八王子_windspeed_value,八王子_sunshine_hours,N225_Close,Cases_Tokyo,東京_緊急事態,東京_まん防,東京_weather_霧雨,covid
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2016-01-01 00:00:00,35355,12141,14725,17493,10177,12231,11088,5839,13977,9576,...,0.0,1.5,1.8,0.0,18450.98,0.0,0,0,0,0
2016-01-01 03:00:00,28498,12368,14351,17447,9992,10735,11248,6329,14161,9241,...,0.0,1.1,4.1,0.0,18450.98,0.0,0,0,0,0
2016-01-01 06:00:00,17156,11814,13452,17407,9351,10595,10763,13130,13327,8840,...,0.0,-1.1,1.7,0.0,18450.98,0.0,0,0,0,0
2016-01-01 09:00:00,17727,18039,14696,24128,12199,11015,11835,18117,14174,9030,...,0.0,5.2,1.2,1.0,18450.98,0.0,0,0,0,0
2016-01-01 12:00:00,26222,24975,17675,31771,16030,12370,12966,23981,15539,10533,...,0.0,10.8,2.0,1.0,18450.98,0.0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2023-09-30 09:00:00,35347,42090,21798,28723,25356,17428,17618,37380,20538,15806,...,0.0,24.7,1.8,0.0,31857.62,0.0,0,0,0,0
2023-09-30 12:00:00,84314,92850,36892,44786,47068,24358,27022,71237,26053,25193,...,0.0,24.0,0.8,0.0,31857.62,0.0,0,0,0,0
2023-09-30 15:00:00,110158,113744,41491,49685,52116,25661,29211,78514,29550,27711,...,0.0,26.4,2.7,0.0,31857.62,0.0,0,0,0,0
2023-09-30 18:00:00,112153,106017,41137,47749,48999,26222,31059,65990,30933,26047,...,0.0,25.3,4.8,0.0,31857.62,0.0,0,0,0,0
