## cityB validation

### 設定

In [None]:
import pandas as pd
import numpy as np
import geobleu

### 定数定義

In [None]:
# 7:00～20:00を日中と定義
MORNING_T = 14 # AM7時
NIGHT_T = 40 # PM8時

# 曜日の数（不変）
DOW_COUNT = 7

In [None]:
INPUT_PATH = "../../01_public/humob-challenge-2024/input/cityB_challengedata.csv.gz"

### データ読み込み・分割

In [None]:
df_city_b = pd.read_csv(INPUT_PATH)
df_city_b.head(3)

In [None]:
df_city_b["dow"] = df_city_b["d"] % DOW_COUNT
df_city_b["t_label"] = (
    df_city_b["t"]
    .apply(lambda x: "daytime" if MORNING_T <= x < NIGHT_T else "nighttime")  
)
df_city_b.head(3)

In [None]:
# uidが20000~21999のデータを検証に利用
df_city_b_valid = (
    df_city_b
    .loc[df_city_b["uid"].between(20000, 21999)]
)
df_city_b_valid.head(3)

In [None]:
# dが60前後でデータを分割
df_city_b_train = (
    df_city_b_valid
    .loc[df_city_b_valid["d"] < 60]
)

df_city_b_answer = (
    df_city_b_valid
    .loc[df_city_b_valid["d"] >= 60]
)

### 欠損値補完テーブル作成

In [None]:
# uid×dow×t毎に最頻値を算出
df_dow_t_mode = (
    df_city_b_train
    .groupby(["uid", "dow", "t"])
    [["x", "y"]]
    .agg(
        lambda x: x.mode().iloc[0]
    )
    .reset_index()

    .rename(
        columns={"x": "dow_t_x", "y": "dow_t_y"}
    )
)
df_dow_t_mode.head(3)

In [None]:
# uid×t毎に最頻値を算出
df_t_mode = (
    df_city_b_train
    .groupby(["uid", "t"])
    [["x", "y"]]
    .agg(
        lambda x: x.mode().iloc[0]
    )

    .reset_index()

    .rename(
        columns={"x": "t_x", "y": "t_y"}
    )
)
df_t_mode.head(3)

In [None]:
# uid×dow×t_label毎に最頻値を算出
df_dow_t_label_mode = (
    df_city_b_train
    .groupby(["uid", "dow", "t_label"])
    [["x", "y"]]
    .agg(
        lambda x: x.mode().iloc[0]
    )
    .reset_index()

    .rename(
        columns={"x": "dow_t_label_x", "y": "dow_t_label_y"}
    )
)
df_dow_t_label_mode.head(3)

In [None]:
# uid×t_label毎に最頻値を算出
df_t_label_mode = (
    df_city_b_train
    .groupby(["uid", "t_label"])
    [["x", "y"]]
    .agg(
        lambda x: x.mode().iloc[0]
    )

    .reset_index()

    .rename(
        columns={"x": "t_label_x", "y": "t_label_y"}
    )
)
df_t_label_mode.head(3)

In [None]:
# uid毎に最頻値を算出
df_uid_mode = (
    df_city_b_train
    .groupby(["uid"])
    [["x", "y"]]
    .agg(
        lambda x: x.mode().iloc[0]
    )

    .reset_index()

    .rename(
        columns={"x": "uid_x", "y": "uid_y"}
    )
)
df_uid_mode.head(3)

### 予測

In [None]:
# 欠損値補完テーブルを結合
df_city_b_pred = (
    df_city_b_answer

    .merge(df_dow_t_mode, on=["uid", "dow", "t"], how="left")
    .merge(df_t_mode, on=["uid", "t"], how="left")
    .merge(df_dow_t_label_mode, on=["uid", "dow", "t_label"], how="left")
    .merge(df_t_label_mode, on=["uid", "t_label"], how="left")
    .merge(df_uid_mode, on=["uid"], how="left")
)

In [None]:
# 優先順位をつけて最終的な予測値を算出
df_city_b_pred["pred_x"] = (
    df_city_b_pred["dow_t_x"]
    .fillna(df_city_b_pred["t_x"])
    .fillna(df_city_b_pred["dow_t_label_x"])
    .fillna(df_city_b_pred["t_label_x"])
    .fillna(df_city_b_pred["uid_x"])
)

df_city_b_pred["pred_y"] = (
    df_city_b_pred["dow_t_y"]
    .fillna(df_city_b_pred["t_y"])
    .fillna(df_city_b_pred["dow_t_label_y"])
    .fillna(df_city_b_pred["t_label_y"])
    .fillna(df_city_b_pred["uid_y"])
)

In [None]:
# 最終的なdf
df_city_b_pred[["uid", "d", "t", "x", "y", "pred_x", "pred_y"]]

### 精度検証

In [None]:
list_geobleu_val = []
list_dtw_val = []

for i in range(20000, 22000):
    
    # 該当uidのデータを抽出
    df = (
        df_city_b_pred
        .loc[df_city_b_pred["uid"] == i]
    )

    # 予測値のリスト
    df_pred = df[["d", "t", "pred_x", "pred_y"]]
    list_pred = [tuple(row) for row in df_pred.to_records(index=False)]

    # 正解値のリスト
    df_answer = df[["d", "t", "x", "y"]]
    list_answer = [tuple(row) for row in df_answer.to_records(index=False)]

    # スコア算出
    geobleu_val = geobleu.calc_geobleu(list_pred, list_answer, processes=3)
    list_geobleu_val.append(geobleu_val)
    
    dtw_val = geobleu.calc_dtw(list_pred, list_answer, processes=3)
    list_dtw_val.append(dtw_val)


In [None]:
print(f"geobleu:{np.mean(list_geobleu_val)}")
print(f"dtw:{np.mean(list_dtw_val)}")