<a href="https://colab.research.google.com/github/mazarimono/pyconapac2023/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# モデル作成

- 地下を歩く人の数、地下鉄乗車数、天気から動物園入場者数を推論する簡易なモデルを作成
- 総入園者数0は除外する
- lightGBMでモデル作成
    -

In [1]:
import lightgbm as lgb
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import plotly.express as px

In [2]:
# データ読み込み
all_path = 'https://raw.githubusercontent.com/mazarimono/pyconapac2023/main/data/all_data.csv'
df = pd.read_csv(all_path, index_col=0)
df = df[df['A:総入園者数'] != 0]
df.index = pd.to_datetime(df.index)

# 2022年11月までをトレイン、2022年12月をテストでデータを分ける

random_state = 42

target_col = 'A:総入園者数'
features = [
    'J1',
    'さっぽろ（南北線）',
    '気温(℃)',
    '降水量(mm)',
    '積雪深(cm)',
    'wod_x',
    'month'
    ]

label = df.pop(target_col)
data = df[features]

test_data = data.loc['2022-12']
test_label = label.loc['2022-12']
train_data = data.loc[:'2022-11']
train_label = label.loc[:'2022-11']

X_train, X_valid, y_train, y_valid = train_test_split(
    train_data, train_label, random_state=42
)
train_set = lgb.Dataset(X_train, y_train)
valid_set = lgb.Dataset(X_valid, y_valid, reference=train_set)
params = {
    'objective': 'regression',
    'metric': ['l2_root'],
    'random_state': random_state
}
callbacks = [lgb.callback.early_stopping(stopping_rounds=10)]
model = lgb.train(params, train_set, valid_sets=valid_set, callbacks=callbacks)

You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 894
[LightGBM] [Info] Number of data points in the train set: 781, number of used features: 7
[LightGBM] [Info] Start training from score 2354.793854
Training until validation scores don't improve for 10 rounds
Early stopping, best iteration is:
[58]	valid_0's rmse: 1403.04


In [3]:
pred = model.predict(test_data)
pred_df = pd.DataFrame(
    {'test': test_label.values,
    'pred': pred}
)
pred_df.index = test_label.index
pred_df['wod'] = pred_df.index.map(lambda x: x.weekday)
pred_df

Unnamed: 0,test,pred,wod
2022-12-01,403.0,424.487066,3
2022-12-02,590.0,772.233595,4
2022-12-03,1215.0,2008.688834,5
2022-12-04,1549.0,2148.735927,6
2022-12-05,718.0,1010.695424,0
2022-12-06,1055.0,380.973484,1
2022-12-07,613.0,1180.5992,2
2022-12-08,679.0,578.182952,3
2022-12-09,1025.0,718.303707,4
2022-12-10,1868.0,1512.929716,5


In [4]:
rmse = np.sqrt(mean_squared_error(test_label.values, pred))
r2 = r2_score(test_label, pred)
print(f'rmse: {rmse}')
print(f'R2: {r2}')

rmse: 407.1580291519024
R2: 0.19283471141934383


In [6]:
px.line(pred_df[['test', 'pred']], title='予測と実際')

In [7]:
# feature_imporanceの観察
imp = model.feature_importance(importance_type='gain')
imp_df = pd.DataFrame({'imp': imp})
imp_df.index = features
imp_df

Unnamed: 0,imp
J1,3024880000.0
さっぽろ（南北線）,3479828000.0
気温(℃),3519829000.0
降水量(mm),1081654000.0
積雪深(cm),92329120.0
wod_x,3148884000.0
month,1569504000.0


In [8]:
px.bar(imp_df.sort_values('imp'), title='feature_imporance(gain)')

In [9]:
df

Unnamed: 0,B:有料入園者数,C:無料入園者数,D:Cのうち幼児,E:Cのうち小学生,F:Cのうち中学生,G:Cのうち障がい者,H:Cのうち市内６５歳以上,I:CのうちD～H以外,month,wod_x,...,気温(℃),風向(度:0～359),風速(m/s),降水量(mm),積雪深(cm),J1,J2,J3,J4,J5
2019-04-01,2057.0,1878.0,856.0,603.0,116.0,66.0,165.0,72.0,4,0,...,2.682639,234.916667,0.460417,0.0,0.000000,2242.454545,1436.597403,1837.558442,1514.935065,981.948052
2019-04-02,1253.0,1380.0,582.0,490.0,72.0,53.0,118.0,65.0,4,1,...,1.634722,267.562500,0.700694,0.0,0.006944,2346.155844,1449.025974,2040.376623,1572.909091,1076.051948
2019-04-03,2015.0,2138.0,809.0,904.0,97.0,64.0,185.0,79.0,4,2,...,1.883333,304.916667,0.872917,1.0,0.541667,2329.000000,1500.454545,2024.870130,1586.116883,1116.272727
2019-04-04,3092.0,3708.0,1634.0,1188.0,158.0,113.0,486.0,129.0,4,3,...,4.932639,223.041667,0.319444,0.0,0.000000,1960.389610,1305.636364,1778.259740,1596.519481,1059.467532
2019-04-05,918.0,919.0,337.0,342.0,72.0,33.0,72.0,63.0,4,4,...,4.656250,203.743056,0.743750,1.0,0.013889,2208.532468,1498.506494,1916.805195,1603.818182,1131.363636
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022-12-23,383.0,114.0,50.0,56.0,5.0,1.0,0.0,2.0,12,4,...,1.179861,283.152778,1.184028,26.5,14.625000,2097.935065,1261.038961,1850.727273,1430.441558,820.792208
2022-12-24,931.0,211.0,114.0,48.0,18.0,10.0,6.0,15.0,12,5,...,2.925694,332.451389,2.509722,1.5,13.840278,1900.350649,1338.051948,1681.376623,1421.025974,590.909091
2022-12-25,1199.0,392.0,190.0,152.0,15.0,11.0,9.0,15.0,12,6,...,1.916667,324.736111,2.406250,13.5,16.159722,1896.389610,1279.753247,1678.532468,1351.415584,980.025974
2022-12-26,1034.0,541.0,244.0,186.0,32.0,20.0,13.0,46.0,12,0,...,3.347222,303.194444,1.485417,0.0,14.013889,1873.012987,1259.662338,1685.168831,1281.987013,772.103896
