In [None]:
CFG = {
    "experiment_number": 1,
}

from srs import *

train = pl.read_parquet('data/train_main_features.parquet')
test = pl.read_parquet('data/test_main_features.parquet')

target = pl.read_parquet('data/train_target.parquet')

print('Тренировочные данные:', train.shape)
print('Тестовые данные:', test.shape)

cat_feature_names = [
    col_name for col_name in train.columns 
    if col_name.startswith("cat_feature")
]

train = cast_cat_features(train, cat_feature_names)
test = cast_cat_features(test, cat_feature_names)

train_pool = Pool(data = train.drop("customer_id").to_pandas(), 
                  label = target.drop("customer_id").to_pandas(), 
                  cat_features = cat_feature_names)

model = get_base_CatBoostClassifier()

In [None]:
model.fit(train_pool)

from pathlib import Path

Path("models").mkdir(parents=True, exist_ok=True)
model_path = Path("models") / f"exp_{CFG['experiment_number']}_catboost_model.cbm"
model.save_model(str(model_path))
print(f"Model saved to: {model_path}")

In [None]:
model = CatBoostClassifier()
model.load_model(str(model_path))

test_pool = Pool(data = test.drop("customer_id").to_pandas(), 
                 cat_features = cat_feature_names)

test_predict = model.predict(test_pool, prediction_type = "RawFormulaVal")

test_predict.shape

In [61]:
predict_schema = [col.replace("target_", "predict_") for col in target.columns if col.startswith("target_")]

catboost_predictions = pl.DataFrame(test_predict, schema = predict_schema)

catboost_predictions.head(n = 5)

predict_1_1,predict_1_2,predict_1_3,predict_1_4,predict_1_5,predict_2_1,predict_2_2,predict_2_3,predict_2_4,predict_2_5,predict_2_6,predict_2_7,predict_2_8,predict_3_1,predict_3_2,predict_3_3,predict_3_4,predict_3_5,predict_4_1,predict_5_1,predict_5_2,predict_6_1,predict_6_2,predict_6_3,predict_6_4,predict_6_5,predict_7_1,predict_7_2,predict_7_3,predict_8_1,predict_8_2,predict_8_3,predict_9_1,predict_9_2,predict_9_3,predict_9_4,predict_9_5,predict_9_6,predict_9_7,predict_9_8,predict_10_1
f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
-4.559634,-5.490853,-3.682293,-3.696798,-6.377251,-4.698343,-3.925055,-6.864407,-4.644582,-6.041761,-5.306951,-8.498892,-10.326492,-2.117544,-3.225609,-6.416635,-7.057265,-9.00383,-4.980861,-4.800834,-5.807614,-4.759972,-4.719907,-5.224776,-5.52965,-8.442456,-2.790907,-3.40613,-5.383466,-3.463931,-3.20201,-3.96947,-5.522536,-2.835501,-3.592803,-5.626597,-4.600432,-0.826411,-2.292409,-4.566064,-0.624437
-4.525017,-5.424643,-3.600475,-3.663936,-6.481639,-4.698258,-3.87817,-6.936742,-4.645305,-6.135729,-5.380085,-8.337462,-10.183376,-2.203137,-3.597962,-6.640242,-7.668241,-9.528925,-5.153873,-4.985227,-5.928929,-4.757343,-4.775738,-5.236387,-5.556886,-8.324126,-2.859017,-3.515063,-5.523479,-3.379077,-3.198607,-3.882064,-5.586788,-2.723239,-3.643311,-5.593574,-4.575498,-0.778346,-2.265025,-4.476992,-0.654095
-4.236281,-5.217968,-3.718711,-3.652767,-6.019878,-5.293401,-3.742374,-7.368946,-4.510376,-6.034828,-5.158191,-9.000582,-10.222584,-2.262793,-3.455289,-6.670321,-7.669997,-9.098556,-5.156947,-4.889535,-5.912771,-4.659181,-4.814485,-5.226197,-5.420546,-8.381626,-2.971064,-3.012568,-5.33779,-3.654234,-3.037381,-3.915707,-5.821074,-3.501836,-3.764656,-5.877993,-5.007001,-1.049139,-2.805007,-5.363468,-0.362334
-5.107494,-5.747457,-3.87031,-3.988364,-6.613195,-4.781696,-4.434008,-7.223677,-4.649338,-6.461804,-5.293109,-8.814394,-9.972295,-2.315931,-3.816451,-6.643194,-8.339884,-10.589949,-5.334445,-4.888459,-5.859745,-4.840233,-4.480154,-4.971953,-5.841156,-9.015239,-2.464894,-3.341553,-5.357212,-4.255379,-3.26839,-4.015814,-5.678406,-2.961717,-3.568433,-5.278345,-4.324995,-0.696825,-2.157653,-4.039829,-0.652303
-4.48209,-5.395573,-3.571668,-3.641002,-6.268344,-4.670509,-3.762274,-6.943392,-4.550145,-6.070182,-5.286359,-8.725736,-10.104682,-2.229007,-3.662391,-6.55999,-7.705663,-9.952746,-5.139228,-4.917467,-5.864986,-4.77367,-4.696359,-5.208519,-5.754242,-8.778719,-2.784675,-3.323991,-5.338707,-3.977604,-3.055069,-3.632175,-5.412687,-2.838579,-3.566938,-5.623631,-4.54547,-0.736783,-2.24171,-4.585213,-0.668844


In [62]:
submit = test.select("customer_id")

submit = submit.hstack(catboost_predictions)

In [63]:
submit.head(n = 5)

customer_id,predict_1_1,predict_1_2,predict_1_3,predict_1_4,predict_1_5,predict_2_1,predict_2_2,predict_2_3,predict_2_4,predict_2_5,predict_2_6,predict_2_7,predict_2_8,predict_3_1,predict_3_2,predict_3_3,predict_3_4,predict_3_5,predict_4_1,predict_5_1,predict_5_2,predict_6_1,predict_6_2,predict_6_3,predict_6_4,predict_6_5,predict_7_1,predict_7_2,predict_7_3,predict_8_1,predict_8_2,predict_8_3,predict_9_1,predict_9_2,predict_9_3,predict_9_4,predict_9_5,predict_9_6,predict_9_7,predict_9_8,predict_10_1
i32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
1750001,-4.559634,-5.490853,-3.682293,-3.696798,-6.377251,-4.698343,-3.925055,-6.864407,-4.644582,-6.041761,-5.306951,-8.498892,-10.326492,-2.117544,-3.225609,-6.416635,-7.057265,-9.00383,-4.980861,-4.800834,-5.807614,-4.759972,-4.719907,-5.224776,-5.52965,-8.442456,-2.790907,-3.40613,-5.383466,-3.463931,-3.20201,-3.96947,-5.522536,-2.835501,-3.592803,-5.626597,-4.600432,-0.826411,-2.292409,-4.566064,-0.624437
1750002,-4.525017,-5.424643,-3.600475,-3.663936,-6.481639,-4.698258,-3.87817,-6.936742,-4.645305,-6.135729,-5.380085,-8.337462,-10.183376,-2.203137,-3.597962,-6.640242,-7.668241,-9.528925,-5.153873,-4.985227,-5.928929,-4.757343,-4.775738,-5.236387,-5.556886,-8.324126,-2.859017,-3.515063,-5.523479,-3.379077,-3.198607,-3.882064,-5.586788,-2.723239,-3.643311,-5.593574,-4.575498,-0.778346,-2.265025,-4.476992,-0.654095
1750003,-4.236281,-5.217968,-3.718711,-3.652767,-6.019878,-5.293401,-3.742374,-7.368946,-4.510376,-6.034828,-5.158191,-9.000582,-10.222584,-2.262793,-3.455289,-6.670321,-7.669997,-9.098556,-5.156947,-4.889535,-5.912771,-4.659181,-4.814485,-5.226197,-5.420546,-8.381626,-2.971064,-3.012568,-5.33779,-3.654234,-3.037381,-3.915707,-5.821074,-3.501836,-3.764656,-5.877993,-5.007001,-1.049139,-2.805007,-5.363468,-0.362334
1750004,-5.107494,-5.747457,-3.87031,-3.988364,-6.613195,-4.781696,-4.434008,-7.223677,-4.649338,-6.461804,-5.293109,-8.814394,-9.972295,-2.315931,-3.816451,-6.643194,-8.339884,-10.589949,-5.334445,-4.888459,-5.859745,-4.840233,-4.480154,-4.971953,-5.841156,-9.015239,-2.464894,-3.341553,-5.357212,-4.255379,-3.26839,-4.015814,-5.678406,-2.961717,-3.568433,-5.278345,-4.324995,-0.696825,-2.157653,-4.039829,-0.652303
1750005,-4.48209,-5.395573,-3.571668,-3.641002,-6.268344,-4.670509,-3.762274,-6.943392,-4.550145,-6.070182,-5.286359,-8.725736,-10.104682,-2.229007,-3.662391,-6.55999,-7.705663,-9.952746,-5.139228,-4.917467,-5.864986,-4.77367,-4.696359,-5.208519,-5.754242,-8.778719,-2.784675,-3.323991,-5.338707,-3.977604,-3.055069,-3.632175,-5.412687,-2.838579,-3.566938,-5.623631,-4.54547,-0.736783,-2.24171,-4.585213,-0.668844


In [None]:
submit.write_parquet("submits/exp_{CFG['experiment_number']}_submit.parquet")