In [2]:
# -*- coding: utf-8 -*-
import pandas as pd
import pickle
import tensorflow as tf
import datetime

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard
from tensorflow.keras.optimizers import Adam

def load_data(file_path):
    try:
        data = pd.read_csv(file_path)
        return data
    except FileNotFoundError:
        print(f"文件 {file_path} 未找到，请检查文件路径。")
        return None

def preprocess_data(data):
    # Drop irrelevant columns
    data = data.drop(columns=['RowNumber', 'CustomerId', 'Surname'], axis=1)
    
    # Check for missing values
    if data.isnull().any().any():
        print("数据包含缺失值，请处理缺失值。")
        return None
    
    # Encode the categorical variables
    label_encoder_gender = LabelEncoder()
    data['Gender'] = label_encoder_gender.fit_transform(data['Gender'])
    
    onehot_encoder_geo = OneHotEncoder()
    geo_encoded = onehot_encoder_geo.fit_transform(data[['Geography']])
    geo_encoded_df = pd.DataFrame(geo_encoded.toarray(), columns=onehot_encoder_geo.get_feature_names_out(['Geography']))
    
    # Save the encoders
    with open('label_encoder_gender.pkl', 'wb') as f:
        pickle.dump(label_encoder_gender, f)
    with open('onehot_encoder.pkl', 'wb') as f:
        pickle.dump(onehot_encoder_geo, f)
    
    # Divide the data into features and target
    X = pd.concat([data.drop(columns=['Exited', 'Geography'], axis=1), geo_encoded_df], axis=1)
    y = data['Exited']
    
    return X, y

def split_and_scale_data(X, y):
    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # Scale the data
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    
    # Save the scaler
    with open('scaler.pkl', 'wb') as f:
        pickle.dump(scaler, f)
    
    return X_train, X_test, y_train, y_test

# Load the dataset
file_path = 'Churn_Modelling.csv'
data = load_data(file_path)
if data is not None:
    # Preprocess the data
    X, y = preprocess_data(data)
    if X is not None and y is not None:
        # Split and scale the data
        X_train, X_test, y_train, y_test = split_and_scale_data(X, y)
        print("数据预处理完成，训练集和测试集已准备好。")
        print(f"训练集包含 {X_train.shape[0]} 个样本，测试集包含 {X_test.shape[0]} 个样本。")
        print(f"特征维度为 {X_train.shape[1]}。")
        print(f"目标变量的类别数为 {y.nunique()}。")
        print(f"目标变量的类别分布为：\n{y.value_counts()}")
        print(f"训练集目标变量的类别分布为：\n{y_train.value_counts()}")
        print(f"测试集目标变量的类别分布为：\n{y_test.value_counts()}")
        print("数据预处理完成。")
        print("数据已保存为 pickle 文件。")
        print("label_encoder_gender.pkl")
        print("onehot_encoder_geo.pkl")
        print("scaler.pkl")
        print("train_test_data.pkl")

## Build the ANN model
def build_model(input_dim):
    model = Sequential()
    model.add(Dense(units=64, activation='relu', input_dim=input_dim)) ## Hidden layer 1 connected with input layer
    model.add(Dense(units=32, activation='relu')) ## Hidden layer 2 connected with hidden layer 1
    model.add(Dense(units=1, activation='sigmoid')) ## Output layer connected with hidden layer 2
    
    return model
# build_model(X_train.shape[1]).summary()

## compile the model
def compile_model(model):
    model.compile(optimizer=Adam(learning_rate=0.01), loss='binary_crossentropy', metrics=['accuracy'])
    return model
compile_model(build_model(X_train.shape[1])).summary()  

## Train the model
def train_model(model, X_train, y_train):
    log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
    ## Setup early stopping 
    early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)  
    model.fit(X_train, y_train, epochs=10, batch_size=32, validation_split=0.1, callbacks=[early_stopping, tensorboard_callback])
    return model
model = train_model(compile_model(build_model(X_train.shape[1])), X_train, y_train)
model.save('churn_model.h5')

## Evaluate the model
def evaluate_model(model, X_test, y_test):
    loss, accuracy = model.evaluate(X_test, y_test)
    print(f"Loss: {loss}")
    print(f"Accuracy: {accuracy}")
evaluate_model(train_model(compile_model(build_model(X_train.shape[1])), X_train, y_train), X_test, y_test)

## Load tensorboard extension
%load_ext tensorboard
%tensorboard --logdir logs/fit





数据预处理完成，训练集和测试集已准备好。
训练集包含 8000 个样本，测试集包含 2000 个样本。
特征维度为 12。
目标变量的类别数为 2。
目标变量的类别分布为：
Exited
0    7963
1    2037
Name: count, dtype: int64
训练集目标变量的类别分布为：
Exited
0    6356
1    1644
Name: count, dtype: int64
测试集目标变量的类别分布为：
Exited
0    1607
1     393
Name: count, dtype: int64
数据预处理完成。
数据已保存为 pickle 文件。
label_encoder_gender.pkl
onehot_encoder_geo.pkl
scaler.pkl
train_test_data.pkl
Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_9 (Dense)             (None, 64)                832       
                                                                 
 dense_10 (Dense)            (None, 32)                2080      
                                                                 
 dense_11 (Dense)            (None, 1)                 33        
                                                                 
Total params: 2,945
Trainable params: 2,945
Non-trainable params: 0




Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10




Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Loss: 27.951988220214844
Accuracy: 0.6859999895095825
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 86508), started 0:01:27 ago. (Use '!kill 86508' to kill it.)