In [9]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import joblib
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# 1. Carregar os dados
data = pd.read_csv('../data/telecom_churn_data.csv')  # Ajuste o caminho conforme necessário

# 2. Verificar os tipos de dados
print(data.dtypes)

# 3. Identificar colunas categóricas
categorical_cols = data.select_dtypes(include=['object']).columns.tolist()
print("Colunas categóricas:", categorical_cols)

# 4. Remover a coluna 'Churn' antes do One-Hot Encoding
churn = data['Churn']  # Armazenar a coluna 'Churn'
data = data.drop(columns=['Churn'])  # Remover 'Churn' do DataFrame

# 5. Aplicar One-Hot Encoding nas colunas categóricas
data = pd.get_dummies(data, columns=categorical_cols, drop_first=True)

# 6. Adicionar a coluna 'Churn' de volta ao DataFrame
data['Churn'] = churn

# 7. Separar as características e o alvo
X = data.drop('Churn', axis=1)  # Características
y = data['Churn']  # Alvo

# 8. Verificar e tratar valores nulos, se necessário
if X.isnull().sum().any():
    X.fillna(X.mean(), inplace=True)

# 9. Dividir os dados em conjuntos de treino e teste
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 10. Treinar o modelo
model = RandomForestClassifier()
model.fit(X_train, y_train)

# 11. Salvar o modelo treinado
joblib.dump(model, 'model.pkl')

# 12. Salvar o conjunto de teste
X_test.to_csv('../data/X_test.csv', index=False)
y_test.to_csv('../data/y_test.csv', index=False)

# 13. Carregar o modelo
model = joblib.load('model.pkl')

# 14. Fazer previsões
y_pred = model.predict(X_test)

# 15. Relatório de classificação
print(classification_report(y_test, y_pred))

# 16. Matriz de confusão
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()


customerID         object
gender             object
SeniorCitizen       int64
tenure              int64
MonthlyCharges    float64
Churn              object
dtype: object
Colunas categóricas: ['customerID', 'gender', 'Churn']


KeyError: "['Churn'] not in index"