In [1]:
!pip install catboost


Collecting catboost
  Downloading catboost-1.2.5-cp311-cp311-macosx_11_0_universal2.whl.metadata (1.2 kB)
Downloading catboost-1.2.5-cp311-cp311-macosx_11_0_universal2.whl (26.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.2/26.2 MB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hInstalling collected packages: catboost
Successfully installed catboost-1.2.5


In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import catboost
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
#cv
from sklearn.model_selection import cross_val_score


train = pd.read_csv('wineq_train.csv')

train.head()    


Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality
0,7.0,0.27,0.36,20.7,0.045,45.0,170.0,1.001,3.0,0.45,8.8,6
1,6.3,0.3,0.34,1.6,0.049,14.0,132.0,0.994,3.3,0.49,9.5,6
2,8.1,0.28,0.4,6.9,0.05,30.0,97.0,0.9951,3.26,0.44,10.1,6
3,7.2,0.23,0.32,8.5,0.058,47.0,186.0,0.9956,3.19,0.4,9.9,6
4,7.2,0.23,0.32,8.5,0.058,47.0,186.0,0.9956,3.19,0.4,9.9,6


In [7]:
import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from catboost import CatBoostClassifier, Pool
from sklearn.metrics import mean_squared_error
from tqdm import tqdm

# Split the data into X and y
X = train.drop(columns=['quality'])
y = train['quality']

# Split the data into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create data pools
train_pool = Pool(X_train, y_train)
test_pool = Pool(X_test, y_test)

# Initialize the model
model = CatBoostClassifier(loss_function='MultiClass', random_seed=42, verbose=0)

# Define the parameter grid
param_grid = {
    'iterations': [500, 1000],
    'learning_rate': [0.01, 0.05, 0.1],
    'depth': [4, 6, 8, 10],
    'l2_leaf_reg': [1, 3, 5, 7, 9],
    'border_count': [32, 64, 128, 256]
}

# Initialize StratifiedKFold
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Initialize GridSearchCV
grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=cv, scoring='neg_root_mean_squared_error', verbose=2, n_jobs=-1)

# Fit GridSearchCV
grid_search.fit(X_train, y_train)

# Get the best parameters
best_params = grid_search.best_params_

# Initialize the model with the best parameters
best_model = CatBoostClassifier(**best_params, loss_function='MultiClass', random_seed=42, verbose=200)

# Fit the model with the best parameters
best_model.fit(train_pool, eval_set=test_pool)

# Feature selection using manual LOFO
def calculate_lofo_importance(X, y, model, cv):
    baseline_score = -mean_squared_error(y, model.predict(X), squared=False)
    feature_importance = {}
    
    for feature in tqdm(X.columns, desc="Calculating LOFO importance"):
        X_loo = X.drop(columns=[feature])
        scores = []
        
        for train_idx, val_idx in cv.split(X_loo, y):
            X_train_cv, X_val_cv = X_loo.iloc[train_idx], X_loo.iloc[val_idx]
            y_train_cv, y_val_cv = y.iloc[train_idx], y.iloc[val_idx]
            
            model.fit(X_train_cv, y_train_cv)
            y_pred_cv = model.predict(X_val_cv)
            score = -mean_squared_error(y_val_cv, y_pred_cv, squared=False)
            scores.append(score)
        
        feature_importance[feature] = baseline_score - np.mean(scores)
    
    importance_df = pd.DataFrame.from_dict(feature_importance, orient='index', columns=['importance'])
    importance_df = importance_df.sort_values(by='importance', ascending=False)
    
    return importance_df

# Calculate LOFO importance
importance_df = calculate_lofo_importance(X_train, y_train, best_model, cv)

# Select features with positive importance
selected_features = importance_df[importance_df["importance"] > 0].index.tolist()
X_train_selected = X_train[selected_features]
X_test_selected = X_test[selected_features]

# Create new data pools with selected features
train_pool_selected = Pool(X_train_selected, y_train)
test_pool_selected = Pool(X_test_selected, y_test)

# Fit the model again with the selected features
best_model.fit(train_pool_selected, eval_set=test_pool_selected)

# Predict and evaluate the model
y_pred = best_model.predict(X_test_selected)
rmse = mean_squared_error(y_test, y_pred, squared=False)

# You can add more evaluation metrics here
print("RMSE:", rmse)
print("Classification Report:\n", classification_report(y_test, y_pred))


Fitting 5 folds for each of 480 candidates, totalling 2400 fits




[CV] END border_count=32, depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.05; total time=   1.3s
[CV] END border_count=32, depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.01; total time=   1.3s
[CV] END border_count=32, depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.05; total time=   1.3s
[CV] END border_count=32, depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.01; total time=   1.3s
[CV] END border_count=32, depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.01; total time=   1.4s
[CV] END border_count=32, depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.05; total time=   1.4s
[CV] END border_count=32, depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.01; total time=   1.4s
[CV] END border_count=32, depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.01; total time=   1.5s
[CV] END border_count=32, depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.1; total time=   1.3s
[CV] END border_count=32, depth=4, iterations=500, l2_le



499:	learn: 0.2503945	test: 0.8315067	best: 0.8315067 (499)	total: 2.81s	remaining: 0us

bestTest = 0.8315067111
bestIteration = 499



Calculating LOFO importance:   0%|          | 0/11 [00:00<?, ?it/s]

0:	learn: 1.8677289	total: 5.8ms	remaining: 2.9s




200:	learn: 0.5493386	total: 991ms	remaining: 1.47s
400:	learn: 0.3005561	total: 1.97s	remaining: 486ms
499:	learn: 0.2334710	total: 2.44s	remaining: 0us
0:	learn: 1.8676649	total: 5.26ms	remaining: 2.62s




200:	learn: 0.5497373	total: 1.01s	remaining: 1.5s
400:	learn: 0.3009809	total: 1.99s	remaining: 492ms
499:	learn: 0.2383952	total: 2.48s	remaining: 0us
0:	learn: 1.8687618	total: 4.78ms	remaining: 2.38s




200:	learn: 0.5397517	total: 1.05s	remaining: 1.57s
400:	learn: 0.2936206	total: 2.04s	remaining: 504ms
499:	learn: 0.2297865	total: 2.52s	remaining: 0us
0:	learn: 1.8662591	total: 5.01ms	remaining: 2.5s




200:	learn: 0.5354045	total: 970ms	remaining: 1.44s
400:	learn: 0.2969210	total: 1.95s	remaining: 481ms
499:	learn: 0.2314764	total: 2.43s	remaining: 0us
0:	learn: 1.8663867	total: 4.69ms	remaining: 2.34s




200:	learn: 0.5311893	total: 991ms	remaining: 1.47s
400:	learn: 0.2878822	total: 2.02s	remaining: 498ms


Calculating LOFO importance:   9%|▉         | 1/11 [00:12<02:08, 12.83s/it]

499:	learn: 0.2231441	total: 2.66s	remaining: 0us
0:	learn: 1.8717445	total: 5.58ms	remaining: 2.78s




200:	learn: 0.5668292	total: 1.2s	remaining: 1.78s
400:	learn: 0.3091981	total: 2.22s	remaining: 548ms
499:	learn: 0.2427430	total: 2.71s	remaining: 0us
0:	learn: 1.8746630	total: 5.15ms	remaining: 2.57s




200:	learn: 0.5746098	total: 1.01s	remaining: 1.51s
400:	learn: 0.3167333	total: 2.05s	remaining: 505ms
499:	learn: 0.2478780	total: 2.56s	remaining: 0us
0:	learn: 1.8729152	total: 4.95ms	remaining: 2.47s




200:	learn: 0.5582323	total: 1.02s	remaining: 1.52s
400:	learn: 0.3102097	total: 2.3s	remaining: 567ms
499:	learn: 0.2429203	total: 2.79s	remaining: 0us
0:	learn: 1.8696959	total: 6.07ms	remaining: 3.03s




200:	learn: 0.5509603	total: 1.01s	remaining: 1.5s
400:	learn: 0.3037198	total: 2.05s	remaining: 505ms
499:	learn: 0.2397111	total: 2.54s	remaining: 0us
0:	learn: 1.8715008	total: 4.99ms	remaining: 2.49s




200:	learn: 0.5602859	total: 1.07s	remaining: 1.6s
400:	learn: 0.3039666	total: 2.13s	remaining: 527ms


Calculating LOFO importance:  18%|█▊        | 2/11 [00:26<02:00, 13.35s/it]

499:	learn: 0.2384716	total: 2.77s	remaining: 0us
0:	learn: 1.8682723	total: 5.77ms	remaining: 2.88s




200:	learn: 0.5560399	total: 1.17s	remaining: 1.75s
400:	learn: 0.3057409	total: 2.27s	remaining: 560ms
499:	learn: 0.2377393	total: 2.79s	remaining: 0us
0:	learn: 1.8677277	total: 5.32ms	remaining: 2.65s




200:	learn: 0.5566351	total: 1.2s	remaining: 1.78s
400:	learn: 0.3103061	total: 2.61s	remaining: 645ms
499:	learn: 0.2429077	total: 3.24s	remaining: 0us
0:	learn: 1.8668561	total: 9.43ms	remaining: 4.7s




200:	learn: 0.5423901	total: 1.14s	remaining: 1.7s
400:	learn: 0.2972213	total: 2.39s	remaining: 591ms
499:	learn: 0.2329113	total: 2.93s	remaining: 0us
0:	learn: 1.8629977	total: 5.67ms	remaining: 2.83s




200:	learn: 0.5353593	total: 1.17s	remaining: 1.75s
400:	learn: 0.2996115	total: 2.47s	remaining: 609ms
499:	learn: 0.2324159	total: 3.05s	remaining: 0us
0:	learn: 1.8670656	total: 5.59ms	remaining: 2.79s




200:	learn: 0.5446873	total: 1.19s	remaining: 1.77s
400:	learn: 0.2986393	total: 2.53s	remaining: 624ms


Calculating LOFO importance:  27%|██▋       | 3/11 [00:42<01:54, 14.37s/it]

499:	learn: 0.2314501	total: 3.1s	remaining: 0us
0:	learn: 1.8694981	total: 7.97ms	remaining: 3.98s




200:	learn: 0.5513174	total: 1.07s	remaining: 1.59s
400:	learn: 0.2967744	total: 2.07s	remaining: 511ms
499:	learn: 0.2289893	total: 2.58s	remaining: 0us
0:	learn: 1.8686498	total: 5.49ms	remaining: 2.74s




200:	learn: 0.5560745	total: 1.03s	remaining: 1.53s
400:	learn: 0.3051002	total: 2.17s	remaining: 536ms
499:	learn: 0.2360819	total: 2.68s	remaining: 0us
0:	learn: 1.8658033	total: 5.12ms	remaining: 2.56s




200:	learn: 0.5442097	total: 1.03s	remaining: 1.54s
400:	learn: 0.2923738	total: 2.11s	remaining: 522ms
499:	learn: 0.2287866	total: 2.63s	remaining: 0us
0:	learn: 1.8636791	total: 4.82ms	remaining: 2.41s




200:	learn: 0.5301475	total: 1.12s	remaining: 1.66s
400:	learn: 0.2887909	total: 2.16s	remaining: 534ms
499:	learn: 0.2241800	total: 2.68s	remaining: 0us
0:	learn: 1.8678652	total: 5.34ms	remaining: 2.66s




200:	learn: 0.5373573	total: 1.1s	remaining: 1.64s
400:	learn: 0.2845929	total: 2.25s	remaining: 555ms


Calculating LOFO importance:  36%|███▋      | 4/11 [00:55<01:38, 14.09s/it]

499:	learn: 0.2186695	total: 2.77s	remaining: 0us
0:	learn: 1.8686144	total: 5.91ms	remaining: 2.95s




200:	learn: 0.5454318	total: 1.04s	remaining: 1.55s
400:	learn: 0.3014097	total: 2.27s	remaining: 561ms
499:	learn: 0.2354544	total: 2.85s	remaining: 0us
0:	learn: 1.8668332	total: 4.97ms	remaining: 2.48s




200:	learn: 0.5514466	total: 1.09s	remaining: 1.62s
400:	learn: 0.2990559	total: 2.19s	remaining: 540ms
499:	learn: 0.2347042	total: 2.76s	remaining: 0us
0:	learn: 1.8643381	total: 5.73ms	remaining: 2.86s




200:	learn: 0.5495318	total: 1.1s	remaining: 1.64s
400:	learn: 0.2993610	total: 2.24s	remaining: 552ms
499:	learn: 0.2367743	total: 2.77s	remaining: 0us
0:	learn: 1.8634277	total: 4.82ms	remaining: 2.41s




200:	learn: 0.5311477	total: 1.04s	remaining: 1.55s
400:	learn: 0.2966687	total: 2.15s	remaining: 530ms
499:	learn: 0.2301623	total: 2.68s	remaining: 0us
0:	learn: 1.8680640	total: 5.7ms	remaining: 2.84s




200:	learn: 0.5392361	total: 1.08s	remaining: 1.61s
400:	learn: 0.2916828	total: 2.15s	remaining: 530ms


Calculating LOFO importance:  45%|████▌     | 5/11 [01:09<01:24, 14.05s/it]

499:	learn: 0.2264038	total: 2.65s	remaining: 0us
0:	learn: 1.8686144	total: 4.58ms	remaining: 2.29s




200:	learn: 0.5629709	total: 1.09s	remaining: 1.62s
400:	learn: 0.3062177	total: 2.3s	remaining: 568ms
499:	learn: 0.2387791	total: 2.81s	remaining: 0us
0:	learn: 1.8668332	total: 4.96ms	remaining: 2.48s




200:	learn: 0.5587655	total: 1.2s	remaining: 1.78s
400:	learn: 0.3037252	total: 2.46s	remaining: 608ms
499:	learn: 0.2377467	total: 3.05s	remaining: 0us
0:	learn: 1.8643381	total: 4.78ms	remaining: 2.38s




200:	learn: 0.5605809	total: 1.36s	remaining: 2.03s
400:	learn: 0.3017853	total: 2.67s	remaining: 659ms
499:	learn: 0.2315254	total: 3.29s	remaining: 0us
0:	learn: 1.8634277	total: 5.53ms	remaining: 2.76s




200:	learn: 0.5454354	total: 1.26s	remaining: 1.87s
400:	learn: 0.3035240	total: 2.51s	remaining: 620ms
499:	learn: 0.2376578	total: 3.11s	remaining: 0us
0:	learn: 1.8680640	total: 7.05ms	remaining: 3.52s




200:	learn: 0.5428916	total: 1.3s	remaining: 1.93s
400:	learn: 0.2969456	total: 2.58s	remaining: 637ms


Calculating LOFO importance:  55%|█████▍    | 6/11 [01:25<01:13, 14.71s/it]

499:	learn: 0.2291925	total: 3.32s	remaining: 0us
0:	learn: 1.8661752	total: 5.43ms	remaining: 2.71s




200:	learn: 0.5584944	total: 1.13s	remaining: 1.69s
400:	learn: 0.2995513	total: 2.31s	remaining: 571ms
499:	learn: 0.2310240	total: 2.92s	remaining: 0us
0:	learn: 1.8658145	total: 5.11ms	remaining: 2.55s




200:	learn: 0.5510564	total: 1.28s	remaining: 1.91s
400:	learn: 0.3012572	total: 2.74s	remaining: 676ms
499:	learn: 0.2352545	total: 3.38s	remaining: 0us
0:	learn: 1.8653900	total: 17.6ms	remaining: 8.76s




200:	learn: 0.5489283	total: 1.41s	remaining: 2.09s
400:	learn: 0.2906350	total: 2.72s	remaining: 671ms
499:	learn: 0.2290656	total: 3.47s	remaining: 0us
0:	learn: 1.8616462	total: 5.07ms	remaining: 2.53s




200:	learn: 0.5361366	total: 1.39s	remaining: 2.06s
400:	learn: 0.2955245	total: 2.65s	remaining: 656ms
499:	learn: 0.2269923	total: 3.22s	remaining: 0us
0:	learn: 1.8680640	total: 5.29ms	remaining: 2.64s




200:	learn: 0.5328969	total: 1.22s	remaining: 1.81s
400:	learn: 0.2915794	total: 2.39s	remaining: 589ms


Calculating LOFO importance:  64%|██████▎   | 7/11 [01:42<01:01, 15.30s/it]

499:	learn: 0.2268297	total: 2.99s	remaining: 0us
0:	learn: 1.8660952	total: 5.27ms	remaining: 2.63s




200:	learn: 0.5325366	total: 1.26s	remaining: 1.88s
400:	learn: 0.2830233	total: 2.65s	remaining: 654ms
499:	learn: 0.2174152	total: 3.34s	remaining: 0us
0:	learn: 1.8651263	total: 4.97ms	remaining: 2.48s




200:	learn: 0.5315806	total: 1.47s	remaining: 2.19s
400:	learn: 0.2868304	total: 2.81s	remaining: 695ms
499:	learn: 0.2234946	total: 3.49s	remaining: 0us
0:	learn: 1.8658806	total: 4.8ms	remaining: 2.4s




200:	learn: 0.5298787	total: 1.28s	remaining: 1.91s
400:	learn: 0.2842602	total: 2.58s	remaining: 637ms
499:	learn: 0.2218918	total: 3.18s	remaining: 0us
0:	learn: 1.8615671	total: 5.41ms	remaining: 2.7s




200:	learn: 0.5240857	total: 1.21s	remaining: 1.8s
400:	learn: 0.2851625	total: 2.38s	remaining: 588ms
499:	learn: 0.2219641	total: 2.99s	remaining: 0us
0:	learn: 1.8674383	total: 5.15ms	remaining: 2.57s




200:	learn: 0.5253061	total: 1.07s	remaining: 1.59s
400:	learn: 0.2791367	total: 2.12s	remaining: 524ms


Calculating LOFO importance:  73%|███████▎  | 8/11 [01:58<00:46, 15.57s/it]

499:	learn: 0.2150518	total: 2.7s	remaining: 0us
0:	learn: 1.8660952	total: 25.5ms	remaining: 12.7s




200:	learn: 0.5609395	total: 1.39s	remaining: 2.06s
400:	learn: 0.3059680	total: 2.69s	remaining: 664ms
499:	learn: 0.2381970	total: 3.31s	remaining: 0us
0:	learn: 1.8651263	total: 15.2ms	remaining: 7.59s




200:	learn: 0.5525676	total: 1.32s	remaining: 1.96s
400:	learn: 0.3028373	total: 2.58s	remaining: 638ms
499:	learn: 0.2379012	total: 3.16s	remaining: 0us
0:	learn: 1.8658806	total: 5.41ms	remaining: 2.7s




200:	learn: 0.5469538	total: 1.12s	remaining: 1.66s
400:	learn: 0.2999430	total: 2.26s	remaining: 557ms
499:	learn: 0.2330160	total: 2.82s	remaining: 0us
0:	learn: 1.8615671	total: 5.53ms	remaining: 2.76s




200:	learn: 0.5379160	total: 1.19s	remaining: 1.77s
400:	learn: 0.2974577	total: 2.35s	remaining: 580ms
499:	learn: 0.2318790	total: 2.92s	remaining: 0us
0:	learn: 1.8646294	total: 5.83ms	remaining: 2.91s




200:	learn: 0.5427429	total: 1.11s	remaining: 1.65s
400:	learn: 0.2959936	total: 2.29s	remaining: 565ms


Calculating LOFO importance:  82%|████████▏ | 9/11 [02:13<00:31, 15.52s/it]

499:	learn: 0.2294971	total: 2.84s	remaining: 0us
0:	learn: 1.8660952	total: 5.05ms	remaining: 2.52s




200:	learn: 0.5562403	total: 1.17s	remaining: 1.74s
400:	learn: 0.3097594	total: 2.3s	remaining: 567ms
499:	learn: 0.2395708	total: 2.85s	remaining: 0us
0:	learn: 1.8644555	total: 7.05ms	remaining: 3.52s




200:	learn: 0.5544124	total: 1.2s	remaining: 1.78s
400:	learn: 0.3093188	total: 2.36s	remaining: 583ms
499:	learn: 0.2441208	total: 2.92s	remaining: 0us
0:	learn: 1.8658806	total: 5.08ms	remaining: 2.54s




200:	learn: 0.5422712	total: 1.13s	remaining: 1.68s
400:	learn: 0.3017752	total: 2.24s	remaining: 554ms
499:	learn: 0.2364508	total: 2.81s	remaining: 0us
0:	learn: 1.8615671	total: 5.83ms	remaining: 2.91s




200:	learn: 0.5362713	total: 1.09s	remaining: 1.63s
400:	learn: 0.2977304	total: 2.19s	remaining: 541ms
499:	learn: 0.2344067	total: 2.73s	remaining: 0us
0:	learn: 1.8646294	total: 5.04ms	remaining: 2.51s




200:	learn: 0.5338822	total: 1.17s	remaining: 1.75s
400:	learn: 0.2935191	total: 2.29s	remaining: 565ms


Calculating LOFO importance:  91%|█████████ | 10/11 [02:28<00:15, 15.21s/it]

499:	learn: 0.2295068	total: 2.83s	remaining: 0us
0:	learn: 1.8698291	total: 4.86ms	remaining: 2.42s




200:	learn: 0.5461156	total: 1.08s	remaining: 1.61s
400:	learn: 0.2893728	total: 2.21s	remaining: 545ms
499:	learn: 0.2252415	total: 2.8s	remaining: 0us
0:	learn: 1.8694072	total: 5.31ms	remaining: 2.65s




200:	learn: 0.5426120	total: 1.11s	remaining: 1.66s
400:	learn: 0.2951680	total: 2.21s	remaining: 546ms
499:	learn: 0.2275348	total: 2.76s	remaining: 0us
0:	learn: 1.8699554	total: 6.48ms	remaining: 3.23s




200:	learn: 0.5378472	total: 1.12s	remaining: 1.67s
400:	learn: 0.2900397	total: 2.24s	remaining: 554ms
499:	learn: 0.2275070	total: 2.8s	remaining: 0us
0:	learn: 1.8688508	total: 6.03ms	remaining: 3.01s




200:	learn: 0.5266734	total: 1.11s	remaining: 1.66s
400:	learn: 0.2863375	total: 2.27s	remaining: 560ms
499:	learn: 0.2218273	total: 2.81s	remaining: 0us
0:	learn: 1.8680814	total: 4.92ms	remaining: 2.45s




200:	learn: 0.5343951	total: 1.1s	remaining: 1.64s
400:	learn: 0.2879869	total: 2.35s	remaining: 581ms


Calculating LOFO importance: 100%|██████████| 11/11 [02:42<00:00, 14.79s/it]

499:	learn: 0.2219707	total: 2.88s	remaining: 0us





0:	learn: 1.8664119	test: 1.8712757	best: 1.8712757 (0)	total: 4.37ms	remaining: 2.18s
200:	learn: 0.5722594	test: 0.9060651	best: 0.9060651 (200)	total: 1.1s	remaining: 1.63s
400:	learn: 0.3216893	test: 0.8458424	best: 0.8458424 (400)	total: 2.33s	remaining: 575ms
499:	learn: 0.2560347	test: 0.8358186	best: 0.8355334 (495)	total: 2.93s	remaining: 0us

bestTest = 0.835533401
bestIteration = 495

Shrink model to first 496 iterations.
RMSE: 0.6667792697696852
Classification Report:
               precision    recall  f1-score   support

           3       0.00      0.00      0.00         3
           4       0.50      0.18      0.26        28
           5       0.72      0.69      0.70       238
           6       0.65      0.80      0.72       313
           7       0.70      0.58      0.64       127
           8       0.73      0.38      0.50        29
           9       0.00      0.00      0.00         2

    accuracy                           0.68       740
   macro avg       0.47   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [11]:
X_train_selected = X[selected_features]
#X_test_selected = X_test[selected_features]

# Create new data pools with selected features
train_pool_selected = Pool(X_train_selected, y)
#test_pool_selected = Pool(X_test_selected, y_test)

# Fit the model again with the selected features
best_model.fit(train_pool_selected)

val = pd.read_csv('../wineq_validation.csv')
X_val = val[selected_features]

y_val = best_model.predict(X_val)

#write to txt 
with open('wineq_predictionscb.txt', 'w') as f:
    for item in y_val:
        f.write("%s\n" % item)


0:	learn: 1.8649334	total: 41.2ms	remaining: 20.5s
200:	learn: 0.5834450	total: 1.39s	remaining: 2.07s
400:	learn: 0.3466127	total: 2.65s	remaining: 655ms
499:	learn: 0.2762604	total: 3.27s	remaining: 0us


In [10]:
best_params

{'border_count': 128,
 'depth': 8,
 'iterations': 500,
 'l2_leaf_reg': 1,
 'learning_rate': 0.05}