In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
from sklearn.pipeline import Pipeline
from scikeras.wrappers import KerasClassifier
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping
import pickle

# load data
data = pd.read_csv('data/Churn_Modelling.csv')
data = data.drop(columns=['RowNumber', 'CustomerId', 'Surname'], axis=1)

Label_encoder_gender = LabelEncoder()
data['Gender'] = Label_encoder_gender.fit_transform(data['Gender'])

onehot_encoder_geo = OneHotEncoder(handle_unknown='ignore')
geo_encoded = onehot_encoder_geo.fit_transform(data['Geography'].values.reshape(-1, 1)).toarray()
geo_encoded_df = pd.DataFrame(geo_encoded, columns=onehot_encoder_geo.get_feature_names_out(['Geography'])) 

data = pd.concat([data.drop('Geography', axis=1), geo_encoded_df], axis=1)
X = data.drop('Exited', axis=1)
y = data['Exited']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Save encoder and scaler for later use
try:
    with open('label_encoder_gender.pkl', 'wb') as f:
        pickle.dump(Label_encoder_gender, f)
    with open('onehot_encoder_geo.pkl', 'wb') as f:
        pickle.dump(onehot_encoder_geo, f)
    with open('scaler.pkl', 'wb') as f:
        pickle.dump(scaler, f)
    print("编码器和标准化器保存成功。")
except Exception as e:
    print(f"保存编码器和标准化器时出现错误: {e}")

# Define a function to create the model and try different hyperparameters(kerasClassifier)
def create_model(activation='relu', neurons=2, optimizer='adam', layers=1):
    model = Sequential()
    model.add(Dense(neurons, activation=activation, input_dim=X_train.shape[1])) # input layer
    for _ in range(layers-1):
        model.add(Dense(neurons, activation=activation))
        
    model.add(Dense(1, activation='sigmoid'))
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
    return model

# 定义早停回调
early_stopping = EarlyStopping(monitor='loss', patience=2, restore_best_weights=True)

# Create the KerasClassifier
model = KerasClassifier(model=create_model, verbose=0, callbacks=[early_stopping])

# Define the hyperparameters to search over
param_grid = {
    'model__activation': ['relu', 'sigmoid'],
    'model__neurons': [1, 2, 4, 8, 16],
    'model__layers': [1, 2],
    'model__optimizer': ['adam', 'sgd'],
    'epochs': [50, 100],
}
# Perform the grid search over the hyperparameters
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1, cv=3)
grid_result = grid.fit(X_train, y_train)

# # Print the best parameters
print(f'Best accuracy: {grid_result.best_score_} using {grid_result.best_params_}')


编码器和标准化器保存成功。


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  current = self.get_monitor_value(logs)
  current = self.get_monitor_value(logs)
  current = self.get_monitor_value(logs)
  current = self.get_monitor_value(logs)
  current = self.get_monitor_value(logs)
  c

Best accuracy: 0.8593743861537188 using {'epochs': 50, 'model__activation': 'relu', 'model__layers': 2, 'model__neurons': 4, 'model__optimizer': 'adam'}
