# Game Churn Prediction AI - ML Churn Prediction

This notebook splits the clean data, trains machine learning models, and evaluates their performance to save the best model.


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import joblib

sns.set_theme(style='whitegrid')


### 1. Load Clean Dataset


In [None]:
df = pd.read_csv('../data/clean_data.csv')
df.head()


### 2. Split data into train/test


In [None]:
X = df.drop(columns=['Churn'])
y = df['Churn']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
print(f"Training data shape: {X_train.shape}")
print(f"Testing data shape: {X_test.shape}")


### 3. Train Models (Random Forest & Logistic Regression)


In [None]:
# Random Forest Model
rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
rf_model.fit(X_train, y_train)
rf_preds = rf_model.predict(X_test)

# Logistic Regression Model
lr_model = LogisticRegression(max_iter=1000, random_state=42)
lr_model.fit(X_train, y_train)
lr_preds = lr_model.predict(X_test)


### 4. Compare Performance


In [None]:
def print_metrics(y_true, y_pred, model_name):
    print(f"--- {model_name} ---")
    print(f"Accuracy:  {accuracy_score(y_true, y_pred):.4f}")
    print(f"Precision: {precision_score(y_true, y_pred):.4f}")
    print(f"Recall:    {recall_score(y_true, y_pred):.4f}")
    print(f"F1-score:  {f1_score(y_true, y_pred):.4f}\n")

print_metrics(y_test, rf_preds, "Random Forest")
print_metrics(y_test, lr_preds, "Logistic Regression")


### 5. Plot Confusion Matrix (Best Model)


In [None]:
cm = confusion_matrix(y_test, rf_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix - Random Forest')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()


### 6. Show Feature Importance & Explain Top Churn Factors


In [None]:
feature_importances = pd.DataFrame({'Feature': X.columns, 'Importance': rf_model.feature_importances_})
feature_importances = feature_importances.sort_values(by='Importance', ascending=False)

plt.figure(figsize=(10, 6))
sns.barplot(x='Importance', y='Feature', data=feature_importances.head(10))
plt.title('Top 10 Feature Importances')
plt.show()

print("\n--- Top Churn Factors Explained ---")
print("The features at the top of the chart have the highest impact on a player's likelihood to churn.")
for i, row in feature_importances.head(3).iterrows():
    print(f"- {row['Feature']}: Highly influences engagement and retention.")


### 7. Save Best Model


In [None]:
joblib.dump(rf_model, '../models/churn_model.pkl')
joblib.dump(X.columns.tolist(), '../models/model_features.pkl')
print("Model successfully saved to '../models/churn_model.pkl'")
