In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import pandas as pd
from sklearn.metrics import (
    make_scorer,
    mean_absolute_error,
    mean_squared_error,
    r2_score
)
from sklearn.model_selection import cross_validate, StratifiedKFold
from sklearn.tree import DecisionTreeRegressor
import numpy as np

In [3]:
def root_mean_squared_error(y_true, y_pred):
    return np.sqrt(mean_squared_error(y_true, y_pred))

METRICS_REGRESSION = {
    "MAE": make_scorer(mean_absolute_error),
    "MSE": make_scorer(mean_squared_error),
    "RMSE": make_scorer(root_mean_squared_error),
    "R2": make_scorer(r2_score)
}

In [4]:
df = pd.read_csv("../cenario4_engine.csv")
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3207 entries, 0 to 3206
Data columns (total 6 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   model_year  3207 non-null   int64  
 1   KM          3207 non-null   int64  
 2   HP          2578 non-null   float64
 3   Litros      2891 non-null   float64
 4   Cilindros   2705 non-null   float64
 5   price_eur   3207 non-null   int64  
dtypes: float64(3), int64(3)
memory usage: 150.5 KB


In [5]:
categorical_columns = df.select_dtypes(include='object').columns

# Mapeia as categorias para números
for column in categorical_columns:
    df[column] = df[column].astype('category').cat.codes

In [6]:
df.head()

Unnamed: 0,model_year,KM,HP,Litros,Cilindros,price_eur
0,2014,114263,285.0,3.6,6.0,20020
1,2015,110361,270.0,3.5,6.0,23660
2,2018,112076,208.0,2.5,4.0,21835
3,2021,53913,260.0,2.4,4.0,30940
4,2018,120701,301.0,4.6,8.0,34666


In [7]:
X, y = df.drop("price_eur", axis=1), df["price_eur"]

In [8]:
splitter = StratifiedKFold(n_splits=10, shuffle=True, random_state=1234)

In [9]:
dt = DecisionTreeRegressor(max_depth=3, random_state=1234)
scores = cross_validate(dt, X, y, cv=splitter, scoring=METRICS_REGRESSION)
dt_scores = pd.DataFrame(scores)
pd.DataFrame(dt_scores.mean()).T

Unnamed: 0,fit_time,score_time,test_MAE,test_MSE,test_RMSE,test_R2
0,0.004784,0.002309,20120.747122,5310780000.0,61295.123827,0.025158
