In [None]:
# %% [markdown]
# Treinando MLP (scikit-learn) com GridSearch + holdout e avaliação completa
# - Explora arquiteturas, ativações, otimizadores, batch_size, L2 (alpha) e early stopping
# - Mantém pipeline de pré-processamento (imputação, OHE e padronização)
# - Salva resultados e modelo
# - (Opcional) NN do zero em NumPy para demonstrar o loop de treinamento

# %%
import os
import json
import math
import warnings
from typing import Optional, List, Tuple

import numpy as np
import pandas as pd
import joblib
import matplotlib.pyplot as plt

from sklearn.model_selection import (
    train_test_split,
    GridSearchCV,
    StratifiedKFold
)
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    average_precision_score,
    f1_score,
    classification_report,
    confusion_matrix
)
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.exceptions import ConvergenceWarning
import pathlib

# =========================
# Configurações gerais
# =========================
RANDOM_STATE = 42
TEST_SIZE = 0.15    # 70/15/15 ~> grid usa CV em cima dos 85% de treino/val
CV_SPLITS = 5
PRIMARY_SCORING_BINARY = "roc_auc"
PRIMARY_SCORING_MULTICLASS = "roc_auc_ovr"
path_data = pathlib.Path().cwd().parent / "data" 

DATA_PATH = path_data / 'train.csv'
TARGET: Optional[str] = "Churn"       # defina aqui se já sabe a coluna alvo (ex.: "Churn")

# Grade menor (rápida) vs maior (exaustiva)
EXHAUSTIVE = True  # coloque True para explorar mais combinações

# Onde salvar artefatos
OUTPUT_DIR = "./mlp_gridsearch_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

warnings.filterwarnings("ignore", category=ConvergenceWarning)


# =========================
# Utilidades
# =========================
COMMON_TARGET_CANDIDATES = [
    "churn", "Churn", "CHURN",
    "target", "TARGET", "label", "Label", "y",
    "Exited", "is_churn", "default"
]

def infer_target_column(df: pd.DataFrame) -> str:
    """Tenta inferir a coluna-alvo:
    1) nomes comuns (acima);
    2) alguma coluna binária com nome sugestivo;
    3) por último, alguma coluna claramente binária.
    """
    cols_lower = {c.lower(): c for c in df.columns}
    for cand in COMMON_TARGET_CANDIDATES:
        if cand.lower() in cols_lower:
            return cols_lower[cand.lower()]

    # procura por colunas com 2 valores únicos (binárias)
    binary_cols = []
    for c in df.columns:
        uniques = pd.Series(df[c].dropna().unique())
        if len(uniques) == 2:
            binary_cols.append(c)

    # heurística: se tiver 'churn' no nome
    for c in df.columns:
        if "churn" in c.lower():
            return c

    if len(binary_cols) == 1:
        return binary_cols[0]

    if len(binary_cols) > 1:
        # arbitrariamente pegue a primeira, mas avise
        print(f"[AVISO] Múltiplas colunas binárias candidatas: {binary_cols}. "
              f"Usando {binary_cols[0]}. Defina TARGET manualmente para garantir.")
        return binary_cols[0]

    raise ValueError(
        "Não foi possível inferir a coluna-alvo. "
        "Defina TARGET manualmente (ex.: TARGET = 'Churn')."
    )


def build_preprocessor(df: pd.DataFrame, target_col: str) -> Tuple[ColumnTransformer, List[str], List[str]]:
    """Cria um ColumnTransformer:
       - num: imputação mediana + padronização
       - cat: imputação modo + OneHotEncoder (dense)
    """
    # Se o alvo estiver no df, remova pra detectar features
    feature_df = df.drop(columns=[target_col]) if target_col in df.columns else df

    numeric_features = feature_df.select_dtypes(include=["number", "float", "int", "bool"]).columns.tolist()
    categorical_features = feature_df.select_dtypes(include=["object", "category"]).columns.tolist()

    # OneHotEncoder compat (sklearn >=1.2 usa sparse_output)
    try:
        ohe = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    except TypeError:
        ohe = OneHotEncoder(handle_unknown="ignore", sparse=False)

    numeric_transformer = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler(with_mean=True))
    ])

    categorical_transformer = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="most_frequent")),
        ("ohe", ohe)
    ])

    preprocessor = ColumnTransformer(
        transformers=[
            ("num", numeric_transformer, numeric_features),
            ("cat", categorical_transformer, categorical_features),
        ],
        remainder="drop"
    )

    return preprocessor, numeric_features, categorical_features


def primary_scoring_for_target(y: pd.Series) -> str:
    """Define métrica principal de refit."""
    n_unique = y.nunique(dropna=True)
    if n_unique <= 2:
        return PRIMARY_SCORING_BINARY
    return PRIMARY_SCORING_MULTICLASS


# =========================
# Carregar dados
# =========================
df = pd.read_csv(DATA_PATH)
print(f"Shape do dataset: {df.shape}")

# Definir/Inferir target
if TARGET is None:
    TARGET = infer_target_column(df)
print(f"Coluna alvo: {TARGET}")

# Se o alvo vier como string/Yes/No etc., tente mapear para {0,1} se binário
y_raw = df[TARGET]
if y_raw.dtype == object:
    # Tenta mapear strings comuns para binário
    map_yes = {"yes": 1, "y": 1, "true": 1, "sim": 1, "churn": 1, "1": 1}
    map_no = {"no": 0, "n": 0, "false": 0, "nao": 0, "não": 0, "retained": 0, "0": 0}
    y_lower = y_raw.astype(str).str.lower().str.strip()
    mapped = y_lower.map(lambda v: 1 if v in map_yes else (0 if v in map_no else np.nan))
    # Se mapeou quase tudo, usa o mapeamento
    if mapped.notna().mean() > 0.9:
        y = mapped.astype(int)
    else:
        # Deixa como estava (pode ser multiclass ou stratification por string)
        y = y_raw.copy()
else:
    y = y_raw.copy()

X = df.drop(columns=[TARGET])


# =========================
# Split: holdout de teste
# =========================
X_train_val, X_test, y_train_val, y_test = train_test_split(
    X, y, test_size=TEST_SIZE, stratify=y, random_state=RANDOM_STATE
)

print(f"Split -> train/val: {X_train_val.shape}, test: {X_test.shape}")

# =========================
# Preprocessador + Pipeline
# =========================
preprocessor, num_cols, cat_cols = build_preprocessor(df, TARGET)

base_clf = MLPClassifier(
    max_iter=200,
    random_state=RANDOM_STATE,
    tol=1e-4,
    shuffle=True
)

pipe = Pipeline(steps=[
    ("pre", preprocessor),
    ("clf", base_clf)
])

# =========================
# Grade de hiperparâmetros
# =========================
# Grades "compactas" (rápidas) e "exaustivas" (mais lentas).
if not EXHAUSTIVE:
    hidden_sizes = [(32,), (64,), (128,), (64, 32), (128, 64)]
    batch_sizes = [32, 64, 128]
    alphas = [1e-5, 1e-4, 1e-3]
    lr_init_sgd = [1e-2, 1e-3]
    lr_init_adam = [1e-3, 5e-4]
else:
    hidden_sizes = [(32,), (64,), (128,), (256,), (64, 32), (128, 64), (128, 64, 32), (256, 128)]
    batch_sizes = [32, 64, 128, 256]
    alphas = [1e-6, 1e-5, 1e-4, 1e-3]
    lr_init_sgd = [5e-2, 1e-2, 5e-3, 1e-3]
    lr_init_adam = [3e-3, 1e-3, 5e-4, 1e-4]

param_grid = [
    # SGD (mini-batch controlado por batch_size)
    {
        "clf__solver": ["sgd"],
        "clf__hidden_layer_sizes": hidden_sizes,
        "clf__activation": ["relu", "tanh", "logistic"],  # logistic == sigmoid
        "clf__batch_size": batch_sizes,
        "clf__learning_rate": ["constant", "adaptive"],
        "clf__learning_rate_init": lr_init_sgd,
        "clf__momentum": [0.0, 0.9],
        "clf__nesterovs_momentum": [True, False],
        "clf__alpha": alphas,
        "clf__early_stopping": [True, False],  # explorar uso de early stopping
        "clf__validation_fraction": [0.15],
    },
    # Adam
    {
        "clf__solver": ["adam"],
        "clf__hidden_layer_sizes": hidden_sizes,
        "clf__activation": ["relu", "tanh", "logistic"],
        "clf__batch_size": batch_sizes,
        "clf__learning_rate_init": lr_init_adam,
        "clf__alpha": alphas,
        "clf__beta_1": [0.9],
        "clf__beta_2": [0.999],
        "clf__epsilon": [1e-8],
        "clf__early_stopping": [True, False],
        "clf__validation_fraction": [0.15],
    },
]

# =========================
# Scoring / CV
# =========================
PRIMARY_SCORING = primary_scoring_for_target(y_train_val)
scoring = {
    "accuracy": "accuracy",
    "roc_auc": "roc_auc" if y_train_val.nunique() <= 2 else "roc_auc_ovr",
    "average_precision": "average_precision",
    "f1": "f1_macro" if y_train_val.nunique() > 2 else "f1",
}

cv = StratifiedKFold(n_splits=CV_SPLITS, shuffle=True, random_state=RANDOM_STATE)

# =========================
# Pesos amostrais (balancear classes) – opcional mas útil
# =========================
try:
    sample_weights = compute_sample_weight(class_weight="balanced", y=y_train_val)
    fit_params = {"clf__sample_weight": sample_weights}
except Exception:
    sample_weights = None
    fit_params = {}

# =========================
# Grid Search
# =========================
grid = GridSearchCV(
    estimator=pipe,
    param_grid=param_grid,
    scoring=scoring,
    refit="roc_auc" if y_train_val.nunique() <= 2 else "roc_auc_ovr",
    cv=cv,
    n_jobs=-1,
    verbose=2,
    error_score="raise"
)

print(f"Iniciando GridSearchCV com métrica de refit: {grid.refit}")
grid.fit(X_train_val, y_train_val, **fit_params)

# Resultados do grid
cv_results_df = pd.DataFrame(grid.cv_results_).sort_values(by=f"mean_test_{grid.refit}", ascending=False)
cv_results_path = os.path.join(OUTPUT_DIR, "gridsearch_results.csv")
cv_results_df.to_csv(cv_results_path, index=False)
print(f"Resultados do GridSearch salvos em: {cv_results_path}")

print("\n=== Melhor conjunto de hiperparâmetros (CV) ===")
print(json.dumps(grid.best_params_, indent=2))
print(f"\nMelhor {grid.refit} (CV): {grid.best_score_:.4f}")

# =========================
# Avaliação no Teste
# =========================
best_model: Pipeline = grid.best_estimator_

# Transform para obter probabilidades
y_pred = best_model.predict(X_test)
proba_ok = False
try:
    y_proba = best_model.predict_proba(X_test)
    proba_ok = True
except Exception:
    y_proba = None

acc = accuracy_score(y_test, y_pred)
if y_train_val.nunique() <= 2:
    roc = roc_auc_score(y_test, y_proba[:, 1]) if proba_ok else np.nan
    ap = average_precision_score(y_test, y_proba[:, 1]) if proba_ok else np.nan
else:
    roc = roc_auc_score(y_test, y_proba, multi_class="ovr") if proba_ok else np.nan
    ap = average_precision_score(pd.get_dummies(y_test), y_proba) if proba_ok else np.nan

f1 = f1_score(y_test, y_pred, average="binary" if y_train_val.nunique() <= 2 else "macro")
cm = confusion_matrix(y_test, y_pred)
report = classification_report(y_test, y_pred)

print("\n=== Métricas no conjunto de TESTE ===")
print(f"Accuracy : {acc:.4f}")
print(f"ROC AUC  : {roc:.4f}" if not np.isnan(roc) else "ROC AUC  : n/a (sem predict_proba)")
print(f"AvgPrec  : {ap:.4f}" if not np.isnan(ap) else "AvgPrec  : n/a (sem predict_proba)")
print(f"F1       : {f1:.4f}")
print("\nConfusion Matrix:\n", cm)
print("\nClassification Report:\n", report)

metrics_path = os.path.join(OUTPUT_DIR, "test_metrics.json")
with open(metrics_path, "w", encoding="utf-8") as f:
    json.dump({
        "accuracy": acc,
        "roc_auc": None if np.isnan(roc) else roc,
        "average_precision": None if np.isnan(ap) else ap,
        "f1": f1,
        "confusion_matrix": cm.tolist(),
        "best_params": grid.best_params_
    }, f, indent=2, ensure_ascii=False)
print(f"Métricas de teste salvas em: {metrics_path}")

# =========================
# Curva de perda (treino) do melhor MLP
# =========================
try:
    clf_best: MLPClassifier = best_model.named_steps["clf"]
    loss_curve = getattr(clf_best, "loss_curve_", None)
    if loss_curve is not None and len(loss_curve) > 0:
        plt.figure(figsize=(6, 4))
        plt.plot(range(1, len(loss_curve) + 1), loss_curve, linewidth=2)
        plt.xlabel("Época")
        plt.ylabel("Loss (treino)")
        plt.title("Curva de perda - melhor MLP")
        plt.grid(True, alpha=0.3)
        fig_path = os.path.join(OUTPUT_DIR, "best_model_loss_curve.png")
        plt.tight_layout()
        plt.savefig(fig_path, dpi=150)
        plt.close()
        print(f"Curva de perda salva em: {fig_path}")
except Exception as e:
    print("[AVISO] Não foi possível plotar a loss_curve:", e)

# =========================
# Salvar modelo final
# =========================
model_path = os.path.join(OUTPUT_DIR, "best_mlp_pipeline.joblib")
joblib.dump(best_model, model_path)
print(f"Pipeline (pré-processamento + MLP) salvo em: {model_path}")


# %% [markdown]
# -----------------------------------------------
# (Opcional) Rede Neural do zero (NumPy) – 1 camada oculta
# Para demonstrar forward -> loss -> backprop -> atualização de pesos
# -----------------------------------------------

# %%
RUN_SCRATCH = False  # ajuste para True se quiser treinar a NN do zero para comparação

if RUN_SCRATCH:
    # Reaproveita o pré-processador do pipeline para obter X numérico já imputado/escalado e OHE
    X_proc = best_model.named_steps["pre"].fit_transform(X_train_val)
    y_proc = y_train_val.values

    # Para binário, encoder para {0,1}
    if y_train_val.nunique() == 2:
        # Se não for int/float, converte via fatoração
        if not np.issubdtype(y_proc.dtype, np.number):
            _, y_proc = np.unique(y_proc, return_inverse=True)
        y_proc = y_proc.astype(np.float64).reshape(-1, 1)
        n_out = 1
        is_binary = True
    else:
        # One-hot para multiclass
        classes, y_idx = np.unique(y_proc, return_inverse=True)
        n_out = len(classes)
        y_onehot = np.eye(n_out)[y_idx]
        y_proc = y_onehot.astype(np.float64)
        is_binary = False

    # Split interno para validação do early stopping
    X_tr, X_val, y_tr, y_val = train_test_split(
        X_proc, y_proc, test_size=0.15, stratify=y_train_val, random_state=RANDOM_STATE
    )

    # Hiperparâmetros simples
    n_in = X_tr.shape[1]
    n_hidden = 64
    epochs = 100
    batch_size = 64
    lr = 1e-3
    l2 = 1e-4
    patience = 10
    activation = "relu"  # "relu" | "tanh" | "sigmoid"

    rng = np.random.default_rng(RANDOM_STATE)

    # Inicialização Xavier/He simples
    def init_weights(fan_in, fan_out, nonlin: str):
        if nonlin == "relu":
            std = math.sqrt(2.0 / fan_in)
        else:
            std = math.sqrt(1.0 / fan_in)
        return rng.normal(0.0, std, size=(fan_in, fan_out))

    W1 = init_weights(n_in, n_hidden, activation)
    b1 = np.zeros((1, n_hidden))
    W2 = init_weights(n_hidden, n_out, activation)
    b2 = np.zeros((1, n_out))

    def act(x, kind):
        if kind == "relu":
            return np.maximum(0, x)
        if kind == "tanh":
            return np.tanh(x)
        if kind == "sigmoid":
            return 1.0 / (1.0 + np.exp(-x))
        raise ValueError(kind)

    def act_grad(x, kind):
        if kind == "relu":
            return (x > 0).astype(x.dtype)
        if kind == "tanh":
            y = np.tanh(x)
            return 1 - y**2
        if kind == "sigmoid":
            s = 1.0 / (1.0 + np.exp(-x))
            return s * (1 - s)
        raise ValueError(kind)

    def forward(Xb):
        Z1 = Xb @ W1 + b1
        A1 = act(Z1, activation)
        Z2 = A1 @ W2 + b2
        if is_binary:
            # sigmoid + BCE
            A2 = 1.0 / (1.0 + np.exp(-Z2))
        else:
            # softmax
            Z2_shift = Z2 - Z2.max(axis=1, keepdims=True)
            expZ = np.exp(Z2_shift)
            A2 = expZ / expZ.sum(axis=1, keepdims=True)
        cache = (Xb, Z1, A1, Z2, A2)
        return A2, cache

    def compute_loss(A2, yb):
        # Cross-entropy + L2
        eps = 1e-12
        if is_binary:
            # BCE
            loss = -np.mean(yb * np.log(A2 + eps) + (1 - yb) * np.log(1 - A2 + eps))
        else:
            loss = -np.mean(np.sum(yb * np.log(A2 + eps), axis=1))
        l2_term = 0.5 * l2 * (np.sum(W1**2) + np.sum(W2**2))
        return loss + l2_term

    def backward(cache, A2, yb):
        Xb, Z1, A1, Z2, _ = cache
        m = Xb.shape[0]

        if is_binary:
            dZ2 = (A2 - yb) / m  # derivada BCE + sigmoid
        else:
            dZ2 = (A2 - yb) / m  # derivada CE + softmax

        dW2 = A1.T @ dZ2 + l2 * W2
        db2 = dZ2.sum(axis=0, keepdims=True)

        dA1 = dZ2 @ W2.T
        dZ1 = dA1 * act_grad(Z1, activation)
        dW1 = Xb.T @ dZ1 + l2 * W1
        db1 = dZ1.sum(axis=0, keepdims=True)

        return dW1, db1, dW2, db2

    def iterate_minibatches(Xa, ya, bs, shuffle=True):
        idx = np.arange(Xa.shape[0])
        if shuffle:
            rng.shuffle(idx)
        for start in range(0, Xa.shape[0], bs):
            sl = idx[start:start+bs]
            yield Xa[sl], ya[sl]

    best_val = np.inf
    best_params = None
    patience_cnt = 0
    train_curve, val_curve = [], []

    for epoch in range(1, epochs + 1):
        # Treino
        for Xb, yb in iterate_minibatches(X_tr, y_tr, batch_size, shuffle=True):
            A2, cache = forward(Xb)
            dW1, db1, dW2, db2 = backward(cache, A2, yb)
            # Atualização (SGD)
            W1 -= lr * dW1
            b1 -= lr * db1
            W2 -= lr * dW2
            b2 -= lr * db2

        # Avalia loss
        A2_tr, _ = forward(X_tr)
        A2_val, _ = forward(X_val)
        loss_tr = compute_loss(A2_tr, y_tr)
        loss_val = compute_loss(A2_val, y_val)
        train_curve.append(loss_tr)
        val_curve.append(loss_val)

        if epoch % 5 == 0 or epoch == 1:
            print(f"[NumPy NN] Época {epoch:03d} | loss_tr={loss_tr:.4f} | loss_val={loss_val:.4f}")

        # Early stopping
        if loss_val < best_val - 1e-5:
            best_val = loss_val
            best_params = (W1.copy(), b1.copy(), W2.copy(), b2.copy())
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= patience:
                print(f"[NumPy NN] Early stopping na época {epoch}")
                break

    if best_params is not None:
        W1, b1, W2, b2 = best_params

    # Curvas
    plt.figure(figsize=(6, 4))
    plt.plot(train_curve, label="treino")
    plt.plot(val_curve, label="val")
    plt.xlabel("Época")
    plt.ylabel("Loss")
    plt.title("Curva de perda - NumPy NN")
    plt.legend()
    plt.grid(True, alpha=0.3)
    np_fig_path = os.path.join(OUTPUT_DIR, "numpy_nn_loss_curve.png")
    plt.tight_layout()
    plt.savefig(np_fig_path, dpi=150)
    plt.close()
    print(f"[NumPy NN] Curva de perda salva em: {np_fig_path}")

    # Avaliação no teste (usando mesmo pré-processador)
    X_test_proc = best_model.named_steps["pre"].transform(X_test)
    A2_test, _ = forward(X_test_proc)
    if is_binary:
        y_pred_np = (A2_test.ravel() >= 0.5).astype(int)
        y_proba_np = A2_test.ravel()
        acc_np = accuracy_score(y_test, y_pred_np)
        roc_np = roc_auc_score(y_test, y_proba_np)
        ap_np = average_precision_score(y_test, y_proba_np)
        f1_np = f1_score(y_test, y_pred_np)
        
        print(f"[NumPy NN] Test -> acc={acc_np:.4f} | roc_auc={roc_np:.4f} | ap={ap_np:.4f} | f1={f1_np:.4f}")
    else:
        y_pred_np = A2_test.argmax(axis=1)
        # mapear rótulos de volta se necessário – omitido por simplicidade
        acc_np = accuracy_score(y_test, y_pred_np)
        f1_np = f1_score(y_test, y_pred_np, average="macro")
        print(f"[NumPy NN] Test -> acc={acc_np:.4f} | f1_macro={f1_np:.4f}")
