In [1]:
import pandas as pd
import joblib

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier

In [3]:
#Load preprocessed data
df = pd.read_csv("../data/processed/cleaned_data.csv")

In [4]:
feature_cols = ['Age', 'Gender', 'family_history', 'benefits',
                'care_options', 'anonymity', 'leave', 'work_interfere']

X = df[feature_cols]
y = df['treatment']

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)

In [5]:
#Define models
models = {
    "Logistic Regression": LogisticRegression(max_iter=1000),
    "Random Forest": RandomForestClassifier(n_estimators=100, random_state=42),
    "SVM": SVC(kernel='rbf'),
    "KNN": KNeighborsClassifier(n_neighbors=5)
}


In [6]:
#Train & evaluate
results = {}

for name, model in models.items():
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)

    acc = accuracy_score(y_test, y_pred)
    results[name] = acc

    print(f"\n{name}")
    print("Accuracy:", acc)
    print(classification_report(y_test, y_pred))



Logistic Regression
Accuracy: 0.8015873015873016
              precision    recall  f1-score   support

           0       0.81      0.73      0.77       173
           1       0.79      0.86      0.82       205

    accuracy                           0.80       378
   macro avg       0.80      0.80      0.80       378
weighted avg       0.80      0.80      0.80       378


Random Forest
Accuracy: 0.7962962962962963
              precision    recall  f1-score   support

           0       0.81      0.73      0.77       173
           1       0.79      0.85      0.82       205

    accuracy                           0.80       378
   macro avg       0.80      0.79      0.79       378
weighted avg       0.80      0.80      0.80       378


SVM
Accuracy: 0.8148148148148148
              precision    recall  f1-score   support

           0       0.89      0.68      0.77       173
           1       0.78      0.93      0.84       205

    accuracy                           0.81       378


In [7]:
#Select best model
best_model_name = max(results, key=results.get)
best_model = models[best_model_name]

print("\nBest Model:", best_model_name)


Best Model: SVM


In [12]:
joblib.dump(best_model, "../models/artifacts/best_model.pkl")
joblib.dump(list(X.columns), "../models/artifacts/feature_names.pkl")

print("Model saved successfully!")


Model saved successfully!
