In [None]:
import pandas as pd
import numpy as np
import os
import shap
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.inspection import permutation_importance, PartialDependenceDisplay
from sklearn.ensemble import ExtraTreesRegressor, RandomForestRegressor, StackingRegressor
from xgboost import XGBRegressor
from catboost import CatBoostRegressor
from lightgbm import LGBMRegressor

os.makedirs('result_pic', exist_ok=True)
os.makedirs('result_table', exist_ok=True)

In [None]:
# Data preprocessing
data_ba = pd.read_csv('BaSO4 data.csv')
data_sr = pd.read_csv('SrSO4 data.csv')

def clean_data(df):
    for col in df.columns:
        if col == 'Solubility':
            continue
        df[col] = pd.to_numeric(df[col], errors='coerce')
    print(df.isnull().sum())
    df = df.dropna().reset_index(drop=True)
    return df

data_ba = clean_data(data_ba)
data_sr = clean_data(data_sr)

In [None]:
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import pandas as pd
import numpy as np

def model_pipeline(name, data, target, models):
    X = data.drop(target, axis=1)
    y = data[target]

    # Data partition
    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

    # Data Normalization
    scaler = MinMaxScaler()
    X_train = scaler.fit_transform(X_train)
    X_val = scaler.transform(X_val)
    X_test = scaler.transform(X_test)

    result = []

    # Train the model and log the results
    for mname, model in models.items():
        model.fit(X_train, y_train)
        pred = model.predict(X_test)
        
        # Calculate evaluation metrics
        result.append({
            'Model': mname,
            'R2': r2_score(y_test, pred),
            'MAE': mean_absolute_error(y_test, pred),
            'RMSE': np.sqrt(mean_squared_error(y_test, pred))
        })
        
       
        plt.figure()
        plt.scatter(y_test, pred, alpha=0.7)
        plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], 'r--')
        plt.xlabel('True')
        plt.ylabel('Predicted')
        plt.title(f'{name}-{mname}')
        plt.savefig(f'result_pic/{name}_{mname}_pred.png')
        plt.close()

    result_df = pd.DataFrame(result)
    result_df.to_excel(f'result_table/{name}_model_compare.xlsx', index=False)

    
    sns.barplot(data=result_df, x='Model', y='R2')
    plt.title(f'{name} R2 Comparison')
    plt.savefig(f'result_pic/{name}_R2_bar.png')
    plt.close()

    print(result_df)

    # Output the best model
    best_model_name = result_df.sort_values('R2', ascending=False).iloc[0,0]
    print(f'>>> {name} 最佳模型: {best_model_name}')

    return X_train, X_val, X_test, y_train, y_val, y_test, best_model_name


In [None]:
models_ba = {
    'XGBoost': xgb.XGBRegressor(objective='reg:squarederror'),
    'CatBoost': CatBoostRegressor(verbose=0),
    'ExtraTrees': ExtraTreesRegressor(n_estimators=200, random_state=42),
    'MLP': MLPRegressor(),  
    'RF': RandomForestRegressor(n_estimators=100, random_state=42) 
}

models_sr = {
    'XGBoost': xgb.XGBRegressor(objective='reg:squarederror'),
    'CatBoost': CatBoostRegressor(verbose=0),
    'LightGBM': lgb.LGBMRegressor(),
    'Stacking': StackingRegressor(estimators=[
        ('xgb', xgb.XGBRegressor(objective='reg:squarederror')),
        ('rf', RandomForestRegressor())
    ]),
    'MLP': MLPRegressor(), 
    'RF': RandomForestRegressor(n_estimators=100, random_state=42)  
}

# Model comparison
X_train_ba, X_val_ba, X_test_ba, y_train_ba, y_val_ba, y_test_ba, best_ba = model_pipeline('BaSO4', data_ba, 'Solubility', models_ba)
X_train_sr, X_val_sr, X_test_sr, y_train_sr, y_val_sr, y_test_sr, best_sr = model_pipeline('SrSO4', data_sr, 'Solubility', models_sr)
