## Class Weightを計算する

1d/2d or 4ep/18epでCVとLBの傾向に乖離ができる原因がclass-balanceであると仮定する。
class-balanceの効果を打ち消すようなweightをclassごとに設定して学習し、CV/LBの傾向がどう変わるかを確認する。
classごとのweightはeegごとのweightをclassごとに平均することで計算する。

In [7]:
from pathlib import Path

import matplotlib.pyplot as plt
import polars as pl

from src.preprocess import process_label
from src.constant import LABELS


plt.style.use("tableau-colorblind10")

In [2]:
data_dir = Path("../../../input/hms-harmful-brain-activity-classification")
list(data_dir.glob("*"))

[PosixPath('../../../input/hms-harmful-brain-activity-classification/test_eegs'),
 PosixPath('../../../input/hms-harmful-brain-activity-classification/test.csv'),
 PosixPath('../../../input/hms-harmful-brain-activity-classification/example_figures'),
 PosixPath('../../../input/hms-harmful-brain-activity-classification/test_spectrograms'),
 PosixPath('../../../input/hms-harmful-brain-activity-classification/train_spectrograms'),
 PosixPath('../../../input/hms-harmful-brain-activity-classification/sample_submission.csv'),
 PosixPath('../../../input/hms-harmful-brain-activity-classification/train_eegs'),
 PosixPath('../../../input/hms-harmful-brain-activity-classification/train.csv')]

In [25]:
metadata = pl.read_csv(data_dir / "train.csv")
metadata = process_label(metadata)
metadata = metadata.group_by("eeg_id").agg(
    *[pl.col(f"{label}_prob_per_eeg").first() for label in LABELS],
    pl.col("weight_per_eeg").first(),
)
metadata = metadata.with_columns(
    pl.col(f"{label}_prob_per_eeg").mul("weight_per_eeg").alias(f"{label}_vote_per_eeg")
    for label in LABELS
)
meta_sum = metadata.sum()
label_prob = meta_sum.with_columns(
    pl.col(f"{label}_vote_per_eeg")
    .truediv(pl.col("weight_per_eeg"))
    .alias(f"{label}_prob")
    for label in LABELS
).select("^.*_prob$")
label_prob

seizure_prob,lpd_prob,gpd_prob,lrda_prob,grda_prob,other_prob
f64,f64,f64,f64,f64,f64
0.10051,0.163856,0.111289,0.068821,0.098882,0.456643


In [32]:
prob = label_prob.to_numpy()[0]
weight = 1 / prob
weight /= weight.sum()
weight *= 6
weight.round(3)

array([1.151, 0.706, 1.039, 1.681, 1.17 , 0.253])