
# Árvore de Decisão — Consumo de Cerveja (Kaggle)

Este notebook cria um pipeline completo para transformar a base de **Consumo de Cerveja** em um problema de **classificação** (alto vs. baixo), treinar **Árvore de Decisão** com **balanceamento**, avaliar com **acurácia**, **outras métricas** e **matriz de confusão**, além de **extrair as regras** da árvore.  
> Ajuste o caminho do arquivo CSV na célula de *Setup* antes de executar.


In [9]:

import pandas as pd
import numpy as np
from pathlib import Path

from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.pipeline import Pipeline
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, confusion_matrix, ConfusionMatrixDisplay, roc_auc_score
)

import matplotlib.pyplot as plt

RANDOM_STATE = 42


In [10]:
df = pd.read_csv('consumo_cerveja.csv', decimal=",")

df.head(5)


Unnamed: 0,Data,Temperatura Media (C),Temperatura Minima (C),Temperatura Maxima (C),Precipitacao (mm),Final de Semana,Consumo de cerveja (litros)
0,01/01/2015,27.3,23.9,32.5,0.0,0,25.461
1,02/01/2015,27.02,24.5,33.5,0.0,0,28.972
2,03/01/2015,24.82,22.4,29.9,0.0,1,30.814
3,04/01/2015,23.98,21.5,28.6,1.2,1,29.799
4,05/01/2015,23.82,21.0,28.3,0.0,0,28.9


In [11]:

# === Detecção "best effort" de colunas ===
def pick_guess(cols, patterns):
    for pat in patterns:
        for c in cols:
            if pat in c.lower():
                return c
    return None

# Tentativas comuns: 'consumo' / 'beer' / 'litro' / 'consumption'
CONSUMO_COL = pick_guess([c.lower() for c in df.columns], ["consumo", "beer", "litro", "consumption"])
if CONSUMO_COL is None:
    raise ValueError(f"Não achei a coluna de consumo. Verifique nomes de colunas: {list(df.columns)}")
# Mapear de volta ao nome original (case-sensitive)
for c in df.columns:
    if c.lower() == CONSUMO_COL:
        CONSUMO_COL = c
        break

DATE_COL = pick_guess([c.lower() for c in df.columns], ["data", "date"])
if DATE_COL is not None:
    for c in df.columns:
        if c.lower() == DATE_COL:
            DATE_COL = c
            break

print("Coluna de consumo detectada:", CONSUMO_COL)
print("Coluna de data detectada:", DATE_COL)

# Parser de data + features de calendário
if DATE_COL and DATE_COL in df.columns:
    df[DATE_COL] = pd.to_datetime(df[DATE_COL], dayfirst=True, errors="coerce")
    df["mes"] = df[DATE_COL].dt.month
    df["dia_semana"] = df[DATE_COL].dt.dayofweek  # 0=segunda ... 6=domingo
    df["fim_semana_auto"] = (df["dia_semana"] >= 5).astype(int)

# Normalização de flags booleanas/categóricas comuns (ex.: fim de semana)
for cand in ["fim_de_semana", "final_de_semana", "weekend", "fds", "is_weekend"]:
    for col in df.columns:
        if col.lower() == cand:
            df[col] = (
                df[col]
                .astype(str).str.strip().str.lower()
                .map({"1":1,"0":0,"true":1,"false":0,"sim":1,"nao":0,"não":0})
                .fillna(np.nan)
            )
            # Se mapeou strings conhecidas, force int; caso contrário, deixa como estava
            if df[col].dropna().isin([0,1]).all():
                df[col] = df[col].astype(int)

df.info()


Coluna de consumo detectada: Consumo de cerveja (litros)
Coluna de data detectada: Data
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 365 entries, 0 to 364
Data columns (total 10 columns):
 #   Column                       Non-Null Count  Dtype         
---  ------                       --------------  -----         
 0   Data                         365 non-null    datetime64[ns]
 1   Temperatura Media (C)        365 non-null    object        
 2   Temperatura Minima (C)       365 non-null    object        
 3   Temperatura Maxima (C)       365 non-null    object        
 4   Precipitacao (mm)            365 non-null    object        
 5   Final de Semana              365 non-null    int64         
 6   Consumo de cerveja (litros)  365 non-null    object        
 7   mes                          365 non-null    int32         
 8   dia_semana                   365 non-null    int32         
 9   fim_semana_auto              365 non-null    int64         
dtypes: datetime64[ns](1), 

In [12]:

# === Transformar problema em classificação: consumo alto vs. baixo ===
# Por padrão, usamos a mediana para criar um threshold equilibrado.
y_threshold = df[CONSUMO_COL].median()  # alternativa: df[CONSUMO_COL].quantile(0.7)
print("Threshold (mediana) para consumo alto:", y_threshold)

df["consumo_alto"] = (df[CONSUMO_COL] >= y_threshold).astype(int)
print("Distribuição de classes (0=baixo, 1=alto):")
display(df["consumo_alto"].value_counts(normalize=True).rename("proporção").round(3))


TypeError: Cannot convert ['25.461' '28.972' '30.814' '29.799' '28.900' '28.218' '29.732' '28.397'
 '24.886' '37.937' '36.254' '25.743' '26.990' '31.825' '25.724' '29.938'
 '37.690' '30.524' '29.265' '35.127' '29.130' '25.795' '21.784' '28.348'
 '31.088' '21.520' '29.972' '22.603' '22.696' '26.845' '27.030' '32.057'
 '24.097' '31.655' '24.738' '19.950' '22.821' '28.893' '29.926' '24.062'
 '21.137' '26.805' '26.389' '24.219' '30.231' '24.968' '25.343' '17.399'
 '21.392' '22.922' '24.567' '30.943' '30.825' '25.692' '26.959' '25.366'
 '22.784' '26.241' '26.467' '27.475' '28.749' '24.146' '22.988' '30.300'
 '22.654' '29.090' '24.619' '20.016' '23.042' '22.933' '22.409' '21.281'
 '28.844' '32.872' '20.903' '26.275' '20.167' '23.628' '24.213' '28.631'
 '25.855' '21.406' '21.617' '22.401' '27.989' '24.974' '29.760' '26.116'
 '25.850' '24.925' '21.979' '22.116' '24.867' '33.450' '32.713' '22.356'
 '21.004' '23.362' '20.298' '24.862' '30.505' '25.070' '22.620' '22.001'
 '23.469' '21.735' '21.593' '33.822' '28.028' '24.304' '31.108' '19.113'
 '23.198' '24.388' '27.420' '30.479' '21.838' '20.812' '19.761' '20.452'
 '17.939' '25.272' '28.049' '25.317' '21.826' '20.680' '19.143' '18.146'
 '25.489' '23.537' '16.956' '19.052' '17.287' '20.300' '20.538' '23.702'
 '28.411' '21.073' '24.215' '19.525' '20.786' '20.429' '27.250' '31.387'
 '26.075' '22.162' '24.258' '24.683' '21.245' '25.937' '26.081' '16.228'
 '20.106' '21.055' '22.772' '25.142' '31.129' '30.498' '26.150' '21.327'
 '22.008' '24.615' '22.375' '29.607' '32.983' '19.119' '21.029' '23.898'
 '24.534' '16.748' '23.055' '28.857' '23.022' '27.146' '17.241' '19.463'
 '21.860' '24.227' '27.594' '24.863' '20.161' '20.824' '19.727' '14.940'
 '24.632' '21.294' '18.448' '21.237' '19.849' '20.740' '25.698' '26.691'
 '33.298' '25.640' '23.937' '28.742' '21.748' '22.032' '24.827' '32.473'
 '20.620' '21.825' '14.343' '19.029' '21.104' '20.738' '25.233' '18.975'
 '19.640' '22.522' '24.227' '24.726' '32.467' '31.663' '25.867' '27.724'
 '22.039' '26.127' '26.580' '31.310' '33.517' '23.181' '24.183' '24.594'
 '22.610' '25.479' '29.621' '26.272' '22.541' '23.070' '26.021' '17.655'
 '23.243' '30.177' '27.518' '23.210' '21.092' '23.357' '17.888' '22.217'
 '31.681' '31.833' '28.441' '22.389' '20.681' '24.222' '19.345' '21.827'
 '23.566' '20.227' '17.075' '16.977' '21.525' '21.454' '21.814' '21.252'
 '20.464' '30.775' '25.343' '33.930' '26.311' '31.836' '34.695' '29.829'
 '26.362' '32.589' '30.345' '29.411' '29.637' '32.184' '17.731' '24.114'
 '28.034' '22.664' '27.488' '24.876' '24.862' '24.679' '22.304' '30.329'
 '33.182' '23.849' '33.330' '34.496' '26.249' '26.523' '26.793' '35.861'
 '27.387' '32.666' '22.199' '24.000' '27.871' '31.139' '23.065' '26.594'
 '27.657' '26.594' '28.084' '27.582' '24.862' '22.634' '31.649' '35.781'
 '24.429' '20.648' '22.741' '21.479' '23.134' '20.575' '24.330' '28.610'
 '28.456' '27.964' '29.569' '29.267' '28.647' '26.836' '29.386' '24.609'
 '26.964' '23.614' '22.960' '20.332' '30.392' '31.933' '21.662' '21.689'
 '25.119' '25.285' '28.979' '34.382' '30.617' '20.238' '24.529' '30.471'
 '28.405' '29.513' '32.451' '32.780' '23.375' '27.713' '27.137' '22.933'
 '30.740' '29.579' '29.188' '28.131' '28.617' '21.062' '24.337' '27.042'
 '32.536' '30.127' '24.834' '26.828' '26.468' '31.572' '26.308' '21.955'
 '32.307' '26.095' '22.309' '20.467' '22.446'] to numeric

In [None]:

# === Seleção de features ===
drop_cols = [CONSUMO_COL, "consumo_alto"]
if DATE_COL and DATE_COL in df.columns:
    drop_cols.append(DATE_COL)

X = df.drop(columns=[c for c in drop_cols if c in df.columns])
y = df["consumo_alto"].copy()

num_cols = X.select_dtypes(include=[np.number]).columns.tolist()
cat_cols = [c for c in X.columns if c not in num_cols]

print("Numéricas:", num_cols)
print("Categóricas:", cat_cols)


KeyError: 'consumo_alto'

In [None]:

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=RANDOM_STATE, stratify=y
)
X_train.shape, X_test.shape, y_train.mean(), y_test.mean()


In [None]:

numeric_tf = Pipeline(steps=[
    ("imputer", SimpleImputer(strategy="median"))
])

categorical_tf = Pipeline(steps=[
    ("imputer", SimpleImputer(strategy="most_frequent")),
    ("ohe", OneHotEncoder(handle_unknown="ignore"))
])

pre = ColumnTransformer(
    transformers=[
        ("num", numeric_tf, num_cols),
        ("cat", categorical_tf, cat_cols)
    ]
)

tree = DecisionTreeClassifier(
    criterion="entropy",
    class_weight="balanced",      # balanceamento direto
    max_depth=6,
    min_samples_leaf=50,
    random_state=RANDOM_STATE
)

model = Pipeline(steps=[("pre", pre), ("tree", tree)])
model.fit(X_train, y_train)

print("Modelo treinado com class_weight='balanced'.")


In [None]:

pred = model.predict(X_test)
proba = model.predict_proba(X_test)[:,1]

acc = accuracy_score(y_test, pred)
prec = precision_score(y_test, pred, zero_division=0)
rec = recall_score(y_test, pred, zero_division=0)
f1  = f1_score(y_test, pred, zero_division=0)
auc = roc_auc_score(y_test, proba)

print(f"Acurácia:  {acc:.3f}")
print(f"Precisão:  {prec:.3f}")
print(f"Recall:    {rec:.3f}")
print(f"F1-Score:  {f1:.3f}")
print(f"ROC-AUC:   {auc:.3f}\n")

print("Relatório de Classificação:")
print(classification_report(y_test, pred, digits=3))

cm = confusion_matrix(y_test, pred)
disp = ConfusionMatrixDisplay(cm, display_labels=["baixo","alto"])
fig, ax = plt.subplots(figsize=(4,4))
disp.plot(ax=ax, values_format="d", colorbar=False)
plt.title("Matriz de Confusão — Árvore (balanced)")
plt.show()


In [None]:

# Obter nomes pós-OneHot
ohe = None
if len(cat_cols) > 0:
    ohe = model.named_steps["pre"].named_transformers_["cat"].named_steps["ohe"]

num_feats = num_cols
cat_feats = ohe.get_feature_names_out(cat_cols).tolist() if ohe is not None else []
all_feats = num_feats + cat_feats

importances = model.named_steps["tree"].feature_importances_
imp = pd.Series(importances, index=all_feats).sort_values(ascending=False)
display(imp.head(15).rename("importância"))


In [None]:

rules_text = export_text(
    model.named_steps["tree"],
    feature_names=all_feats,
    decimals=2,
    show_weights=True
)
print(rules_text)

# with open("regras_arvore.txt", "w", encoding="utf-8") as f:
#     f.write(rules_text)


In [None]:

from sklearn import tree as sktree

plt.figure(figsize=(18, 10))
sktree.plot_tree(
    model.named_steps["tree"],
    feature_names=all_feats,
    class_names=["baixo","alto"],
    filled=True,
    rounded=True,
    max_depth=3  # mostre só o topo para legibilidade
)
plt.title("Árvore de Decisão (topo — 3 níveis)")
plt.show()



## (Opcional) Balanceamento com SMOTE
Use quando a classe positiva estiver **bem minoritária** (por ex., ao usar um quantil alto como threshold).


In [None]:

# Descomente para usar SMOTE
# from imblearn.over_sampling import SMOTE
# from imblearn.pipeline import Pipeline as ImbPipeline
# 
# smote = SMOTE(random_state=RANDOM_STATE, k_neighbors=5)
# 
# tree_sm = DecisionTreeClassifier(
#     criterion="entropy",
#     class_weight=None,      # sem 'balanced' quando usar SMOTE
#     max_depth=6,
#     min_samples_leaf=50,
#     random_state=RANDOM_STATE
# )
# 
# model_sm = ImbPipeline(steps=[("pre", pre), ("smote", smote), ("tree", tree_sm)])
# model_sm.fit(X_train, y_train)
# 
# pred_sm = model_sm.predict(X_test)
# proba_sm = model_sm.predict_proba(X_test)[:,1]
# 
# print("=== Árvore + SMOTE ===")
# print("Acurácia:", accuracy_score(y_test, pred_sm).round(3))
# print("Precisão:", precision_score(y_test, pred_sm, zero_division=0).round(3))
# print("Recall:  ", recall_score(y_test, pred_sm, zero_division=0).round(3))
# print("F1:      ", f1_score(y_test, pred_sm, zero_division=0).round(3))
# print("ROC-AUC: ", roc_auc_score(y_test, proba_sm).round(3))
