In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🏋️‍♂️ 02 – Model Training and Evaluation\n",
    "\n",
    "In this notebook, we load the preprocessed training and test data, train a RandomForest classifier, log the model with MLflow, and evaluate its performance on the test set. We also display important metrics and visualizations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import joblib\n",
    "from sklearn.metrics import accuracy_score, confusion_matrix, classification_report\n",
    "from src.data_loader import load_and_preprocess_data\n",
    "from src.model_trainer import train_model\n",
    "\n",
    "sns.set(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load preprocessed data\n",
    "X_train, X_test, y_train, y_test = load_and_preprocess_data()\n",
    "\n",
    "print(\"Training data shape:\", X_train.shape)\n",
    "print(\"Test data shape:\", X_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the model and log parameters with MLflow\n",
    "model = train_model(X_train, y_train)\n",
    "\n",
    "print(\"Model trained successfully.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate model on test data\n",
    "predictions = model.predict(X_test)\n",
    "acc = accuracy_score(y_test, predictions)\n",
    "print(f\"Test Accuracy: {acc:.2f}\")\n",
    "\n",
    "print(\"\\nClassification Report:\")\n",
    "print(classification_report(y_test, predictions))\n",
    "\n",
    "# Confusion Matrix\n",
    "cm = confusion_matrix(y_test, predictions)\n",
    "sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')\n",
    "plt.title(\"Confusion Matrix\")\n",
    "plt.xlabel(\"Predicted\")\n",
    "plt.ylabel(\"True\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optional: Save the trained model locally\n",
    "joblib.dump(model, \"models/rf_model.joblib\")\n",
    "print(\"Model saved as models/rf_model.joblib\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}