In [None]:
"""
Full ML experiment script for flash butt welding defect prediction.
- Four model families: LogisticRegression, RandomForest, XGBoost (or GradientBoosting fallback), MLPClassifier
- Hyperparameter tuning via RandomizedSearchCV (StratifiedKFold)
- SMOTE optional + class_weight adjustments
- Multiple metrics: accuracy, precision, recall, f1, roc_auc (classification)
- Regression metrics included in helper for completeness (RMSE, MAE, R2)
- Error analysis: confusion matrices, misclassified class analysis, probability-based error inspection
- Threshold tuning to maximize F1 / precision / recall as needed
- Plots: confusion matrix, ROC curve, Precision-Recall curve, error distribution (probability hist)
- MLflow tracking: params, metrics, model artifacts, plots
"""

import os
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import mlflow
import mlflow.sklearn

from sklearn.model_selection import train_test_split, StratifiedKFold, RandomizedSearchCV
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
                             roc_auc_score, roc_curve, precision_recall_curve, confusion_matrix,
                             classification_report)
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline as ImbPipeline
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.base import BaseEstimator, ClassifierMixin

# Optional XGBoost
try:
    from xgboost import XGBClassifier
    xgb_available = True
except Exception:
    xgb_available = False

# -------------------------
# Config
# -------------------------
DATA_PATH = "preprocessed_data.csv"
TARGET = "iDefect"
RANDOM_STATE = 42
TEST_SIZE = 0.20
N_ITER_SEARCH = 30
CV_SPLITS = 5
MLFLOW_EXPERIMENT_NAME = "welding_defect_models"

os.makedirs("artifacts/plots", exist_ok=True)
os.makedirs("artifacts/models", exist_ok=True)

# -------------------------
# Helpers
# -------------------------
def regression_metrics(y_true, y_pred):
    rmse = mean_squared_error(y_true, y_pred, squared=False)
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    return {"rmse": rmse, "mae": mae, "r2": r2}

def classification_metrics(y_true, y_pred, y_proba=None, pos_label=1):
    results = {}
    results["accuracy"] = accuracy_score(y_true, y_pred)
    results["precision"] = precision_score(y_true, y_pred, zero_division=0, pos_label=pos_label)
    results["recall"] = recall_score(y_true, y_pred, zero_division=0, pos_label=pos_label)
    results["f1"] = f1_score(y_true, y_pred, zero_division=0, pos_label=pos_label)
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            results["roc_auc"] = roc_auc_score(y_true, y_proba)
        except Exception:
            results["roc_auc"] = np.nan
    else:
        results["roc_auc"] = np.nan
    return results

def plot_and_save_confusion_matrix(y_true, y_pred, title, fname):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(5,4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.xlabel("Pred")
    plt.ylabel("True")
    plt.title(title)
    plt.savefig(fname, bbox_inches="tight")
    plt.close()

def plot_and_save_roc(y_true, y_proba, title, fname):
    fpr, tpr, _ = roc_curve(y_true, y_proba)
    plt.figure(figsize=(6,5))
    plt.plot(fpr, tpr, label=f"AUC={roc_auc_score(y_true, y_proba):.4f}")
    plt.plot([0,1],[0,1],"k--")
    plt.xlabel("FPR")
    plt.ylabel("TPR")
    plt.title(title)
    plt.legend(loc="lower right")
    plt.savefig(fname, bbox_inches="tight")
    plt.close()

def plot_and_save_pr(y_true, y_proba, title, fname):
    precision, recall, _ = precision_recall_curve(y_true, y_proba)
    # compute area under PR isn't built-in; just show curve
    plt.figure(figsize=(6,5))
    plt.plot(recall, precision)
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(title)
    plt.savefig(fname, bbox_inches="tight")
    plt.close()

def plot_probability_hist(y_true, y_proba, title, fname):
    plt.figure(figsize=(6,4))
    sns.histplot(y_proba[y_true==0], label="class0", stat="density", kde=True)
    sns.histplot(y_proba[y_true==1], label="class1", stat="density", kde=True)
    plt.legend()
    plt.title(title)
    plt.xlabel("Predicted probability (positive)")
    plt.savefig(fname, bbox_inches="tight")
    plt.close()

# Custom MLP Classifier using PyTorch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define GPU-based MLP model
class TorchMLP(nn.Module):
    def __init__(self, input_dim, hidden_layers=(100,), activation='relu'):
        super(TorchMLP, self).__init__()
        layers = []
        prev = input_dim
        for h in hidden_layers:
            layers.append(nn.Linear(prev, h))
            if activation == 'relu':
                layers.append(nn.ReLU())
            else:
                layers.append(nn.Tanh())
            prev = h
        layers.append(nn.Linear(prev, 1))
        layers.append(nn.Sigmoid())
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Sklearn-compatible PyTorch wrapper
class TorchMLPClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, hidden_layer_sizes=(100,), activation='relu', alpha=1e-4, learning_rate_init=1e-3,
                 max_iter=200, batch_size=64, random_state=42):
        self.hidden_layer_sizes = hidden_layer_sizes
        self.activation = activation
        self.alpha = alpha
        self.learning_rate_init = learning_rate_init
        self.max_iter = max_iter
        self.batch_size = batch_size
        self.random_state = random_state
        torch.manual_seed(random_state)

    def fit(self, X, y):
        X = torch.tensor(X, dtype=torch.float32).to(device)
        y = torch.tensor(y.values if isinstance(y, pd.Series) else y, dtype=torch.float32).view(-1, 1).to(device)

        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

        self.model_ = TorchMLP(X.shape[1], self.hidden_layer_sizes, self.activation).to(device)
        criterion = nn.BCELoss()
        optimizer = optim.Adam(self.model_.parameters(), lr=self.learning_rate_init, weight_decay=self.alpha)

        for epoch in range(self.max_iter):
            self.model_.train()
            for Xb, yb in loader:
                optimizer.zero_grad()
                output = self.model_(Xb)
                loss = criterion(output, yb)
                loss.backward()
                optimizer.step()
        return self

    def predict_proba(self, X):
        X = torch.tensor(X, dtype=torch.float32).to(device)
        self.model_.eval()
        with torch.no_grad():
            preds = self.model_(X).cpu().numpy()
        return np.hstack((1 - preds, preds))

    def predict(self, X):
        return (self.predict_proba(X)[:, 1] >= 0.5).astype(int)

# -------------------------
# Load data
# -------------------------
df = pd.read_csv(DATA_PATH)

categorical_col = "sSizeID"

# --- Define target and features ---
# Replace 'Defect' with the actual column name of your target
target_col = "iDefect"
X = df.drop(columns=[target_col])
y = df[target_col]

# --- One-hot encode sSizeID ---
encoder = OneHotEncoder(sparse_output=False, drop=None, handle_unknown='ignore')
encoded = encoder.fit_transform(X[[categorical_col]])
encoded_df = pd.DataFrame(encoded, columns=encoder.get_feature_names_out([categorical_col]))

# --- Replace original sSizeID with encoded columns ---
X = pd.concat([X.drop(columns=[categorical_col]).reset_index(drop=True),
               encoded_df.reset_index(drop=True)], axis=1)

# --- Split into train & test ---
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=TEST_SIZE, stratify=y, random_state=RANDOM_STATE
)

print("Train class distribution:\n", y_train.value_counts(normalize=True))
print("Test class distribution:\n", y_test.value_counts(normalize=True))

# -------------------------
# CV
# -------------------------
cv = StratifiedKFold(n_splits=CV_SPLITS, shuffle=True, random_state=RANDOM_STATE)

# -------------------------
# Preprocessing pipelines
# -------------------------
# We'll use median imputation + scaling. Use SMOTE in pipeline for training to handle imbalance.
preproc_steps = [
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler())
]

# -------------------------
# Models and hyperparam grids
# -------------------------
models_to_run = {}

# Logistic Regression (linear)
lr_pipe = ImbPipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler()),
    ('smote', SMOTE(random_state=RANDOM_STATE)),
    ('clf', LogisticRegression(solver='saga', max_iter=5000, class_weight='balanced', random_state=RANDOM_STATE))
])
lr_param_dist = {
    'clf__C': np.logspace(-4, 4, 20),
    'clf__penalty': ['l1', 'l2', 'elasticnet'],
    'clf__l1_ratio': [None, 0.1, 0.5, 0.9]
}
models_to_run['LogisticRegression'] = (lr_pipe, lr_param_dist)

# Random Forest (tree-based)
rf_pipe = ImbPipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('smote', SMOTE(random_state=RANDOM_STATE)),
    ('clf', RandomForestClassifier(n_jobs=-1, class_weight='balanced', random_state=RANDOM_STATE))
])
rf_param_dist = {
    'clf__n_estimators': [100, 200, 400],
    'clf__max_depth': [None, 10, 20],
    'clf__min_samples_split': [2, 5],
    'clf__min_samples_leaf': [1, 2],
    'clf__max_features': ['sqrt', 0.5]
}
models_to_run['RandomForest'] = (rf_pipe, rf_param_dist)

# XGBoost
xgb_pipe = ImbPipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('smote', SMOTE(random_state=RANDOM_STATE)),
    ('clf', XGBClassifier(use_label_encoder=False, tree_method='gpu_hist', predictor='gpu_predictor', eval_metric='logloss', n_jobs=-1, random_state=RANDOM_STATE))
])
xgb_param_dist = {
    'clf__n_estimators': [100, 200, 400],
    'clf__max_depth': [3, 6, 10],
    'clf__learning_rate': [0.01, 0.05, 0.1],
    'clf__subsample': [0.6, 0.8, 1.0],
    'clf__colsample_bytree': [0.4, 0.6, 0.8]
}
models_to_run['XGBoost'] = (xgb_pipe, xgb_param_dist)

# MLP (advanced)
mlp_pipe = ImbPipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler()),
    ('smote', SMOTE(random_state=RANDOM_STATE)),
    ('clf', TorchMLPClassifier(max_iter=200, random_state=RANDOM_STATE))
])
mlp_param_dist = {
    'clf__hidden_layer_sizes': [(50,), (100,), (100,50)],
    'clf__activation': ['relu', 'tanh'],
    'clf__alpha': [1e-4, 1e-3],
    'clf__learning_rate_init': [1e-3, 1e-4]
}
models_to_run['MLP'] = (mlp_pipe, mlp_param_dist)

# -------------------------
# MLflow setup
# -------------------------
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)

# results store
results_list = []  # will store dict with metrics, params, model_name

# -------------------------
# Helper: threshold tuning
# -------------------------
def tune_threshold(y_true, y_proba, metric='f1'):
    """
    Returns best threshold in [0,1] maximizing metric (f1 / precision / recall).
    y_proba is probability for positive class.
    """
    precisions, recalls, thresholds = precision_recall_curve(y_true, y_proba)
    best_thresh = 0.5
    best_score = -np.inf
    # evaluate using thresholds from precision_recall_curve (note lengths differ)
    for t in np.unique(np.concatenate(([0.5], thresholds))):
        preds = (y_proba >= t).astype(int)
        if metric == 'f1':
            score = f1_score(y_true, preds, zero_division=0)
        elif metric == 'precision':
            score = precision_score(y_true, preds, zero_division=0)
        elif metric == 'recall':
            score = recall_score(y_true, preds, zero_division=0)
        else:
            score = f1_score(y_true, preds, zero_division=0)
        if score > best_score:
            best_score = score
            best_thresh = t
    return best_thresh, best_score

# -------------------------
# Iterate models: tune, evaluate, log
# -------------------------
for model_name, (pipe, param_dist) in models_to_run.items():
    print(f"\n=== RUNNING: {model_name} ===")
    search = RandomizedSearchCV(
        pipe, param_distributions=param_dist,
        n_iter=min(N_ITER_SEARCH, max(1, len(param_dist) * 3)),
        scoring='f1', cv=cv, n_jobs=-1, random_state=RANDOM_STATE, verbose=1
    )
    search.fit(X_train, y_train)
    best = search.best_estimator_
    print("Best params:", search.best_params_)
    # predict probabilities and classes on test set
    try:
        y_proba = best.predict_proba(X_test)[:, 1]
    except Exception:
        # fallback to decision_function
        try:
            dfcn = best.decision_function(X_test)
            y_proba = (dfcn - dfcn.min()) / (dfcn.max() - dfcn.min() + 1e-9)
        except Exception:
            y_proba = np.zeros(len(X_test))
    y_pred_default = best.predict(X_test)

    # threshold tuning for best F1
    best_thresh, best_f1 = tune_threshold(y_test.values, y_proba, metric='f1')
    y_pred_thresh = (y_proba >= best_thresh).astype(int)

    # compute metrics for default (0.5) and tuned
    metrics_default = classification_metrics(y_test, y_pred_default, y_proba)
    metrics_tuned = classification_metrics(y_test, y_pred_thresh, y_proba)

    # confusion matrix & plots
    cm_fname = f"artifacts/plots/{model_name}_confusion.png"
    plot_and_save_confusion_matrix(y_test, y_pred_thresh, f"{model_name} Confusion (thresh={best_thresh:.3f})", cm_fname)

    roc_fname = f"artifacts/plots/{model_name}_roc.png"
    if len(np.unique(y_test)) > 1:
        plot_and_save_roc(y_test, y_proba, f"{model_name} ROC", roc_fname)

        pr_fname = f"artifacts/plots/{model_name}_pr.png"
        plot_and_save_pr(y_test, y_proba, f"{model_name} Precision-Recall", pr_fname)
    else:
        roc_fname = None
        pr_fname = None

    prob_hist_fname = f"artifacts/plots/{model_name}_prob_hist.png"
    plot_probability_hist(y_test.values, y_proba, f"{model_name} Pred Prob Dist", prob_hist_fname)

    # misclassification analysis
    mis_idx = np.where(y_test.values != y_pred_thresh)[0]
    # top misclassified with high-confidence (probability >=0.8 or <=0.2)
    high_conf_mis = [(i, y_test.values[i], y_pred_thresh[i], y_proba[i]) for i in mis_idx if (y_proba[i] >= 0.8 or y_proba[i] <= 0.2)]
    # store only a small sample
    high_conf_mis_sample = high_conf_mis[:10]

    # Save model artifact
    model_fname = f"artifacts/models/{model_name}_best.joblib"
    joblib.dump(best, model_fname)

    # Log with MLflow
    with mlflow.start_run(run_name=f"{model_name}_run"):
        # params
        mlflow.log_param("model_name", model_name)
        # log best params (flatten)
        for k, v in search.best_params_.items():
            mlflow.log_param(k, str(v))
        # log metrics (tuned)
        for metric_name, metric_val in metrics_tuned.items():
            mlflow.log_metric(metric_name, float(metric_val))
        # log default metrics too (prefix 'default_')
        for k, v in metrics_default.items():
            mlflow.log_metric("default_" + k, float(v))
        mlflow.log_metric("best_threshold", float(best_thresh))
        mlflow.log_metric("best_threshold_f1", float(best_f1))

        # log confusion matrix & plots as artifacts
        mlflow.log_artifact(cm_fname, artifact_path="plots")
        if roc_fname:
            mlflow.log_artifact(roc_fname, artifact_path="plots")
            mlflow.log_artifact(pr_fname, artifact_path="plots")
        mlflow.log_artifact(prob_hist_fname, artifact_path="plots")

        # log model
        mlflow.sklearn.log_model(best, artifact_path="model")

        # Log a small CSV sample of high confidence misclassifications for inspection
        if high_conf_mis_sample:
            mis_df = pd.DataFrame(high_conf_mis_sample, columns=["idx_in_test", "true", "pred", "proba"])
            mis_csv = f"artifacts/{model_name}_high_conf_mis.csv"
            mis_df.to_csv(mis_csv, index=False)
            mlflow.log_artifact(mis_csv, artifact_path="analysis")

    # Append results summary
    results_list.append({
        "model": model_name,
        "best_params": search.best_params_,
        "metrics_default": metrics_default,
        "metrics_tuned": metrics_tuned,
        "best_threshold": best_thresh,
        "best_threshold_f1": best_f1,
        "model_path": model_fname,
        "cm_plot": cm_fname,
        "roc_plot": roc_fname,
        "prob_hist": prob_hist_fname
    })

# -------------------------
# Compare models: results table and plots
# -------------------------
summary_rows = []
for r in results_list:
    row = {
        "model": r["model"],
        "accuracy": r["metrics_tuned"]["accuracy"],
        "precision": r["metrics_tuned"]["precision"],
        "recall": r["metrics_tuned"]["recall"],
        "f1": r["metrics_tuned"]["f1"],
        "roc_auc": r["metrics_tuned"]["roc_auc"],
        "best_threshold": r["best_threshold"]
    }
    summary_rows.append(row)
summary_df = pd.DataFrame(summary_rows).sort_values(by="f1", ascending=False).reset_index(drop=True)
summary_df.to_csv("artifacts/models_summary.csv", index=False)
print("\n=== Models comparison (sorted by F1) ===")
print(summary_df)

# Plot comparison bar chart for key metrics
plt.figure(figsize=(10,6))
summary_df.set_index("model")[["accuracy","precision","recall","f1"]].plot(kind="bar")
plt.title("Model comparison - key metrics")
plt.ylabel("Score")
plt.ylim(0,1)
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig("artifacts/plots/models_comparison_metrics.png")
plt.close()
mlflow.log_artifact("artifacts/models_summary.csv", artifact_path="comparison")
mlflow.log_artifact("artifacts/plots/models_comparison_metrics.png", artifact_path="comparison")

# -------------------------
# Final reporting: best metrics and explanation
# -------------------------
best_by_f1 = summary_df.iloc[0]
print("\nBest model by F1:", best_by_f1["model"])
print("Metrics:", best_by_f1[["accuracy","precision","recall","f1","roc_auc","best_threshold"]].to_dict())


Train class distribution:
 iDefect
0    0.848887
1    0.151113
Name: proportion, dtype: float64
Test class distribution:
 iDefect
0    0.84897
1    0.15103
Name: proportion, dtype: float64

=== RUNNING: LogisticRegression ===
Fitting 5 folds for each of 9 candidates, totalling 45 fits
Best params: {'clf__penalty': 'l1', 'clf__l1_ratio': None, 'clf__C': np.float64(0.0006951927961775605)}





=== RUNNING: RandomForest ===
Fitting 5 folds for each of 15 candidates, totalling 75 fits
Best params: {'clf__n_estimators': 200, 'clf__min_samples_split': 5, 'clf__min_samples_leaf': 1, 'clf__max_features': 'sqrt', 'clf__max_depth': None}





=== RUNNING: XGBoost ===
Fitting 5 folds for each of 15 candidates, totalling 75 fits
Best params: {'clf__subsample': 0.6, 'clf__n_estimators': 400, 'clf__max_depth': 10, 'clf__learning_rate': 0.01, 'clf__colsample_bytree': 0.4}





=== RUNNING: MLP ===
Fitting 5 folds for each of 12 candidates, totalling 60 fits
Best params: {'clf__learning_rate_init': 0.0001, 'clf__hidden_layer_sizes': (50,), 'clf__alpha': 0.0001, 'clf__activation': 'tanh'}





=== Models comparison (sorted by F1) ===
                model  accuracy  precision    recall        f1   roc_auc  \
0        RandomForest  0.889143   0.579000  0.974747  0.726474  0.942702   
1             XGBoost  0.888889   0.579534  0.962963  0.723593  0.943879   
2                 MLP  0.886092   0.575258  0.939394  0.713555  0.943198   
3  LogisticRegression  0.872871   0.543278  0.993266  0.702381  0.920587   

   best_threshold  
0        0.425179  
1        0.436902  
2        0.604681  
3        0.500000  

Best model by F1: RandomForest
Metrics: {'accuracy': 0.8891431477243834, 'precision': 0.579, 'recall': 0.9747474747474747, 'f1': 0.726474278544542, 'roc_auc': 0.9427017000392263, 'best_threshold': 0.4251785714285715}

Explanation:
- We optimized threshold to maximize F1 by default (balanced precision/recall) because in defect detection you often
  want to balance false positives and false negatives. If the domain prefers catching every defect (minimize FN),
  tune thresho

<Figure size 1000x600 with 0 Axes>