In [2]:
import polars as pl
from sklearn.preprocessing import LabelEncoder

from src.data import anime_id_label_encoding, load_data, user_id_label_encoding

pl.Config.set_fmt_str_lengths(100)
pl.Config.set_tbl_rows(10)
pl.Config.set_tbl_cols(100)


polars.config.Config

In [2]:
train, test, anime = load_data()
train, test = user_id_label_encoding(train, test)
train, test, anime = anime_id_label_encoding(train, test, anime)


In [3]:
train_test = pl.concat([train, test], how="diagonal")
train_test


user_id,anime_id,score
i16,i16,i8
0,49,2
0,122,10
0,234,1
0,263,8
0,318,9
…,…,…
1997,1908,
1997,1910,
1997,1915,
1997,1986,


### user_id毎にどのanimeを見ているかをone-hotにする



In [4]:
# train_testを使ってanime_idをone-hotにする
ohe_anime_id_per_user_df = train_test["anime_id"].to_dummies()

# train_testと横方向にコンカット
ohe_anime_id_per_user_df = pl.concat([train_test["user_id"].to_frame(), ohe_anime_id_per_user_df], how="horizontal")

# user_id毎にgroup_byして、各user_idが何のanimeを見ているかのdfを作成
# → これをtrain_testと結合すれば特徴量になる（key=user_id, how="left"）
ohe_anime_id_per_user_df = ohe_anime_id_per_user_df.group_by("user_id", maintain_order=True).agg(pl.all().sum())

ohe_anime_id_per_user_df.select(pl.all().shrink_dtype())


user_id,anime_id_0,anime_id_1,anime_id_10,anime_id_100,anime_id_1000,anime_id_1001,anime_id_1002,anime_id_1003,anime_id_1004,anime_id_1005,anime_id_1006,anime_id_1007,anime_id_1008,anime_id_1009,anime_id_101,anime_id_1010,anime_id_1011,anime_id_1012,anime_id_1013,anime_id_1014,anime_id_1015,anime_id_1016,anime_id_1017,anime_id_1018,anime_id_1019,anime_id_102,anime_id_1020,anime_id_1021,anime_id_1022,anime_id_1023,anime_id_1024,anime_id_1025,anime_id_1026,anime_id_1027,anime_id_1028,anime_id_1029,anime_id_103,anime_id_1030,anime_id_1031,anime_id_1032,anime_id_1033,anime_id_1034,anime_id_1035,anime_id_1036,anime_id_1037,anime_id_1038,anime_id_1039,anime_id_104,anime_id_1040,…,anime_id_954,anime_id_955,anime_id_956,anime_id_957,anime_id_958,anime_id_959,anime_id_96,anime_id_960,anime_id_961,anime_id_962,anime_id_963,anime_id_964,anime_id_965,anime_id_966,anime_id_967,anime_id_968,anime_id_969,anime_id_97,anime_id_970,anime_id_971,anime_id_972,anime_id_973,anime_id_974,anime_id_975,anime_id_976,anime_id_977,anime_id_978,anime_id_979,anime_id_98,anime_id_980,anime_id_981,anime_id_982,anime_id_983,anime_id_984,anime_id_985,anime_id_986,anime_id_987,anime_id_988,anime_id_989,anime_id_99,anime_id_990,anime_id_991,anime_id_992,anime_id_993,anime_id_994,anime_id_995,anime_id_996,anime_id_997,anime_id_998,anime_id_999
i16,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,…,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,1,1,0,1,0,0,0,0,0,0,0,0,0,0,0,1,…,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,1,0,0,0,0,0,0,1,0,1,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1973,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1976,0,1,0,0,0,0,0,0,1,1,0,1,0,0,0,1,0,1,0,0,0,1,0,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,…,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,1,0,0,0,1
1984,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1988,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


### anime_id毎にどのuserが見ているかをone-hotにする


In [5]:
# train_testを使ってanime_idをone-hotにする
ohe_user_id_per_anime_id_df = train_test["user_id"].to_dummies()

# train_testと横方向にコンカット
ohe_user_id_per_anime_id_df = pl.concat(
    [train_test["anime_id"].to_frame(), ohe_user_id_per_anime_id_df], how="horizontal"
)

# anime_id毎にgroup_byして、各anime_idをどのuserが見ているかのdfを作成
# → これをtrain_testと結合すれば特徴量になる（key=anime_id, how="left"）
ohe_user_id_per_anime_id_df = ohe_user_id_per_anime_id_df.group_by("anime_id", maintain_order=True).agg(pl.all().sum())

ohe_user_id_per_anime_id_df.select(pl.all().shrink_dtype())


anime_id,user_id_0,user_id_1,user_id_10,user_id_100,user_id_1000,user_id_1001,user_id_1002,user_id_1003,user_id_1004,user_id_1005,user_id_1006,user_id_1007,user_id_1008,user_id_1009,user_id_101,user_id_1010,user_id_1011,user_id_1012,user_id_1013,user_id_1014,user_id_1015,user_id_1016,user_id_1017,user_id_1018,user_id_1019,user_id_102,user_id_1020,user_id_1021,user_id_1022,user_id_1023,user_id_1024,user_id_1025,user_id_1026,user_id_1027,user_id_1028,user_id_1029,user_id_103,user_id_1030,user_id_1031,user_id_1032,user_id_1033,user_id_1034,user_id_1035,user_id_1036,user_id_1037,user_id_1038,user_id_1039,user_id_104,user_id_1040,…,user_id_954,user_id_955,user_id_956,user_id_957,user_id_958,user_id_959,user_id_96,user_id_960,user_id_961,user_id_962,user_id_963,user_id_964,user_id_965,user_id_966,user_id_967,user_id_968,user_id_969,user_id_97,user_id_970,user_id_971,user_id_972,user_id_973,user_id_974,user_id_975,user_id_976,user_id_977,user_id_978,user_id_979,user_id_98,user_id_980,user_id_981,user_id_982,user_id_983,user_id_984,user_id_985,user_id_986,user_id_987,user_id_988,user_id_989,user_id_99,user_id_990,user_id_991,user_id_992,user_id_993,user_id_994,user_id_995,user_id_996,user_id_997,user_id_998,user_id_999
i16,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,…,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8
49,1,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,1,0,0,0,0,0,0,0
122,1,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,1,0,0,0,0,…,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0
234,1,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
263,1,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,1,1,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,1,1,0,1,1,0,0,0,0,1,0,0,…,0,1,1,0,0,1,0,0,0,1,0,0,1,0,1,1,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,1,1,0,0,0,0,1,1,0,0,0,0,0,0,1
318,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1041,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
224,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
942,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1679,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


### 特徴量として使用する


In [6]:
train_test_ohe = train_test.join(ohe_anime_id_per_user_df, on="user_id", how="left")
train_test_ohe = train_test_ohe.join(ohe_user_id_per_anime_id_df, on="anime_id", how="left")
train_test_ohe = train_test_ohe.select(pl.all().shrink_dtype())

train_ohe = train_test_ohe.filter(pl.col("score").is_not_null())
test_ohe = train_test_ohe.filter(pl.col("score").is_null()).drop(["score"])


In [7]:
train_ohe


user_id,anime_id,score,anime_id_0,anime_id_1,anime_id_10,anime_id_100,anime_id_1000,anime_id_1001,anime_id_1002,anime_id_1003,anime_id_1004,anime_id_1005,anime_id_1006,anime_id_1007,anime_id_1008,anime_id_1009,anime_id_101,anime_id_1010,anime_id_1011,anime_id_1012,anime_id_1013,anime_id_1014,anime_id_1015,anime_id_1016,anime_id_1017,anime_id_1018,anime_id_1019,anime_id_102,anime_id_1020,anime_id_1021,anime_id_1022,anime_id_1023,anime_id_1024,anime_id_1025,anime_id_1026,anime_id_1027,anime_id_1028,anime_id_1029,anime_id_103,anime_id_1030,anime_id_1031,anime_id_1032,anime_id_1033,anime_id_1034,anime_id_1035,anime_id_1036,anime_id_1037,anime_id_1038,anime_id_1039,…,user_id_954,user_id_955,user_id_956,user_id_957,user_id_958,user_id_959,user_id_96,user_id_960,user_id_961,user_id_962,user_id_963,user_id_964,user_id_965,user_id_966,user_id_967,user_id_968,user_id_969,user_id_97,user_id_970,user_id_971,user_id_972,user_id_973,user_id_974,user_id_975,user_id_976,user_id_977,user_id_978,user_id_979,user_id_98,user_id_980,user_id_981,user_id_982,user_id_983,user_id_984,user_id_985,user_id_986,user_id_987,user_id_988,user_id_989,user_id_99,user_id_990,user_id_991,user_id_992,user_id_993,user_id_994,user_id_995,user_id_996,user_id_997,user_id_998,user_id_999
i16,i16,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,…,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8
0,49,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,1,0,0,0,0,0,0,0
0,122,10,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0
0,234,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
0,263,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,1,1,0,0,1,0,0,0,1,0,0,1,0,1,1,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,1,1,0,0,0,0,1,1,0,0,0,0,0,0,1
0,318,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1996,1814,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0
1996,1819,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
1996,1850,7,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1996,1967,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1


### subしてみる


In [10]:
import pickle

import lightgbm as lgb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import StratifiedKFold

# Initialize the StratifiedKFold object
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Create a column for the fold number
train_ohe = train_ohe.with_columns(pl.lit(-1).alias("fold"))

# Assign the fold number to each row
for fold, (_, val_index) in enumerate(skf.split(train_ohe, train_ohe["score"])):
    train_ohe[val_index, "fold"] = fold

# Training and evaluation with LightGBM
scores_lgb = []
models_lgb = []
feature_importances = []

for fold in range(5):
    print(f"Training for fold {fold}...")

    # Prepare the train and validation data
    train_data = train_ohe.filter(pl.col("fold") != fold)
    val_data = train_ohe.filter(pl.col("fold") == fold)

    # Define the features and the target
    features = test_ohe.columns
    target = "score"

    # Prepare the LightGBM datasets
    lgb_train = lgb.Dataset(train_data[features].to_pandas(), train_data[target].to_pandas())
    lgb_val = lgb.Dataset(val_data[features].to_pandas(), val_data[target].to_pandas())

    # Define the parameters
    params = {
        "objective": "regression",
        "metric": "rmse",
        "verbosity": -1,
        "boosting_type": "gbdt",
        "learning_rate": 0.01,
        "num_leaves": 31,
        "min_child_samples": 20,
        "max_depth": -1,
        "subsample_freq": 0,
        "bagging_seed": 0,
        "feature_fraction": 0.9,
        "bagging_fraction": 0.8,
        "reg_alpha": 0.1,
        "reg_lambda": 0.1,
    }

    # Train the model
    callbacks = [lgb.early_stopping(stopping_rounds=500), lgb.log_evaluation(period=100)]
    model_lgb = lgb.train(
        params,
        lgb_train,
        valid_sets=[lgb_val],
        callbacks=callbacks,
        num_boost_round=10000,
    )

    # Save the model
    with open(f"model_lgb_{fold}.pkl", "wb") as f:
        pickle.dump(model_lgb, f)

    # Predict the validation data
    val_pred_lgb = model_lgb.predict(val_data[features].to_pandas(), num_iteration=model_lgb.best_iteration)

    # Evaluate the model
    score_lgb = np.sqrt(mean_squared_error(val_data[target].to_pandas(), val_pred_lgb))
    scores_lgb.append(score_lgb)

    print(f"RMSE for fold {fold}: {score_lgb}")

    # Save feature importances
    feature_importances.append(model_lgb.feature_importance(importance_type="gain"))

# Calculate the average score
average_score_lgb = np.mean(scores_lgb)

print(f"Average RMSE: {average_score_lgb}")

# Calculate the average feature importance
average_feature_importance = np.mean(feature_importances, axis=0)
feature_importance_df = pd.DataFrame({"feature": features, "importance": average_feature_importance}).sort_values(
    by="importance", ascending=False
)

print("Feature Importances:")
print(feature_importance_df)

# Predict the test data and create the submission file
submission_df = pl.read_csv("../data/input/sample_submission.csv", try_parse_dates=True)
submission_df = submission_df.with_columns(pl.lit(0).alias("score"))

for fold in range(5):
    with open(f"model_lgb_{fold}.pkl", "rb") as f:
        model_lgb = pickle.load(f)
    test_pred_lgb = model_lgb.predict(test_ohe[features].to_pandas(), num_iteration=model_lgb.best_iteration)
    submission_df = submission_df.with_columns((pl.col("score") + pl.Series(test_pred_lgb) / 5).alias("score"))

submission_df.write_csv("../data/output/submission_baseline_ohe.csv")


Training for fold 0...
Training until validation scores don't improve for 500 rounds
[100]	valid_0's rmse: 1.433
[200]	valid_0's rmse: 1.37457
[300]	valid_0's rmse: 1.33347
[400]	valid_0's rmse: 1.30308
[500]	valid_0's rmse: 1.28048
[600]	valid_0's rmse: 1.26373
[700]	valid_0's rmse: 1.25088
[800]	valid_0's rmse: 1.24033
[900]	valid_0's rmse: 1.23184
[1000]	valid_0's rmse: 1.22546
[1100]	valid_0's rmse: 1.22016
[1200]	valid_0's rmse: 1.2164
[1300]	valid_0's rmse: 1.21355
[1400]	valid_0's rmse: 1.21136
[1500]	valid_0's rmse: 1.20931
[1600]	valid_0's rmse: 1.20746
[1700]	valid_0's rmse: 1.20559
[1800]	valid_0's rmse: 1.20378
[1900]	valid_0's rmse: 1.20201
[2000]	valid_0's rmse: 1.20049
[2100]	valid_0's rmse: 1.19903
[2200]	valid_0's rmse: 1.19748
[2300]	valid_0's rmse: 1.19621
[2400]	valid_0's rmse: 1.19475
[2500]	valid_0's rmse: 1.19349
[2600]	valid_0's rmse: 1.1923
[2700]	valid_0's rmse: 1.19099
[2800]	valid_0's rmse: 1.18982
[2900]	valid_0's rmse: 1.18867
[3000]	valid_0's rmse: 1.1876

In [1]:
# feature_importanceの上位20個をseabornで可視化する
import numpy as np
import seaborn as sns

# 上位20個の特徴量重要度を取得
top_20_features = feature_importance_df.head(30)

# カラフルな色を生成
colors = sns.color_palette("husl", len(top_20_features))

# 可視化
plt.figure(figsize=(10, 6))
sns.barplot(x="importance", y="feature", data=top_20_features, palette=colors)
plt.xlabel("Importance")
plt.ylabel("Feature")
plt.title("Top 30 Feature Importances")
plt.show()


NameError: name 'feature_importance_df' is not defined