torelanceに応じた基本的なscoringについてはこのdiscussionに書いてある。しかし複数のピークが経っている場合やラベルがない場合の理解が曖昧なのでテストを書いてみる。
https://www.kaggle.com/competitions/child-mind-institute-detect-sleep-states/discussion/438936


In [1]:

import jupyter_black
import polars as pl
import seaborn as sns

from src.utils.metrics import event_detection_ap

jupyter_black.load()
# plt.style.use("ggplot")
sns.set()

In [2]:
labels = pl.read_csv(
    "/home/kuto/kaggle/kaggle-sleep-v2/data/child-mind-institute-detect-sleep-states/train_events.csv"
).drop_nulls()

In [3]:
series_id = "7476c0bd18d2"
# night = 4

target_events = labels.filter(
    pl.col("series_id") == series_id
)  # .filter(pl.col("night") <= night)
target_events

series_id,night,event,step,timestamp
str,i64,str,i64,str
"""7476c0bd18d2""",4,"""onset""",56340,"""2019-03-02T23:…"
"""7476c0bd18d2""",4,"""wakeup""",62412,"""2019-03-03T07:…"
"""7476c0bd18d2""",5,"""onset""",73836,"""2019-03-03T23:…"
"""7476c0bd18d2""",5,"""wakeup""",75348,"""2019-03-04T01:…"


## 基本的な挙動

In [4]:
# stepをずらした場合、torelanceに応じてscoreが変わる(距離が離れているほどscoreが下がる)
for diff in [0, 12, 36, 60, 90, 120, 150, 180, 240, 300, 360]:
    pred_events = target_events.clone().select(["series_id", "event", "step"])
    pred_events = pred_events.with_columns(
        pl.lit(1).alias("score"),
        pl.col("step") + diff,
    )
    score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
    print(f"diff:{diff}, score:{score}")

diff:0, score:1.0
diff:12, score:0.9
diff:36, score:0.8
diff:60, score:0.7
diff:90, score:0.6
diff:120, score:0.5
diff:150, score:0.4
diff:180, score:0.3
diff:240, score:0.2
diff:300, score:0.1
diff:360, score:0.0


In [5]:
# stepを片方のイベントだけずらした場合片方分のscoreのみ下がるため上のケースよりはscoreの減少幅が少ない
for diff in [0, 12, 36, 60, 90, 120, 150, 180, 240, 300, 360]:
    pred_events = target_events.clone().select(["series_id", "event", "step"])
    pred_events = pred_events.with_columns(
        pl.lit(1).alias("score"),
        pl.when(pl.col("event") == "onset").then(pl.col("step") + diff).otherwise(pl.col("step")),
    )
    score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
    print(f"diff:{diff}, score:{score}")

diff:0, score:1.0
diff:12, score:0.95
diff:36, score:0.9
diff:60, score:0.85
diff:90, score:0.8
diff:120, score:0.75
diff:150, score:0.7
diff:180, score:0.65
diff:240, score:0.6
diff:300, score:0.55
diff:360, score:0.5


## ラベルありのところの検証

In [6]:
# gtからdiff離れたところに最もscoreの高いピーク(0.6)があり、その次に高いscore(0.5)がGTのstepにある場合
# diff分離れたピークがまずgtとのtoleranceに基づいて評価される。
# その後scoreが低いピークも評価されそれはscore=1になって平均が取られる。
# scoreが高い順に見る→そのstepに近い未マッチのgtを見る→予測とGTをマッチさせる
for diff in [0, 12, 36, 60, 90, 120, 150, 180, 240, 300, 360, 500]:
    pred_events = target_events.clone().select(["series_id", "event", "step"])
    pred_events = pred_events.with_columns(
        pl.lit(0.5).alias("score"),
    )
    pred_events = pl.concat(
        [
            pred_events,
            pred_events.with_columns(
                pl.col("step") + diff,
                pl.lit(0.6).alias("score"),
            ),
        ]
    )
    score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
    print(f"score:{score}")

score:1.0
score:0.95
score:0.9
score:0.85
score:0.8
score:0.75
score:0.7
score:0.65
score:0.6
score:0.55
score:0.5
score:0.5


In [7]:
# GTはあるが未検出がある場合 → その分のスコアが0で平均が計算される
pred_events = target_events.clone().select(["series_id", "event", "step"])
pred_events = pred_events.with_columns(
    pl.lit(1.0).alias("score"),
)
pred_events = pred_events[:2]
print(pred_events)
score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
print(f"score:{score}")

shape: (2, 4)
┌──────────────┬────────┬───────┬───────┐
│ series_id    ┆ event  ┆ step  ┆ score │
│ ---          ┆ ---    ┆ ---   ┆ ---   │
│ str          ┆ str    ┆ i64   ┆ f64   │
╞══════════════╪════════╪═══════╪═══════╡
│ 7476c0bd18d2 ┆ onset  ┆ 56340 ┆ 1.0   │
│ 7476c0bd18d2 ┆ wakeup ┆ 62412 ┆ 1.0   │
└──────────────┴────────┴───────┴───────┘
score:0.5


In [8]:
# GTがないところを他のところ(score=1.0)より小さいscore(score=0.9)で予測している
# → 誤検出にはならずscore=1
pred_events = target_events.clone().select(["series_id", "event", "step"])
pred_events = pred_events.with_columns(
    pl.lit(1.0).alias("score"),
)
# stepをずらしたラベルを追加
pred_events = pl.concat(
    [pred_events, pred_events.with_columns(pl.col("step") + 100000, pl.col("score") - 0.1)]
)
print(pred_events)
score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
print(f"score:{score}")

shape: (8, 4)
┌──────────────┬────────┬────────┬───────┐
│ series_id    ┆ event  ┆ step   ┆ score │
│ ---          ┆ ---    ┆ ---    ┆ ---   │
│ str          ┆ str    ┆ i64    ┆ f64   │
╞══════════════╪════════╪════════╪═══════╡
│ 7476c0bd18d2 ┆ onset  ┆ 56340  ┆ 1.0   │
│ 7476c0bd18d2 ┆ wakeup ┆ 62412  ┆ 1.0   │
│ 7476c0bd18d2 ┆ onset  ┆ 73836  ┆ 1.0   │
│ 7476c0bd18d2 ┆ wakeup ┆ 75348  ┆ 1.0   │
│ 7476c0bd18d2 ┆ onset  ┆ 156340 ┆ 0.9   │
│ 7476c0bd18d2 ┆ wakeup ┆ 162412 ┆ 0.9   │
│ 7476c0bd18d2 ┆ onset  ┆ 173836 ┆ 0.9   │
│ 7476c0bd18d2 ┆ wakeup ┆ 175348 ┆ 0.9   │
└──────────────┴────────┴────────┴───────┘
score:1.0


In [9]:
# GTがないところを他のところ(0.9)より大きいscore(1)で予測している
# → 誤検出にはなり
pred_events = target_events.clone().select(["series_id", "event", "step"])
pred_events = pred_events.with_columns(
    pl.lit(0.9).alias("score"),
)
# stepをずらしたラベルを追加
pred_events = pl.concat(
    [pred_events, pred_events.with_columns(pl.col("step") + 100000, pl.col("score") + 0.1)]
)
print(pred_events)
score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
print(f"score:{score}")

shape: (8, 4)
┌──────────────┬────────┬────────┬───────┐
│ series_id    ┆ event  ┆ step   ┆ score │
│ ---          ┆ ---    ┆ ---    ┆ ---   │
│ str          ┆ str    ┆ i64    ┆ f64   │
╞══════════════╪════════╪════════╪═══════╡
│ 7476c0bd18d2 ┆ onset  ┆ 56340  ┆ 0.9   │
│ 7476c0bd18d2 ┆ wakeup ┆ 62412  ┆ 0.9   │
│ 7476c0bd18d2 ┆ onset  ┆ 73836  ┆ 0.9   │
│ 7476c0bd18d2 ┆ wakeup ┆ 75348  ┆ 0.9   │
│ 7476c0bd18d2 ┆ onset  ┆ 156340 ┆ 1.0   │
│ 7476c0bd18d2 ┆ wakeup ┆ 162412 ┆ 1.0   │
│ 7476c0bd18d2 ┆ onset  ┆ 173836 ┆ 1.0   │
│ 7476c0bd18d2 ┆ wakeup ┆ 175348 ┆ 1.0   │
└──────────────┴────────┴────────┴───────┘
score:0.5


In [10]:
# 同一score、同一eventの重複がある場合
# → これがなぜかscoreが下がる(これは注意した方がいい)
pred_events = target_events.clone()
pred_events = pred_events.with_columns(
    pl.lit(1.0).alias("score"),
)
pred_events = pl.concat(
    [pred_events, pred_events.with_columns(pl.col("step") + 100000, pl.col("night") + 2)]
)
print(pred_events)
score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
print(f"score:{score}")

shape: (8, 6)
┌──────────────┬───────┬────────┬────────┬──────────────────────────┬───────┐
│ series_id    ┆ night ┆ event  ┆ step   ┆ timestamp                ┆ score │
│ ---          ┆ ---   ┆ ---    ┆ ---    ┆ ---                      ┆ ---   │
│ str          ┆ i64   ┆ str    ┆ i64    ┆ str                      ┆ f64   │
╞══════════════╪═══════╪════════╪════════╪══════════════════════════╪═══════╡
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 56340  ┆ 2019-03-02T23:30:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 4     ┆ wakeup ┆ 62412  ┆ 2019-03-03T07:56:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 73836  ┆ 2019-03-03T23:48:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 5     ┆ wakeup ┆ 75348  ┆ 2019-03-04T01:54:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 6     ┆ onset  ┆ 156340 ┆ 2019-03-02T23:30:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 6     ┆ wakeup ┆ 162412 ┆ 2019-03-03T07:56:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 7     ┆ onset  ┆ 173836 ┆ 2019-03-03T23:48:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 7     ┆ wakeup ┆ 175348 ┆ 2019-03

In [11]:
# 同一score、同一eventの重複がある場合
# → これがなぜかscoreが下がる(これは注意した方がいい)
pred_events = target_events.clone()
pred_events = pred_events.with_columns(
    pl.lit(1.0).alias("score"),
)
pred_events = pl.concat([pred_events, pred_events])
print(pred_events)
score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
print(f"score:{score}")

shape: (8, 6)
┌──────────────┬───────┬────────┬───────┬──────────────────────────┬───────┐
│ series_id    ┆ night ┆ event  ┆ step  ┆ timestamp                ┆ score │
│ ---          ┆ ---   ┆ ---    ┆ ---   ┆ ---                      ┆ ---   │
│ str          ┆ i64   ┆ str    ┆ i64   ┆ str                      ┆ f64   │
╞══════════════╪═══════╪════════╪═══════╪══════════════════════════╪═══════╡
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 56340 ┆ 2019-03-02T23:30:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 4     ┆ wakeup ┆ 62412 ┆ 2019-03-03T07:56:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 73836 ┆ 2019-03-03T23:48:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 5     ┆ wakeup ┆ 75348 ┆ 2019-03-04T01:54:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 56340 ┆ 2019-03-02T23:30:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 4     ┆ wakeup ┆ 62412 ┆ 2019-03-03T07:56:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 73836 ┆ 2019-03-03T23:48:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 5     ┆ wakeup ┆ 75348 ┆ 2019-03-04T01:54:00-

In [12]:
# scoreが異なるものが、同一eventの重複がある場合
# → 重複のscoreが高い方がちゃんと採用されるので満点になる
pred_events = target_events.clone()
pred_events = pred_events.with_columns(
    pl.lit(1.0).alias("score"),
)
_pred_events = pred_events[1].with_columns(pl.lit(0.5).alias("score"))
pred_events = pl.concat([pred_events, _pred_events]).sort("step")

print(pred_events)
score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
print(f"score:{score}")

shape: (5, 6)
┌──────────────┬───────┬────────┬───────┬──────────────────────────┬───────┐
│ series_id    ┆ night ┆ event  ┆ step  ┆ timestamp                ┆ score │
│ ---          ┆ ---   ┆ ---    ┆ ---   ┆ ---                      ┆ ---   │
│ str          ┆ i64   ┆ str    ┆ i64   ┆ str                      ┆ f64   │
╞══════════════╪═══════╪════════╪═══════╪══════════════════════════╪═══════╡
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 56340 ┆ 2019-03-02T23:30:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 4     ┆ wakeup ┆ 62412 ┆ 2019-03-03T07:56:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 4     ┆ wakeup ┆ 62412 ┆ 2019-03-03T07:56:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 73836 ┆ 2019-03-03T23:48:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 5     ┆ wakeup ┆ 75348 ┆ 2019-03-04T01:54:00-0500 ┆ 1.0   │
└──────────────┴───────┴────────┴───────┴──────────────────────────┴───────┘
score:1.0


In [13]:
# scoreが異なるものが、同一eventの重複がある場合
# → stepが外れておりscoreが高い方が採用される
pred_events = target_events.clone()
pred_events = pred_events.with_columns(
    pl.lit(0.5).alias("score"),
)
_pred_events = pred_events[1].with_columns(
    pl.lit(1.0).alias("score"),
    pl.col("step") + 12,
)
pred_events = pl.concat([pred_events, _pred_events]).sort("step")

print(pred_events)
score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
print(f"score:{score}")

shape: (5, 6)
┌──────────────┬───────┬────────┬───────┬──────────────────────────┬───────┐
│ series_id    ┆ night ┆ event  ┆ step  ┆ timestamp                ┆ score │
│ ---          ┆ ---   ┆ ---    ┆ ---   ┆ ---                      ┆ ---   │
│ str          ┆ i64   ┆ str    ┆ i64   ┆ str                      ┆ f64   │
╞══════════════╪═══════╪════════╪═══════╪══════════════════════════╪═══════╡
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 56340 ┆ 2019-03-02T23:30:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 4     ┆ wakeup ┆ 62412 ┆ 2019-03-03T07:56:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 4     ┆ wakeup ┆ 62424 ┆ 2019-03-03T07:56:00-0500 ┆ 1.0   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 73836 ┆ 2019-03-03T23:48:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 5     ┆ wakeup ┆ 75348 ┆ 2019-03-04T01:54:00-0500 ┆ 0.5   │
└──────────────┴───────┴────────┴───────┴──────────────────────────┴───────┘
score:0.9083333333333333


In [14]:
# 交互になっていない場合
# → scoreが高い方が採用されるので特に交互になる必要はない
pred_events = target_events.clone()
pred_events = pred_events.with_columns(
    pl.lit(0.5).alias("score"),
)
_pred_events = pred_events.with_columns(
    pl.col("score") * 1.2,
    pl.col("step") + 12,
)
pred_events = pl.concat([pred_events, _pred_events]).sort("step")

print(pred_events)
score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
print(f"score:{score}")

shape: (8, 6)
┌──────────────┬───────┬────────┬───────┬──────────────────────────┬───────┐
│ series_id    ┆ night ┆ event  ┆ step  ┆ timestamp                ┆ score │
│ ---          ┆ ---   ┆ ---    ┆ ---   ┆ ---                      ┆ ---   │
│ str          ┆ i64   ┆ str    ┆ i64   ┆ str                      ┆ f64   │
╞══════════════╪═══════╪════════╪═══════╪══════════════════════════╪═══════╡
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 56340 ┆ 2019-03-02T23:30:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 56352 ┆ 2019-03-02T23:30:00-0500 ┆ 0.6   │
│ 7476c0bd18d2 ┆ 4     ┆ wakeup ┆ 62412 ┆ 2019-03-03T07:56:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 4     ┆ wakeup ┆ 62424 ┆ 2019-03-03T07:56:00-0500 ┆ 0.6   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 73836 ┆ 2019-03-03T23:48:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 73848 ┆ 2019-03-03T23:48:00-0500 ┆ 0.6   │
│ 7476c0bd18d2 ┆ 5     ┆ wakeup ┆ 75348 ┆ 2019-03-04T01:54:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 5     ┆ wakeup ┆ 75360 ┆ 2019-03-04T01:54:00-

In [15]:
# wakeup付近にscoreの高いonsetがある場合 → 悪い影響がある
pred_events = target_events.clone()
pred_events = pred_events.with_columns(
    pl.lit(0.5).alias("score"),
)
_pred_events = pred_events.with_columns(
    pl.lit("onset").alias("event"),
    pl.col("score") * 1.2,
    pl.col("step") + 12,
)
pred_events = pl.concat([pred_events, _pred_events]).sort("step")

print(pred_events)
score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
print(f"score:{score}")

shape: (8, 6)
┌──────────────┬───────┬────────┬───────┬──────────────────────────┬───────┐
│ series_id    ┆ night ┆ event  ┆ step  ┆ timestamp                ┆ score │
│ ---          ┆ ---   ┆ ---    ┆ ---   ┆ ---                      ┆ ---   │
│ str          ┆ i64   ┆ str    ┆ i64   ┆ str                      ┆ f64   │
╞══════════════╪═══════╪════════╪═══════╪══════════════════════════╪═══════╡
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 56340 ┆ 2019-03-02T23:30:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 56352 ┆ 2019-03-02T23:30:00-0500 ┆ 0.6   │
│ 7476c0bd18d2 ┆ 4     ┆ wakeup ┆ 62412 ┆ 2019-03-03T07:56:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 62424 ┆ 2019-03-03T07:56:00-0500 ┆ 0.6   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 73836 ┆ 2019-03-03T23:48:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 73848 ┆ 2019-03-03T23:48:00-0500 ┆ 0.6   │
│ 7476c0bd18d2 ┆ 5     ┆ wakeup ┆ 75348 ┆ 2019-03-04T01:54:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 75360 ┆ 2019-03-04T01:54:00-

In [16]:
# wakeup付近にscoreの低いonsetがある場合 → 問題ない
pred_events = target_events.clone()
pred_events = pred_events.with_columns(
    pl.lit(0.5).alias("score"),
)
_pred_events = pred_events.with_columns(
    pl.lit("onset").alias("event"),
    pl.col("score") * 0.8,
    pl.col("step") + 12,
)
pred_events = pl.concat([pred_events, _pred_events]).sort("step")

print(pred_events)
score = event_detection_ap(target_events.to_pandas(), pred_events.to_pandas())
print(f"score:{score}")

shape: (8, 6)
┌──────────────┬───────┬────────┬───────┬──────────────────────────┬───────┐
│ series_id    ┆ night ┆ event  ┆ step  ┆ timestamp                ┆ score │
│ ---          ┆ ---   ┆ ---    ┆ ---   ┆ ---                      ┆ ---   │
│ str          ┆ i64   ┆ str    ┆ i64   ┆ str                      ┆ f64   │
╞══════════════╪═══════╪════════╪═══════╪══════════════════════════╪═══════╡
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 56340 ┆ 2019-03-02T23:30:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 56352 ┆ 2019-03-02T23:30:00-0500 ┆ 0.4   │
│ 7476c0bd18d2 ┆ 4     ┆ wakeup ┆ 62412 ┆ 2019-03-03T07:56:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 4     ┆ onset  ┆ 62424 ┆ 2019-03-03T07:56:00-0500 ┆ 0.4   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 73836 ┆ 2019-03-03T23:48:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 73848 ┆ 2019-03-03T23:48:00-0500 ┆ 0.4   │
│ 7476c0bd18d2 ┆ 5     ┆ wakeup ┆ 75348 ┆ 2019-03-04T01:54:00-0500 ┆ 0.5   │
│ 7476c0bd18d2 ┆ 5     ┆ onset  ┆ 75360 ┆ 2019-03-04T01:54:00-