# 1.Data Preprocessing

# 1.1 Data Imputation

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.impute import KNNImputer
from scipy.stats import wasserstein_distance
import os

plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 20                 
plt.rcParams['font.weight'] = 'bold'            

plt.rcParams['axes.labelsize'] = 20
plt.rcParams['axes.labelweight'] = 'bold'       
plt.rcParams['axes.titlesize'] = 22
plt.rcParams['axes.titleweight'] = 'bold'        
plt.rcParams['legend.fontsize'] = 18
plt.rcParams['legend.title_fontsize'] = 18
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18
plt.rcParams['axes.unicode_minus'] = False       

# Create directory for plots
plot_dir = "E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/fill/"
os.makedirs(plot_dir, exist_ok=True)

# ========== Load Data ========== #
df = pd.read_csv("E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_en.csv")

# ========== Select Numeric Columns with Missing Rate < 80% ========== #
numeric_cols = df.select_dtypes(include=np.number).columns
missing_info = df[numeric_cols].isna().mean()
to_impute_cols = missing_info[(missing_info > 0) & (missing_info < 0.8)].index.tolist()

# ========== Visualize Missing Rates ========== #
plt.figure(figsize=(10, 7))
sorted_missing = missing_info[to_impute_cols].sort_values(ascending=False)
bars = plt.barh(sorted_missing.index, sorted_missing.values, color='steelblue')
plt.title("Missing Rate of Numerical Features", fontweight='bold')
plt.xlabel("Missing Rate", fontweight='bold')
plt.ylabel("Numerical Features", fontweight='bold')

# Annotate missing rate values
for bar in bars:
    width = bar.get_width()
    plt.text(width + 0.001, bar.get_y() + bar.get_height() / 1,
             f"{width:.2%}", va='center', fontsize=12)

plt.tight_layout()
plt.savefig("E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/fill/missing_rates.png",dpi=300)
plt.show()

# ========== Distribution Comparison and Imputation ========== #
best_methods = {}
imputed_df = df.copy()

def get_decimal_places(series):
    """ Get max number of decimal places in a series """
    return series.dropna().apply(lambda x: len(str(x).split('.')[1]) if '.' in str(x) else 0).max()

def is_integer(series):
    """ Check if all values in the series are integers """
    return series.dropna().apply(lambda x: float(x).is_integer()).all()

def plot_distribution_comparison(feature):
    original = df[feature].dropna()

    decimal_places = get_decimal_places(original)
    is_int = is_integer(original)

    # Mean Imputation
    mean_filled = df[feature].copy()
    mean_filled.fillna(mean_filled.mean(), inplace=True)

    # Median Imputation
    median_filled = df[feature].copy()
    median_filled.fillna(median_filled.median(), inplace=True)

    # KNN Imputation
    df_knn = df[numeric_cols].copy()
    imputer = KNNImputer(n_neighbors=3)
    knn_filled_array = imputer.fit_transform(df_knn)
    knn_filled_df = pd.DataFrame(knn_filled_array, columns=numeric_cols)
    knn_filled = knn_filled_df[feature]

    # Preserve original decimal precision or integer format
    if is_int:
        mean_filled = mean_filled.round(0).astype('Int64')
        median_filled = median_filled.round(0).astype('Int64')
        knn_filled = knn_filled.round(0).astype('Int64')
    else:
        mean_filled = mean_filled.round(decimal_places)
        median_filled = median_filled.round(decimal_places)
        knn_filled = knn_filled.round(decimal_places)

    # Compare using Wasserstein distance
    dists = {
        'Mean': wasserstein_distance(original, mean_filled),
        'Median': wasserstein_distance(original, median_filled),
        'KNN': wasserstein_distance(original, knn_filled)
    }
    best_method = min(dists, key=dists.get)
    best_methods[feature] = best_method

    # Apply best imputation
    if best_method == 'Mean':
        if is_int:
            value = round(imputed_df[feature].mean())
            imputed_df[feature] = imputed_df[feature].fillna(value).astype('Int64')
        else:
            imputed_df[feature].fillna(imputed_df[feature].mean(), inplace=True)

    elif best_method == 'Median':
        if is_int:
            value = round(imputed_df[feature].median())
            imputed_df[feature] = imputed_df[feature].fillna(value).astype('Int64')
        else:
            imputed_df[feature].fillna(imputed_df[feature].median(), inplace=True)

    elif best_method == 'KNN':
        if is_int:
            imputed_df[feature] = knn_filled.round(0).astype('Int64')
        else:
            imputed_df[feature] = knn_filled

    # Plot distributions
    plt.figure(figsize=(12, 9))
    ax = plt.gca()
    
    for spine in ax.spines.values():
        spine.set_linewidth(3) 
    sns.kdeplot(original, label='Original (non-missing)', linewidth=4)
    sns.kdeplot(mean_filled, label='Mean Imputation', linestyle='--', linewidth=3)
    sns.kdeplot(median_filled, label='Median Imputation', linestyle='-.', linewidth=3)
    sns.kdeplot(knn_filled, label='KNN Imputation', linestyle=':', linewidth=3)

    plt.title(f"{feature} - Distribution Comparison\nBest Method: {best_method}", fontweight='bold')
    plt.xlabel("Value", fontweight='bold')
    plt.ylabel("Density", fontweight='bold')
    plt.legend(loc='lower right')
    plt.tight_layout()
    # Save the plot
    plot_filename = os.path.join(plot_dir, f"{feature}_impute_compare.png")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()  # Close figure to avoid overlap


# ========== Process All Selected Columns ========== #
for col in to_impute_cols:
    plot_distribution_comparison(col)

# ========== Print Best Methods ========== #
print("Best Imputation Method for Each Feature:")
for k, v in best_methods.items():
    print(f"{k}: {v}")

# Save best methods to CSV
best_methods_df = pd.DataFrame(list(best_methods.items()), columns=["Feature", "Best_Imputation_Method"])
best_methods_df.to_csv("E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/impute_summary.csv", index=False, encoding='utf-8-sig')    
    
# ========== Save Imputed Dataset ========== #
output_path = "E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_fill_en.csv"
imputed_df.to_csv(output_path, index=False, encoding='utf-8-sig')
print(f"Imputed dataset saved to: {output_path}")

# ========== Show Info and Description ========== #
print("\nData Summary After Imputation:")
imputed_df.info()
print("\nDescriptive Statistics:")
print(imputed_df.describe())

df_des=imputed_df.describe()
df_des.to_csv('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_fill_en_des.csv',encoding='utf-8-sig')

# 1.2 Correlation Analysis

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import chi2_contingency
from sklearn.feature_selection import mutual_info_classif
from collections import Counter
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import Rectangle

plt.rcParams['font.family'] = 'Times New Roman'  
plt.rcParams['font.size'] = 20                  
plt.rcParams['font.weight'] = 'bold'            

plt.rcParams['axes.labelsize'] = 20
plt.rcParams['axes.labelweight'] = 'bold'        
plt.rcParams['axes.titlesize'] = 22
plt.rcParams['axes.titleweight'] = 'bold'        
plt.rcParams['legend.fontsize'] = 18
plt.rcParams['legend.title_fontsize'] = 18
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18
plt.rcParams['axes.unicode_minus'] = False        

# -----------------------------------------------
# ---------- Association Calculation ----------
# -----------------------------------------------
def cramers_v(confusion_matrix):
    chi2 = chi2_contingency(confusion_matrix)[0]
    n = confusion_matrix.sum().sum()
    r, k = confusion_matrix.shape
    return np.sqrt((chi2 / n) / (min(k - 1, r - 1)))

# -----------------------------------------------
# ---------- Visualization Functions ----------
# -----------------------------------------------
def plot_heatmap(matrix, title, save_path=None):
    plt.figure(figsize=(10, 8))
    max_val = matrix.values[np.triu_indices_from(matrix, 1)].max()
    sns.heatmap(matrix, annot=True, fmt=".2f", cmap="YlGnBu", square=True,
                linewidths=0.5, cbar=True)
    for y in range(matrix.shape[0]):
        for x in range(matrix.shape[1]):
            if matrix.iloc[y, x] == max_val and x != y:
                plt.gca().add_patch(plt.Rectangle((x, y), 1, 1, fill=False, edgecolor='red', lw=2))
    plt.title(title, fontsize=16, fontweight='bold')
    plt.xticks(rotation=45, fontweight='bold')
    plt.yticks(rotation=0, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_heatmap_overlay(matrix, title, save_path=None):
    plt.figure(figsize=(10, 8))
    mask = np.triu(np.ones_like(matrix, dtype=bool))
    max_val = matrix.values[np.triu_indices_from(matrix, 1)].max()
    sns.heatmap(matrix, mask=mask, annot=True, fmt=".2f", cmap="YlGnBu", square=True,
                linewidths=.5, cbar_kws={"shrink": .6}, alpha=0.9)
    for y in range(matrix.shape[0]):
        for x in range(matrix.shape[1]):
            if matrix.iloc[y, x] == max_val and x != y and not mask[y, x]:
                plt.gca().add_patch(plt.Rectangle((x, y), 1, 1, fill=False, edgecolor='orange', lw=2))
    plt.title(title + " (Upper Triangle)", fontsize=16, fontweight='bold')
    plt.xticks(rotation=45, fontweight='bold')
    plt.yticks(rotation=0, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_ringed_heatmap(matrix, title, save_path=None):
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(matrix, annot=True, fmt=".2f", cmap="coolwarm", square=True,
                linewidths=1, linecolor='white', cbar=False)
    max_val = matrix.values[np.triu_indices_from(matrix, 1)].max()
    for y in range(matrix.shape[0]):
        for x in range(matrix.shape[1]):
            value = matrix.iloc[y, x]
            size = abs(value) * 500
            color = 'red' if value == max_val and x != y else 'black'
            ax.add_patch(plt.Circle((x+0.5, y+0.5), radius=0.15 + size/1500,
                                    color=color, fill=False, lw=1.5))
    plt.title(title + " (Ring Enhanced)", fontsize=16, fontweight='bold')
    plt.xticks(rotation=45, fontweight='bold')
    plt.yticks(rotation=0, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_cluster_heatmap(matrix, title, save_path=None, xtick_rotation=15):
    
    cg = sns.clustermap(matrix, cmap="PuBuGn", annot=True, fmt=".2f",figsize=(10, 10), 
                        square=True, cbar_kws={"shrink": .6},dendrogram_ratio=(0.08, 0.08),cbar_pos=(0.009, 0.8, 0.03, 0.18))
    
    for ax in [cg.ax_row_dendrogram, cg.ax_col_dendrogram]:
        for collection in ax.collections:
            collection.set_linewidth(2)  
        
    plt.suptitle(title + " (Clustered)", fontsize=16, y=1.02, fontweight='bold')

    cg.ax_heatmap.set_xticklabels(
        cg.ax_heatmap.get_xticklabels(),
        rotation=xtick_rotation,
        ha='right',  
    )

    cg.ax_heatmap.set_yticklabels(
        cg.ax_heatmap.get_yticklabels(),
        rotation=-70
    )

    reordered_index = cg.data2d.index.tolist()

    if 'LossType' in reordered_index:
        row_pos = reordered_index.index('LossType')
        ax = cg.ax_heatmap
        ax.add_patch(Rectangle(
            xy=(0, row_pos),          
            width=matrix.shape[1],    
            height=1,                 
            fill=False,
            edgecolor='red',
            linewidth=4
        ))

    if save_path:
        plt.savefig(save_path, dpi=400, bbox_inches='tight')
    plt.show()
    
def plot_bubble_heatmap(matrix, title, save_path=None):
    fig, ax = plt.subplots(figsize=(10, 8))
    max_val = matrix.values[np.triu_indices_from(matrix, 1)].max()
    for y in range(matrix.shape[0]):
        for x in range(matrix.shape[1]):
            value = matrix.iloc[y, x]
            edge_color = 'red' if value == max_val and x != y else 'k'
            ax.scatter(x, y, s=value * 1500, c='skyblue', edgecolors=edge_color, alpha=0.7, linewidths=1.5)
            ax.text(x, y, f"{value:.2f}", va='center', ha='center', fontsize=10)
    ax.set_xticks(np.arange(len(matrix.columns)))
    ax.set_yticks(np.arange(len(matrix.index)))
    ax.set_xticklabels(matrix.columns, rotation=45, fontweight='bold')
    ax.set_yticklabels(matrix.index, fontweight='bold')
    ax.set_xlim(-0.5, len(matrix.columns) - 0.5)
    ax.set_ylim(-0.5, len(matrix.index) - 0.5)
    plt.title(title + " (Bubble Heatmap)", fontsize=16, fontweight='bold')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_fancy_diagonal_heatmap(matrix, title, save_path=None):
    plt.figure(figsize=(10, 8))
    sns.heatmap(matrix, annot=True, fmt=".2f", cmap="PuBuGn", square=True,
                linewidths=1, linecolor='white')
    for i in range(len(matrix)):
        plt.gca().add_patch(plt.Rectangle((i, i), 1, 1, fill=False, edgecolor='red', lw=2))
    plt.title(title + " (Diagonal Highlighted)", fontsize=16, fontweight='bold')
    plt.xticks(rotation=45, fontweight='bold')
    plt.yticks(rotation=0, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=400, bbox_inches='tight')
    plt.show()

def plot_3d_bar_heatmap(matrix, title, save_path=None):
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    _x = np.arange(matrix.shape[0])
    _y = np.arange(matrix.shape[1])
    _xx, _yy = np.meshgrid(_x, _y)
    x, y = _xx.ravel(), _yy.ravel()
    top = matrix.values.ravel()
    max_val = top.max()
    colors = ['crimson' if v == max_val else 'steelblue' for v in top]
    bottom = np.zeros_like(top)
    width = depth = 0.8
    ax.bar3d(x, y, bottom, width, depth, top, shade=True, color=colors)
    ax.set_xticks(np.arange(len(matrix.columns)))
    ax.set_xticklabels(matrix.columns, rotation=30, fontweight='bold')
    ax.set_yticks(np.arange(len(matrix.index)))
    ax.set_yticklabels(matrix.index, rotation=10,fontweight='bold')
    ax.set_title(title + " (3D Bar Chart)", fontsize=16, fontweight='bold')
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# -----------------------------------------------
# ---------- Main Process ----------
# -----------------------------------------------
# Load data
data = pd.read_csv('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_fill_en.csv')
#data = pd.read_csv('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_fill_en_IQR.csv')

cat_cols = ['WorkingCondition', 'LossFormation', 'Lithology', 'LossSeverity', 'LossType']  # Consider translating column names if needed

# Compute Cramér's V matrix
cramer_matrix = pd.DataFrame(index=cat_cols, columns=cat_cols)
for col1 in cat_cols:
    for col2 in cat_cols:
        cm = pd.crosstab(data[col1], data[col2])
        cramer_matrix.loc[col1, col2] = cramers_v(cm)

cramer_matrix = cramer_matrix.astype(float)
path='E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/corr/'
# ---------- Visualization: choose one ----------
plot_heatmap(cramer_matrix, "Cramér's V", save_path=path+"cramer_heatmap.png")
plot_heatmap_overlay(cramer_matrix, "Cramér's V", save_path=path+"cramer_overlay.png")
plot_ringed_heatmap(cramer_matrix, "Cramér's V", save_path=path+"cramer_ring.png")
plot_cluster_heatmap(cramer_matrix, "Cramér's V", save_path=path+"cramer_cluster.png")
plot_bubble_heatmap(cramer_matrix, "Cramér's V", save_path=path+"cramer_bubble.png")
plot_fancy_diagonal_heatmap(cramer_matrix, "Cramér's V", save_path=path+"cramer_diagonal.png")
plot_3d_bar_heatmap(cramer_matrix, "Cramér's V", save_path=path+"cramer_3dbar.png")


# 2. Model Comparison

In [None]:
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.ticker as mtick

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report

from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier

from sklearn.preprocessing import StandardScaler
from matplotlib import font_manager

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, average_precision_score, roc_auc_score, precision_recall_curve, roc_curve, auc, confusion_matrix, classification_report
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression 
import itertools
import os

#
font_path = 'C:/Users/mumu/AppData/Local/Microsoft/Windows/Fonts/SimHei.ttf'
font_prop = font_manager.FontProperties(fname=font_path)
plt.rcParams['font.sans-serif'] = [font_prop.get_name()]
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 20                
plt.rcParams['font.weight'] = 'bold'          

plt.rcParams['axes.labelsize'] = 20
plt.rcParams['axes.labelweight'] = 'bold'     
plt.rcParams['axes.titlesize'] = 22
plt.rcParams['axes.titleweight'] = 'bold'     
plt.rcParams['legend.fontsize'] = 18
plt.rcParams['legend.title_fontsize'] = 18
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18
plt.rcParams['axes.unicode_minus'] = False     

#
def savefig_with_border(fig, axes, path, dpi=350):
    if isinstance(axes, np.ndarray): 
        for ax in axes:
            for spine in ax.spines.values():
                spine.set_linewidth(2)
                spine.set_edgecolor('black')
    else:  #
        for spine in axes.spines.values():
            spine.set_linewidth(2)
            spine.set_edgecolor('black')
    fig.savefig(path, dpi=dpi, bbox_inches='tight')

        
# 
encoded_data = pd.read_csv('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_enc_en.csv')
raw_data = pd.read_csv('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_fill_en.csv')
#encoded_data = pd.read_csv('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_enc_en_IQR.csv')
#raw_data = pd.read_csv('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_fill_en_IQR.csv')

X_encoded = encoded_data.drop(columns=['LossType'])
y_encoded = encoded_data['LossType']

X_raw = raw_data.drop(columns=['LossType'])
y_raw = raw_data['LossType']

cat_features = ['WorkingCondition','LossFormation','Lithology','LossSeverity']

X_train_enc, X_test_enc, y_train_enc, y_test_enc = train_test_split(X_encoded, y_encoded, test_size=0.2, random_state=46)
X_train_raw, X_test_raw, y_train_raw, y_test_raw = train_test_split(X_raw, y_raw, test_size=0.2, random_state=46)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_enc)
X_test_scaled = scaler.transform(X_test_enc)

lgb_model = LGBMClassifier(n_estimators=1000, learning_rate=0.05, random_state=46)
lgb_model.fit(X_train_scaled, y_train_enc)
lgb_train_pred = lgb_model.predict(X_train_scaled)
lgb_test_pred = lgb_model.predict(X_test_scaled)

label_map = {
    '裂缝漏失': 'FractureLoss',
    '断层漏失': 'FaultLoss',
    '孔洞缝漏失': 'CavityFractureLoss',
    '渗透性漏失': 'PermeabilityLoss'
}
y_train_raw=y_train_raw.map(label_map)
y_test_raw=y_test_raw.map(label_map)
catboost_model = CatBoostClassifier(    
    iterations=300,
    learning_rate=0.05,
    depth=4,
    random_state=46,
    verbose=100,
    loss_function='MultiClass',
    task_type="GPU",
    devices='0',
    early_stopping_rounds=50 
)
catboost_model.fit(X_train_raw, y_train_raw, cat_features=cat_features)
cb_train_pred = catboost_model.predict(X_train_raw)
cb_test_pred = catboost_model.predict(X_test_raw)

def collect_metrics(name, y_train, y_train_pred, y_test, y_test_pred):
    return [
        {
            "Model": name, "Dataset": "Train",
            "Accuracy": accuracy_score(y_train, y_train_pred),
            "Macro F1-score": f1_score(y_train, y_train_pred, average='macro'),
            "Weighted F1-score": f1_score(y_train, y_train_pred, average='weighted')
        },
        {
            "Model": name, "Dataset": "Test",
            "Accuracy": accuracy_score(y_test, y_test_pred),
            "Macro F1-score": f1_score(y_test, y_test_pred, average='macro'),
            "Weighted F1-score": f1_score(y_test, y_test_pred, average='weighted')
        }
    ]

metrics = []
metrics.extend(collect_metrics("LightGBM", y_train_enc, lgb_train_pred, y_test_enc, lgb_test_pred))
metrics.extend(collect_metrics("CatBoost", y_train_raw, cb_train_pred, y_test_raw, cb_test_pred))

metrics_df = pd.DataFrame(metrics)

fig,ax=plt.subplots(figsize=(12, 7))
sns.barplot(data=metrics_df, x="Model", y="Accuracy", hue="Dataset", palette="Set2", edgecolor='black',ax=ax)
plt.title("Accuracy Comparison Between Train and Test Sets")
plt.ylabel("Accuracy")
plt.xticks(rotation=0)
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
plt.legend(loc='lower right')

for p in ax.patches:
    height = p.get_height()
    if height>0:
        ax.annotate(f"{height:.2%}", (p.get_x() + p.get_width() / 2., height),ha='center', va='bottom')
plt.tight_layout()
#plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/model/Accuracy-compare.png',dpi=400, bbox_inches='tight')
#set_axes_border(ax)
savefig_with_border(fig,ax,'E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/model/Accuracy-compare.png')
plt.show()

fig,ax=plt.subplots(figsize=(12, 7))
sns.barplot(data=metrics_df, x="Model", y="Macro F1-score", hue="Dataset", palette="Set1", edgecolor='black',ax=ax)
plt.title("Macro F1-score Comparison Between Train and Test Sets")
plt.ylabel("Macro F1-score")
plt.xticks(rotation=0)
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
plt.legend(loc='lower right')

for p in ax.patches:
    height = p.get_height()
    if height > 0:
        ax.annotate(f"{height:.2%}", (p.get_x() + p.get_width() / 2., height),
                    ha='center', va='bottom')
plt.tight_layout()
#plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/model/Macro F1-score-compare.png', dpi=400, bbox_inches='tight')
#set_axes_border(ax)
savefig_with_border(fig,ax,'E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/model/Macro F1-score-compare.png')
plt.show()

#plt.figure(figsize=(12, 6))
fig,ax=plt.subplots(figsize=(12, 7))
sns.barplot(data=metrics_df, x="Model", y="Weighted F1-score", hue="Dataset", palette="Set3", edgecolor='black',ax=ax)
plt.title("Weighted F1-score Comparison Between Train and Test Sets")
plt.ylabel("Weighted F1-score")
plt.xticks(rotation=0)
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
plt.legend(loc='lower right')

for p in ax.patches:
    height = p.get_height()
    if height > 0:
        ax.annotate(f"{height:.2%}", (p.get_x() + p.get_width() / 2., height),
                    ha='center', va='bottom')
plt.tight_layout()
#plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/model/Weighted F1-score-compare.png', dpi=400, bbox_inches='tight')
#set_axes_border(ax)
savefig_with_border(fig,ax,'E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/model/Weighted F1-score-compare.png')
plt.show()

print("LightGBM Classification Report：")
print(classification_report(y_test_enc, lgb_test_pred))

print("CatBoost Classification Report：")
print(classification_report(y_test_raw, cb_test_pred))

def plot_train_test_confusion_side_by_side(models_preds_train, models_preds_test, y_trues_train, y_trues_test, labels, model_names, save_dir):
    os.makedirs(save_dir, exist_ok=True)

    for i, name in enumerate(model_names):
        fig, axes = plt.subplots(1, 2, figsize=(18, 7))

        cm_train = confusion_matrix(y_trues_train[i], models_preds_train[i])
        disp_train = ConfusionMatrixDisplay(confusion_matrix=cm_train, display_labels=labels)
        disp_train.plot(ax=axes[0], cmap='Blues', values_format='d')
        axes[0].set_title(f'{name} - Train',fontsize=24)
        axes[0].set_xlabel("Predicted Label", fontsize=24, labelpad=10)
        axes[0].set_ylabel("True Label", fontsize=24, labelpad=10)
        axes[0].tick_params(axis='x', labelrotation=20, labelsize=22)
        axes[0].tick_params(axis='y', labelrotation=20, labelsize=22)

        cm_test = confusion_matrix(y_trues_test[i], models_preds_test[i])
        disp_test = ConfusionMatrixDisplay(confusion_matrix=cm_test, display_labels=labels)
        disp_test.plot(ax=axes[1], cmap='Blues', values_format='d')
        axes[1].set_title(f'{name} - Test',fontsize=24)
        axes[1].set_xlabel("Predicted Label", fontsize=24, labelpad=10)
        axes[1].set_ylabel("True Label", fontsize=24, labelpad=10)
        axes[1].tick_params(axis='x', labelrotation=20, labelsize=22)
        axes[1].tick_params(axis='y', labelrotation=20, labelsize=22)

        plt.tight_layout()
        #plt.savefig(f'{save_dir}/confusion_matrix_{name}.png', dpi=450, bbox_inches='tight')
        #for ax in axes:
        #    set_axes_border(ax)
        savefig_with_border(fig,axes,f'{save_dir}/confusion_matrix_{name}.png')
        plt.show()
        plt.close()

class_names = ['FractureLoss', 'FaultLoss', 'CavityFractureLoss', 'PermeabilityLoss']
inv_label_map = {v: k for k, v in label_map.items()}
code_to_label = {0: 'FractureLoss', 1: 'FaultLoss', 2: 'CavityFractureLoss', 3: 'PermeabilityLoss'}
def decode_labels(encoded_list):
    return [code_to_label[i] for i in encoded_list]
models_preds_train_labels = [
    decode_labels(lgb_train_pred),
    list(cb_train_pred)
]
models_preds_test_labels = [
    decode_labels(lgb_test_pred),
    list(cb_test_pred)
]

y_train_enc_labels = [code_to_label[i] for i in y_train_enc]
y_test_enc_labels = [code_to_label[i] for i in y_test_enc]
plot_train_test_confusion_side_by_side(
    models_preds_train=models_preds_train_labels,
    models_preds_test=models_preds_test_labels,
    y_trues_train=[y_train_enc_labels, y_train_enc_labels, list(y_train_raw)],
    y_trues_test=[y_test_enc_labels, y_test_enc_labels, list(y_test_raw)],
    labels=list(code_to_label.values()),
    model_names=['LightGBM','CatBoost'],
    save_dir='E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/model'
)

def plot_multi_class_roc(model, X_train, y_train, X_test, y_test, model_name):
    classes = np.unique(np.concatenate([y_train, y_test]))
    y_train_bin = label_binarize(y_train, classes=classes)
    y_test_bin = label_binarize(y_test, classes=classes)
    n_classes = len(classes)

    def safe_predict_proba(m, X, name):
        if name == 'CatBoost':
            proba = np.array(m.predict_proba(X))
        else:
            proba = m.predict_proba(X)
        if isinstance(proba, list):
            proba = np.array(proba)
        return np.squeeze(proba)

    y_score_train = safe_predict_proba(model, X_train, model_name)
    y_score_test = safe_predict_proba(model, X_test, model_name)

    colors = plt.cm.tab10.colors
    fig,ax=plt.subplots(figsize=(10, 8))

    for j in range(n_classes):
        if len(np.unique(y_train_bin[:, j])) >= 2:
            fpr_train, tpr_train, _ = roc_curve(y_train_bin[:, j], y_score_train[:, j])
            auc_train = auc(fpr_train, tpr_train)
            plt.plot(fpr_train, tpr_train, linestyle='-', color=colors[j % len(colors)],
                     label=f'Train - Class {classes[j]} (AUC={auc_train:.2f})')

        if len(np.unique(y_test_bin[:, j])) >= 2:
            fpr_test, tpr_test, _ = roc_curve(y_test_bin[:, j], y_score_test[:, j])
            auc_test = auc(fpr_test, tpr_test)
            plt.plot(fpr_test, tpr_test, linestyle='--', color=colors[j % len(colors)],
                     label=f'Test - Class {classes[j]} (AUC={auc_test:.2f})')

    plt.plot([0, 1], [0, 1], 'k--', lw=1)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate(FPR)')
    plt.ylabel('True Positive Rate(TPR)')
    plt.title(f'Mulit-Classification ROC Curve (Train vs Test) of {model_name} ')
    plt.legend(loc="lower right", fontsize='small')
    plt.grid(True)
    plt.tight_layout()
    #plt.savefig(f'E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/model/roc-{model_name}.png',dpi=350, bbox_inches='tight')
    #set_axes_border()
    savefig_with_border(fig,ax,f'E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/model/roc-{model_name}.png')
    plt.show()

def plot_catboost_roc_compare(
    y_train, y_train_prob,
    y_test, y_test_prob,
    class_names, 
    title="Mulit-Classification ROC Curve (Train vs Test) of CatBoost", 
    figsize=(10, 8)):
    n_classes = len(class_names)
    y_train_bin = label_binarize(y_train, classes=class_names)
    y_test_bin = label_binarize(y_test, classes=class_names)

    fpr_train, tpr_train, auc_train = {}, {}, {}
    fpr_test, tpr_test, auc_test = {}, {}, {}

    #plt.figure(figsize=figsize)
    fig,ax=plt.subplots(figsize=figsize)
    colors = plt.cm.tab10.colors

    for i in range(n_classes):
        fpr_train[i], tpr_train[i], _ = roc_curve(y_train_bin[:, i], y_train_prob[:, i])
        auc_train[i] = auc(fpr_train[i], tpr_train[i])
        plt.plot(fpr_train[i], tpr_train[i], linestyle='--', color=colors[i % len(colors)],
                 label=f"Train - Class {class_names[i]} (AUC={auc_train[i]:.2f})")

        fpr_test[i], tpr_test[i], _ = roc_curve(y_test_bin[:, i], y_test_prob[:, i])
        auc_test[i] = auc(fpr_test[i], tpr_test[i])
        plt.plot(fpr_test[i], tpr_test[i], linestyle='-', color=colors[i % len(colors)],
                 label=f"Test  - Class {class_names[i]} (AUC={auc_test[i]:.2f})")

    plt.plot([0, 1], [0, 1], 'k--', lw=1)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(loc='lower right', fontsize='small')
    plt.grid(True)
    plt.tight_layout()
    #plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/model/roc-catboost.png',dpi=350, bbox_inches='tight')
    #set_axes_border()
    savefig_with_border(fig,ax,'E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/model/roc-catboost.png')
    plt.show()
    
plot_multi_class_roc(lgb_model, X_train_scaled, y_train_enc_labels, X_test_scaled, y_test_enc_labels, "LightGBM")
y_prob_train_cb = catboost_model.predict_proba(X_train_raw)
y_prob_test_cb = catboost_model.predict_proba(X_test_raw)
class_names = np.unique(np.concatenate([y_train_raw, y_test_raw]))
plot_catboost_roc_compare(
    y_train=y_train_raw,
    y_train_prob=y_prob_train_cb,
    y_test=y_test_raw,
    y_test_prob=y_prob_test_cb,
    class_names=class_names,
    title="Mulit-Classification ROC Curve (Train vs Test) of CatBoost"
)


# 3. CatBoost

In [None]:
from sklearn import metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

model = catboost_model
model.fit(X_train_raw, y_train_raw, cat_features=cat_features)

y_pred_train_cb = model.predict(X_train_raw)
y_prob_train_cb = model.predict_proba(X_train_raw)
y_pred_test_cb = model.predict(X_test_raw)
y_prob_test_cb = model.predict_proba(X_test_raw)

print("== Trainset Classification_report ==")
print(metrics.classification_report(y_train_raw, y_pred_train_cb, digits=3))

print("== Trainset Indicators ==")
print("Accuracy score   :", accuracy_score(y_train_raw, y_pred_train_cb))
print("Precision (macro):", precision_score(y_train_raw, y_pred_train_cb, average='macro'))
print("Recall    (macro):", recall_score(y_train_raw, y_pred_train_cb, average='macro'))
print("F1-score  (macro):", f1_score(y_train_raw, y_pred_train_cb, average='macro'))

print("Precision (micro):", precision_score(y_train_raw, y_pred_train_cb, average='micro'))
print("Recall    (micro):", recall_score(y_train_raw, y_pred_train_cb, average='micro'))
print("F1-score  (micro):", f1_score(y_train_raw, y_pred_train_cb, average='micro'))
print("F1-score  (weighted):", f1_score(y_train_raw, y_pred_train_cb, average='weighted'))

try:
    print("ROC AUC (ovr)    :", roc_auc_score(y_train_raw, y_prob_train_cb, multi_class='ovr'))
except ValueError:
    print("ROC AUC skipped (need predict_proba and all classes present)")

print("\n== Testset Classification_report ==")
print(metrics.classification_report(y_test_raw, y_pred_test_cb, digits=3))

print("== Testset Indicators ==")
print("Accuracy score   :", accuracy_score(y_test_raw, y_pred_test_cb))
print("Precision (macro):", precision_score(y_test_raw, y_pred_test_cb, average='macro'))
print("Recall    (macro):", recall_score(y_test_raw, y_pred_test_cb, average='macro'))
print("F1-score  (macro):", f1_score(y_test_raw, y_pred_test_cb, average='macro'))

print("Precision (micro):", precision_score(y_test_raw, y_pred_test_cb, average='micro'))
print("Recall    (micro):", recall_score(y_test_raw, y_pred_test_cb, average='micro'))
print("F1-score  (micro):", f1_score(y_test_raw, y_pred_test_cb, average='micro'))
print("F1-score  (weighted):", f1_score(y_test_raw, y_pred_test_cb, average='weighted'))
try:
    print("ROC AUC (ovr)    :", roc_auc_score(y_test_raw, y_prob_test_cb, multi_class='ovr'))
except ValueError:
    print("ROC AUC skipped (need predict_proba and all classes present)")


# 4. SMOTENC

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from imblearn.over_sampling import SMOTENC
from sklearn.preprocessing import LabelEncoder
from collections import Counter
import os

plt.rcParams['font.family'] = 'Times New Roman'  
plt.rcParams['font.size'] = 20                  
plt.rcParams['font.weight'] = 'bold'            
plt.rcParams['axes.labelsize'] = 20
plt.rcParams['axes.labelweight'] = 'bold'        
plt.rcParams['axes.titlesize'] = 24
plt.rcParams['axes.titleweight'] = 'bold'        
plt.rcParams['legend.fontsize'] = 18
plt.rcParams['legend.title_fontsize'] = 18
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18
plt.rcParams['axes.unicode_minus'] = False       

csv_path = 'E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_fill_en.csv'
#csv_path = 'E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_fill_en_IQR.csv'
target_column = 'LossType'  
categorical_features = ['WorkingCondition', 'LossFormation', 'Lithology', 'LossSeverity']
sampling_ratio = 0.7 

save_dir = 'E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/smotenc/'
os.makedirs(save_dir, exist_ok=True)

df = pd.read_csv(csv_path)

loss_type_mapping = {
    '裂缝漏失': 'FractureLoss',
    '断层漏失': 'FaultLoss',
    '孔洞缝漏失': 'CavityFractureLoss',
    '渗透性漏失': 'PermeabilityLoss'
}
df[target_column] = df[target_column].map(loss_type_mapping)

X = df.drop(columns=[target_column])
y = df[target_column]

X_encoded = X.copy()
encoders = {}
for col in categorical_features:
    le = LabelEncoder()
    X_encoded[col] = le.fit_transform(X[col])
    encoders[col] = le

class_counts = Counter(y)
print("原始类别分布：", class_counts)
max_count = max(class_counts.values())
target_count = int(max_count * sampling_ratio)
sampling_strategy = {cls: target_count for cls, count in class_counts.items() if count < target_count}
print("生成的 sampling_strategy:", sampling_strategy)

cat_feature_indices = [X.columns.get_loc(col) for col in categorical_features]
smote_nc = SMOTENC(categorical_features=cat_feature_indices, 
                   sampling_strategy=sampling_strategy, 
                   k_neighbors=3,
                   random_state=42,
                   n_jobs=-1)
X_resampled, y_resampled = smote_nc.fit_resample(X_encoded, y)

df_orig = pd.DataFrame({'Class': y, 'Dataset': 'Original'})
df_resampled = pd.DataFrame({'Class': y_resampled, 'Dataset': 'Resampled'})
df_all = pd.concat([df_orig, df_resampled], ignore_index=True)
df_all['Class'] = df_all['Class'].astype(str)
desired_order = ['FractureLoss', 'FaultLoss', 'CavityFractureLoss', 'PermeabilityLoss']

plt.figure(figsize=(13, 8))
ax = sns.countplot(data=df_all, x='Class', hue='Dataset',order=desired_order, edgecolor='black')
plt.title('Class Distribution Comparison: Original vs Resampled')
plt.xlabel('Class')
plt.ylabel('Count')
plt.xticks(rotation=10,fontsize=18)
plt.yticks(fontsize=18)
plt.tight_layout()
ax.legend(loc='upper right',title='Dataset')
for patch in ax.patches:
    height = patch.get_height()
    if height>0:
        ax.text(patch.get_x() + patch.get_width() / 2, height + 1, int(height), ha='center', va='bottom')
for spine in ax.spines.values():
    spine.set_linewidth(2)
    spine.set_edgecolor('black')
plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/smotenc/comparison_losstype.png',dpi=350,bbox_inches='tight')
plt.show()

if not isinstance(X_resampled, pd.DataFrame):
    X_resampled = pd.DataFrame(X_resampled, columns=X.columns)
    
numeric_cols = X.select_dtypes(include=['int64', 'float64']).columns
for col in numeric_cols:
    #plt.figure(figsize=(12, 6))
    fig, ax = plt.subplots(figsize=(12, 6))
    sns.kdeplot(data=X[col], label='Original', fill=True, alpha=0.4, linewidth=2,ax=ax)
    sns.kdeplot(data=X_resampled[col], label='Resampled', fill=True, alpha=0.4, linewidth=2,ax=ax)
    plt.title(f'Distribution Comparison: {col}')
    plt.xlabel(col)
    plt.ylabel('Density')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    for spine in ax.spines.values():
        spine.set_linewidth(2)
        spine.set_edgecolor('black')
    filename = f"{save_dir}kde_{col}.png"
    plt.savefig(filename, dpi=350,bbox_inches='tight')
    plt.show()

def plot_categorical_comparison_english(column_name, df_orig, df_resampled, mapping_dict):
    df_orig_ = df_orig[[column_name]].copy()
    df_orig_['Source'] = 'Original'
    df_res_ = df_resampled[[column_name]].copy()
    df_res_['Source'] = 'Resampled'
    df_all = pd.concat([df_orig_, df_res_], axis=0)
    df_all['EnglishLabel'] = df_all[column_name].map(mapping_dict)

    if column_name in ['LossFormation']:
        figsize = (15, 8)
        rotation = 40
        fontsize = 18
        ha = 'right'
    elif column_name in ['WorkingCondition']:  
        figsize = (12, 8)
        rotation = 30
        fontsize = 18
        ha = 'right'
    else : 
        figsize = (12, 8)
        rotation = 0
        fontsize = 18
        ha = 'center'
    plt.figure(figsize=figsize)
    ax = sns.countplot(data=df_all, x='EnglishLabel', hue='Source')

    plt.title(f'Distribution Comparison of {column_name}')
    plt.xlabel(column_name)
    plt.ylabel('Count')

    plt.xticks(rotation=rotation, ha=ha, fontsize=fontsize)

    plt.legend(title='Dataset', fontsize=17, title_fontsize=18)

    for patch in ax.patches:
        height = patch.get_height()
        if height > 0:
            ax.text(patch.get_x() + patch.get_width()/2, height + 1, int(height), ha='center', va='bottom', fontsize=17)

    plt.tight_layout()

    for spine in ax.spines.values():
        spine.set_linewidth(2)
        spine.set_edgecolor('black')

    filename = os.path.join(save_dir, f'cat_comparison_{column_name}.png')
    plt.savefig(filename, dpi=350, bbox_inches='tight')
    plt.show()

X_resampled_df = pd.DataFrame(X_resampled, columns=X.columns)
for col in categorical_features:
    le = encoders[col]
    X_resampled_df[col] = le.inverse_transform(X_resampled_df[col].astype(int))

numerical_features=[col for col in X.columns if col not in categorical_features]
X_resampled[numerical_features]=X_resampled[numerical_features].round(2)
int_columns=X.select_dtypes(include=['int64']).columns
int_columns=[col for col in int_columns if col not in categorical_features]
X_resampled_df[int_columns] = X_resampled_df[int_columns].round().astype(int)

working_condition_map = {
    '注水泥': 'Cementing', '划眼': 'Reaming', '钻进': 'Drilling', '倒划眼': 'BackReaming',
    '循环': 'Circulation', '压井': 'WellKilling', '下套管': 'RunningCasing', '注水泥堵漏': 'CementPlugging',
    '堵漏': 'LossControl', '下钻': 'RunInHole', '起钻': 'PullOutHole', '固井前循环': 'PreCementCirculation',
    '测井': 'Logging', '地漏试验': 'LeakageTest'
}
loss_formation_map = {
    '东二上段': 'Dong2_Upper', '古生界': 'Paleozoic', '东二下段': 'Dong2_Lower', '中生界': 'Mesozoic',
    '东一段': 'Dong1', '馆陶组': 'Guantao', '东三段': 'Dong3', '孔店组': 'Kongdian',
    '太古界': 'Archean', '明化镇': 'Minghuazhen', '东二段': 'Dong2', '沙一段': 'Sha1',
    '潜山': 'Qianshan', '沙二段': 'Sha2', '沙三段': 'Sha3', '东营组': 'Dongying',
    '东上段': 'Dong_Upper', '东下段': 'Dong_Lower', '沙河街': 'Shahejie', '明下段': 'Ming_Lower',
    '明上段': 'Ming_Upper', '平原组': 'Pingyuan', '沙四段': 'Sha4', '沙三中段': 'Sha3_Mid',
    '沙三下': 'Sha3_Lower'
}
lithology_map = {
    '泥岩': 'Mudstone', '灰岩': 'Limestone', '砂岩': 'Sandstone', '火成岩': 'Igneous',
    '砾岩': 'Conglomerate', '变质岩': 'Metamorphic'
}
loss_severity_map = {
    '严重井漏': 'SevereLoss', '微漏': 'MinorLoss', '中漏': 'ModerateLoss', '小漏': 'SlightLoss'
}

mapping_dicts = {
    'WorkingCondition': working_condition_map,
    'LossFormation': loss_formation_map,
    'Lithology': lithology_map,
    'LossSeverity': loss_severity_map
}

for col in categorical_features:
    plot_categorical_comparison_english(col, X, X_resampled_df, mapping_dicts[col])

resampled_data = pd.concat([X_resampled_df.reset_index(drop=True),
                            pd.Series(y_resampled, name=target_column).reset_index(drop=True)], axis=1)

resampled_csv_path = 'E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/data/leakage_type_selected_fill_en_resampled.csv'
resampled_data.to_csv(resampled_csv_path, index=False, encoding='utf-8-sig')
print(f'采样后数据已保存至：{resampled_csv_path}')

# 5. SHAP

# 5.1 Global

In [None]:
import shap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
plt.rcParams['font.family'] = 'Times New Roman'  
plt.rcParams['font.size'] = 20                  
plt.rcParams['font.weight'] = 'bold'           
plt.rcParams['axes.labelsize'] = 20
plt.rcParams['axes.labelweight'] = 'bold'       
plt.rcParams['axes.titlesize'] = 24
plt.rcParams['axes.titleweight'] = 'bold'        
plt.rcParams['legend.fontsize'] = 18
plt.rcParams['legend.title_fontsize'] = 18
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18
plt.rcParams['axes.unicode_minus'] = False        

def savefig_with_border(fig, axes, path, dpi=350):
    if isinstance(axes, np.ndarray):
        for ax in axes:
            for spine in ax.spines.values():
                spine.set_linewidth(2)
                spine.set_edgecolor('black')
    else:  
        for spine in axes.spines.values():
            spine.set_linewidth(2)
            spine.set_edgecolor('black')
    fig.savefig(path, dpi=dpi, bbox_inches='tight')

explainer=shap.TreeExplainer(model)
shap_values=explainer.shap_values(X_test_raw)

mean_abs_shap = np.mean(np.abs(shap_values), axis=2)  # shape: (samples, features)

feature_importance = np.mean(mean_abs_shap, axis=0)  # shape: (features,)
feature_names = X_test_raw.columns

shap_df = pd.DataFrame({
    'Feature': feature_names,
    'SHAP Importance': feature_importance
}).sort_values(by='SHAP Importance', ascending=False)

top_n = 30
shap_df_top = shap_df.head(top_n)
#plt.figure(figsize=(10, 8))
fig,ax=plt.subplots(figsize=(13,9))
bars = plt.barh(
    shap_df_top['Feature'][::-1], 
    shap_df_top['SHAP Importance'][::-1],
    color='skyblue'
)

for bar in bars:
    width = bar.get_width()
    plt.text(width + 0.001, bar.get_y() + bar.get_height() / 2,
             f'{width:.4f}', va='center', fontsize=17)

plt.xlabel("Mean SHAP Value (|impact|)")
plt.title("SHAP Feature Importance (All Classes Averaged)")
plt.tight_layout()
#plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/shap/overall_bar_with_values.png', dpi=350)
savefig_with_border(fig,ax,'E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/shap/overall_bar_with_values.png')
plt.show()


# 5.2 Class-level

In [None]:
import shap
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.colors as mcolors
plt.rcParams['font.family'] = 'Times New Roman'  
plt.rcParams['font.size'] = 20                  
plt.rcParams['font.weight'] = 'bold'           
plt.rcParams['axes.labelsize'] = 20
plt.rcParams['axes.labelweight'] = 'bold'      
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['axes.titleweight'] = 'bold'       
plt.rcParams['legend.fontsize'] = 20
plt.rcParams['legend.title_fontsize'] = 20
plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['axes.unicode_minus'] = False        

def savefig_with_border(fig, axes, path, dpi=350):
    if isinstance(axes, np.ndarray):  
        for ax in axes:
            for spine in ax.spines.values():
                spine.set_linewidth(2)
                spine.set_edgecolor('black')
    else:  
        for spine in axes.spines.values():
            spine.set_linewidth(2)
            spine.set_edgecolor('black')
    fig.savefig(path, dpi=dpi, bbox_inches='tight')

print(pd.Series(y_test_raw).unique())

def custom_colormap():
    return plt.cm.viridis

cmap = mcolors.LinearSegmentedColormap.from_list("blue_red", ["blue", "red"])

def add_legend():
    plt.scatter([], [], color='gray', label='Character Variables', s=100)
    plt.scatter([], [], color='red', label='Positive Impact', s=100)
    plt.scatter([], [], color='blue', label='Negative Impact', s=100)
    plt.legend(loc='upper center', fontsize=15,bbox_to_anchor=(0.4, -0.05), ncol=3, frameon=False)

shap_values_class_0 = shap_values[:, :, 0]
shap.summary_plot(shap_values_class_0, X_test_raw, show=False, max_display=len(X_test_raw.columns))
plt.suptitle("SHAP Summary Plot for FaultLoss",fontsize=20,fontweight='bold')
plt.subplots_adjust(top=0.94)
add_legend()
plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/shap/faultloss.png',dpi=350,bbox_inches='tight')
plt.show()

shap_values_class_1 = shap_values[:, :, 1] 
shap.summary_plot(shap_values_class_1, X_test_raw, show=False, max_display=len(X_test_raw.columns))
plt.suptitle("SHAP Summary Plot for FractureLoss",fontsize=20,fontweight='bold')
plt.subplots_adjust(top=0.94) 
add_legend()
plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/shap/fractureloss.png',dpi=350,bbox_inches='tight')
plt.show()

shap_values_class_2 = shap_values[:, :, 2]
shap.summary_plot(shap_values_class_2, X_test_raw, show=False, max_display=len(X_test_raw.columns))
plt.suptitle("SHAP Summary Plot for PermeabilityLoss",fontsize=20,fontweight='bold')
plt.subplots_adjust(top=0.94)
add_legend()
plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/shap/permeabilityloss.png',dpi=350,bbox_inches='tight')
plt.show()

shap_values_class_3 = shap_values[:, :, 3]
shap.summary_plot(shap_values_class_3, X_test_raw, show=False, max_display=len(X_test_raw.columns))
plt.suptitle("SHAP Summary Plot for CavityFractureLoss",fontsize=20,fontweight='bold')
plt.subplots_adjust(top=0.94)
add_legend()
plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/shap/cavityfractureloss.png',dpi=350,bbox_inches='tight')
plt.show()


# 5.3 Sample-level

In [None]:
# --- 映射字典 ---
working_condition_map = {
    '注水泥': 'Cementing', '划眼': 'Reaming', '钻进': 'Drilling', '倒划眼': 'BackReaming',
    '循环': 'Circulation', '压井': 'WellKilling', '下套管': 'RunningCasing', '注水泥堵漏': 'CementPlugging',
    '堵漏': 'LossControl', '下钻': 'RunInHole', '起钻': 'PullOutHole', '固井前循环': 'PreCementCirculation',
    '测井': 'Logging', '地漏试验': 'LeakageTest'
}
loss_formation_map = {
    '东二上段': 'Dong2_Upper', '古生界': 'Paleozoic', '东二下段': 'Dong2_Lower', '中生界': 'Mesozoic',
    '东一段': 'Dong1', '馆陶组': 'Guantao', '东三段': 'Dong3', '孔店组': 'Kongdian',
    '太古界': 'Archean', '明化镇': 'Minghuazhen', '东二段': 'Dong2', '沙一段': 'Sha1',
    '潜山': 'Qianshan', '沙二段': 'Sha2', '沙三段': 'Sha3', '东营组': 'Dongying',
    '东上段': 'Dong_Upper', '东下段': 'Dong_Lower', '沙河街': 'Shahejie', '明下段': 'Ming_Lower',
    '明上段': 'Ming_Upper', '平原组': 'Pingyuan', '沙四段': 'Sha4', '沙三中段': 'Sha3_Mid',
    '沙三下': 'Sha3_Lower'
}
lithology_map = {
    '泥岩': 'Mudstone', '灰岩': 'Limestone', '砂岩': 'Sandstone', '火成岩': 'Igneous',
    '砾岩': 'Conglomerate', '变质岩': 'Metamorphic'
}
loss_severity_map = {
    '严重井漏': 'SevereLoss', '微漏': 'MinorLoss', '中漏': 'ModerateLoss', '小漏': 'SlightLoss'
}

mapping_dicts = {
    'WorkingCondition': working_condition_map,
    'LossFormation': loss_formation_map,
    'Lithology': lithology_map,
    'LossSeverity': loss_severity_map
}

def map_and_fill(df, mapping_dicts, fill_na_value='Unknown'):
    for col, mapping in mapping_dicts.items():
        if col in df.columns:
            df[col] = df[col].map(mapping)
            df[col] = df[col].fillna(fill_na_value)
    return df

X_test_mapped = map_and_fill(X_test_raw.copy(), mapping_dicts)
print('shap_values shape:', shap_values.shape)
print('X_test_mapped shape:', X_test_mapped.shape)

class_names=['FaultLoss','FractureLoss','PermeabilityLoss','CavityFractureLoss']
sample_index = 9
class_index = 0
num_features = X_test_mapped.shape[1] 
for class_index in range(shap_values.shape[2]):
    fig = plt.figure(figsize=(10,6))
    shap.decision_plot(
        explainer.expected_value[class_index],
        shap_values[:, :, class_index][sample_index:sample_index+1],
        features=X_test_mapped.iloc[sample_index:sample_index+1],
        feature_names=X_test_mapped.columns.tolist(),
        feature_order='none',
        #max_display=num_features
    )
    fig.suptitle(f"Sample {sample_index} - Decision Plot for {class_names[class_index]}",fontsize=20,fontweight='bold')
    fig.savefig(f"E:/jupyter/lost_circulation/records/paper-bhyt/leakage_type/thesis/En/picture/shap/decision_plot_sample{sample_index}_class{class_index}.png", bbox_inches='tight', dpi=350)
    plt.close(fig)
      