# Customer Churn at a Wizarding School

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, confusion_matrix, roc_curve

df = pd.read_csv('data/wizarding_students.csv')
df.head()

## Exploratory Data Analysis (EDA)

In [None]:
sns.countplot(data=df, x='Churned')
plt.title('Churn Distribution')
plt.show()

sns.countplot(data=df, x='House', hue='Churned')
plt.title('Churn by House')
plt.show()

df[['FlyingGrades', 'PotionsGrades']].hist(bins=20, figsize=(10,4))
plt.suptitle("Grade Distributions")
plt.show()

## Modeling

In [None]:
df_encoded = pd.get_dummies(df.drop(columns=['StudentID']), drop_first=True)
X = df_encoded.drop('Churned', axis=1)
y = df_encoded['Churned']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)


In [None]:
log_model = LogisticRegression(max_iter=1000)
log_model.fit(X_train, y_train)
log_preds = log_model.predict(X_test)

print("Logistic Regression Report:\n", classification_report(y_test, log_preds))


In [None]:
tree_model = DecisionTreeClassifier(max_depth=5)
tree_model.fit(X_train, y_train)
tree_preds = tree_model.predict(X_test)

print("Decision Tree Report:\n", classification_report(y_test, tree_preds))


In [None]:
conf_matrix = confusion_matrix(y_test, tree_preds)
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
plt.title('Decision Tree Confusion Matrix')
plt.show()

tree_probs = tree_model.predict_proba(X_test)[:,1]
fpr, tpr, thresholds = roc_curve(y_test, tree_probs)
plt.plot(fpr, tpr, label='Decision Tree')
plt.plot([0,1],[0,1],'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.show()


## Feature Importance

In [None]:
importances = pd.Series(tree_model.feature_importances_, index=X.columns)
importances.nlargest(10).plot(kind='barh')
plt.title('Top 10 Feature Importances')
plt.show()
