In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ef8397e2",
   "metadata": {},
   "source": [
    "# Цель лабораторной работы\n",
    "Изучить сложные способы подготовки выборки и подбора гиперпараметров на примере метода ближайших соседей.\n",
    "\n",
    "# Задание\n",
    "Требуется выполнить следующие действия:\n",
    "\n",
    "1. Выбрать набор данных (датасет) для решения задачи классификации или регресии.\n",
    "2. В случае необходимости проведите удаление или заполнение пропусков и кодирование категориальных признаков.\n",
    "3. С использованием метода train_test_split разделите выборку на обучающую и тестовую.\n",
    "4. Обучите модель ближайших соседей для произвольно заданного гиперпараметра \n",
    "5. Оцените качество модели с помощью трех подходящих для задачи метрик.\n",
    "6. Постройте модель и оцените качество модели с использованием кросс-валидации. Проведите эксперименты с тремя различными стратегиями кросс-валидации.\n",
    "7. Произведите подбор гиперпараметра с использованием GridSearchCV и кросс-валидации.\n",
    "8. Сравните качество полученной модели с качеством модели, полученной в пункте 4.\n",
    "\n",
    "# Ход выполнения работы\n",
    "Подключим все необходимые библиотеки и настроим отображение графиков:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "b0de3704",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import * \n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "sns.set(style =\"ticks\")\n",
    "\n",
    "\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, median_absolute_error, r2_score\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.neighbors import KNeighborsRegressor, KNeighborsClassifier\n",
    "from sklearn.model_selection import train_test_split, GridSearchCV\n",
    "from sklearn.impute import SimpleImputer, MissingIndicator\n",
    "from sklearn.preprocessing import LabelEncoder, OneHotEncoder, MinMaxScaler, StandardScaler, Normalizer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cdd70703",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('bank_dataset.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03b62b20",
   "metadata": {},
   "source": [
    "## Предварительная подготовка данных"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f5a2f859",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = df.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fca6ab0a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userid</th>\n",
       "      <th>score</th>\n",
       "      <th>City</th>\n",
       "      <th>Gender</th>\n",
       "      <th>Age</th>\n",
       "      <th>Objects</th>\n",
       "      <th>Balance</th>\n",
       "      <th>Products</th>\n",
       "      <th>CreditCard</th>\n",
       "      <th>Loyalty</th>\n",
       "      <th>estimated_salary</th>\n",
       "      <th>Churn</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>15677338</td>\n",
       "      <td>619</td>\n",
       "      <td>Ярославль</td>\n",
       "      <td>Ж</td>\n",
       "      <td>42</td>\n",
       "      <td>2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>101348.88</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>15690047</td>\n",
       "      <td>608</td>\n",
       "      <td>Рыбинск</td>\n",
       "      <td>Ж</td>\n",
       "      <td>41</td>\n",
       "      <td>1</td>\n",
       "      <td>83807.86</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>112542.58</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>15662040</td>\n",
       "      <td>502</td>\n",
       "      <td>Ярославль</td>\n",
       "      <td>Ж</td>\n",
       "      <td>42</td>\n",
       "      <td>8</td>\n",
       "      <td>159660.80</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>113931.57</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>15744090</td>\n",
       "      <td>699</td>\n",
       "      <td>Ярославль</td>\n",
       "      <td>Ж</td>\n",
       "      <td>39</td>\n",
       "      <td>1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>93826.63</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>15780624</td>\n",
       "      <td>850</td>\n",
       "      <td>Рыбинск</td>\n",
       "      <td>Ж</td>\n",
       "      <td>43</td>\n",
       "      <td>2</td>\n",
       "      <td>125510.82</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>79084.10</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     userid  score       City Gender  Age  Objects    Balance  Products  \\\n",
       "0  15677338    619  Ярославль      Ж   42        2        NaN         1   \n",
       "1  15690047    608    Рыбинск      Ж   41        1   83807.86         1   \n",
       "2  15662040    502  Ярославль      Ж   42        8  159660.80         3   \n",
       "3  15744090    699  Ярославль      Ж   39        1        NaN         2   \n",
       "4  15780624    850    Рыбинск      Ж   43        2  125510.82         1   \n",
       "\n",
       "   CreditCard  Loyalty  estimated_salary  Churn  \n",
       "0           1        1         101348.88      1  \n",
       "1           0        1         112542.58      0  \n",
       "2           1        0         113931.57      1  \n",
       "3           0        0          93826.63      0  \n",
       "4           1        1          79084.10      0  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "242b0189",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Кодирование категориальных признаков\n",
    "\n",
    "df[\"City\"] = df[\"City\"].astype('category')\n",
    "\n",
    "df[\"Gender\"] = df[\"Gender\"].astype('category')\n",
    "\n",
    "\n",
    "#Назначить закодированную переменную новосу столбцу с помощью метода доступа\n",
    "df[\"City_cat\"] = df[\"City\"].cat.codes\n",
    "df[\"Gender_cat\"] = df[\"Gender\"].cat.codes\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0da382a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.drop(['City', 'Gender'], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c43a509d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userid</th>\n",
       "      <th>score</th>\n",
       "      <th>Age</th>\n",
       "      <th>Objects</th>\n",
       "      <th>Balance</th>\n",
       "      <th>Products</th>\n",
       "      <th>CreditCard</th>\n",
       "      <th>Loyalty</th>\n",
       "      <th>estimated_salary</th>\n",
       "      <th>Churn</th>\n",
       "      <th>City_cat</th>\n",
       "      <th>Gender_cat</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>15677338</td>\n",
       "      <td>619</td>\n",
       "      <td>42</td>\n",
       "      <td>2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>101348.88</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>15690047</td>\n",
       "      <td>608</td>\n",
       "      <td>41</td>\n",
       "      <td>1</td>\n",
       "      <td>83807.86</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>112542.58</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>15662040</td>\n",
       "      <td>502</td>\n",
       "      <td>42</td>\n",
       "      <td>8</td>\n",
       "      <td>159660.80</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>113931.57</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>15744090</td>\n",
       "      <td>699</td>\n",
       "      <td>39</td>\n",
       "      <td>1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>93826.63</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>15780624</td>\n",
       "      <td>850</td>\n",
       "      <td>43</td>\n",
       "      <td>2</td>\n",
       "      <td>125510.82</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>79084.10</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     userid  score  Age  Objects    Balance  Products  CreditCard  Loyalty  \\\n",
       "0  15677338    619   42        2        NaN         1           1        1   \n",
       "1  15690047    608   41        1   83807.86         1           0        1   \n",
       "2  15662040    502   42        8  159660.80         3           1        0   \n",
       "3  15744090    699   39        1        NaN         2           0        0   \n",
       "4  15780624    850   43        2  125510.82         1           1        1   \n",
       "\n",
       "   estimated_salary  Churn  City_cat  Gender_cat  \n",
       "0         101348.88      1         2           0  \n",
       "1         112542.58      0         1           0  \n",
       "2         113931.57      1         2           0  \n",
       "3          93826.63      0         2           0  \n",
       "4          79084.10      0         1           0  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0ee25f77",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "userid                 0\n",
       "score                  0\n",
       "Age                    0\n",
       "Objects                0\n",
       "Balance             3617\n",
       "Products               0\n",
       "CreditCard             0\n",
       "Loyalty                0\n",
       "estimated_salary       0\n",
       "Churn                  0\n",
       "City_cat               0\n",
       "Gender_cat             0\n",
       "dtype: int64"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.isna().sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7bcfc59d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cad1ef2c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>count</th>\n",
       "      <th>mean</th>\n",
       "      <th>std</th>\n",
       "      <th>min</th>\n",
       "      <th>25%</th>\n",
       "      <th>50%</th>\n",
       "      <th>75%</th>\n",
       "      <th>max</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>userid</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>1.573310e+07</td>\n",
       "      <td>71929.130555</td>\n",
       "      <td>15608437.00</td>\n",
       "      <td>1.567094e+07</td>\n",
       "      <td>15732262.00</td>\n",
       "      <td>1.579584e+07</td>\n",
       "      <td>15858426.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>score</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>6.511385e+02</td>\n",
       "      <td>96.934609</td>\n",
       "      <td>350.00</td>\n",
       "      <td>5.840000e+02</td>\n",
       "      <td>652.00</td>\n",
       "      <td>7.180000e+02</td>\n",
       "      <td>850.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Age</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>3.919771e+01</td>\n",
       "      <td>10.476208</td>\n",
       "      <td>18.00</td>\n",
       "      <td>3.200000e+01</td>\n",
       "      <td>38.00</td>\n",
       "      <td>4.400000e+01</td>\n",
       "      <td>92.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Objects</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>4.979633e+00</td>\n",
       "      <td>2.909514</td>\n",
       "      <td>0.00</td>\n",
       "      <td>2.000000e+00</td>\n",
       "      <td>5.00</td>\n",
       "      <td>8.000000e+00</td>\n",
       "      <td>10.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Balance</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>1.198275e+05</td>\n",
       "      <td>30095.056462</td>\n",
       "      <td>3768.69</td>\n",
       "      <td>1.001820e+05</td>\n",
       "      <td>119839.69</td>\n",
       "      <td>1.395123e+05</td>\n",
       "      <td>250898.09</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Products</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>1.386025e+00</td>\n",
       "      <td>0.577011</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.000000e+00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>2.000000e+00</td>\n",
       "      <td>4.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CreditCard</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>6.992010e-01</td>\n",
       "      <td>0.458641</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.000000e+00</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Loyalty</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>5.135516e-01</td>\n",
       "      <td>0.499855</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.000000e+00</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>estimated_salary</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>1.007174e+05</td>\n",
       "      <td>57380.316584</td>\n",
       "      <td>11.58</td>\n",
       "      <td>5.173685e+04</td>\n",
       "      <td>101139.30</td>\n",
       "      <td>1.495966e+05</td>\n",
       "      <td>199970.74</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Churn</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>2.407959e-01</td>\n",
       "      <td>0.427600</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>City_cat</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>1.013630e+00</td>\n",
       "      <td>0.894271</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>2.000000e+00</td>\n",
       "      <td>2.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Gender_cat</th>\n",
       "      <td>6383.0</td>\n",
       "      <td>5.473915e-01</td>\n",
       "      <td>0.497788</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.000000e+00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.000000e+00</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   count          mean           std          min  \\\n",
       "userid            6383.0  1.573310e+07  71929.130555  15608437.00   \n",
       "score             6383.0  6.511385e+02     96.934609       350.00   \n",
       "Age               6383.0  3.919771e+01     10.476208        18.00   \n",
       "Objects           6383.0  4.979633e+00      2.909514         0.00   \n",
       "Balance           6383.0  1.198275e+05  30095.056462      3768.69   \n",
       "Products          6383.0  1.386025e+00      0.577011         1.00   \n",
       "CreditCard        6383.0  6.992010e-01      0.458641         0.00   \n",
       "Loyalty           6383.0  5.135516e-01      0.499855         0.00   \n",
       "estimated_salary  6383.0  1.007174e+05  57380.316584        11.58   \n",
       "Churn             6383.0  2.407959e-01      0.427600         0.00   \n",
       "City_cat          6383.0  1.013630e+00      0.894271         0.00   \n",
       "Gender_cat        6383.0  5.473915e-01      0.497788         0.00   \n",
       "\n",
       "                           25%          50%           75%          max  \n",
       "userid            1.567094e+07  15732262.00  1.579584e+07  15858426.00  \n",
       "score             5.840000e+02       652.00  7.180000e+02       850.00  \n",
       "Age               3.200000e+01        38.00  4.400000e+01        92.00  \n",
       "Objects           2.000000e+00         5.00  8.000000e+00        10.00  \n",
       "Balance           1.001820e+05    119839.69  1.395123e+05    250898.09  \n",
       "Products          1.000000e+00         1.00  2.000000e+00         4.00  \n",
       "CreditCard        0.000000e+00         1.00  1.000000e+00         1.00  \n",
       "Loyalty           0.000000e+00         1.00  1.000000e+00         1.00  \n",
       "estimated_salary  5.173685e+04    101139.30  1.495966e+05    199970.74  \n",
       "Churn             0.000000e+00         0.00  0.000000e+00         1.00  \n",
       "City_cat          0.000000e+00         1.00  2.000000e+00         2.00  \n",
       "Gender_cat        0.000000e+00         1.00  1.000000e+00         1.00  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.describe().T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "005e6e18",
   "metadata": {},
   "source": [
    "## Разделение данных\n",
    "Разделим данные на целевой столбец и признаки:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7938b99e",
   "metadata": {},
   "source": [
    "При построении предсказательных моделей исходные данные обычно разбиваются на обучающую (\"training set\") и контрольную (\"test set\") выборки. \n",
    "**Обучающая выборка** используется для построения математических отношений между некоторой переменной-откликом и предикторами, тогда как **контрольная (= \"проверочная\")** выборка служит для получения оценки прогнозных свойств модели на новых данных, т.е. данных, которые не были использованы для обучения модели.\n",
    "В нашем случае обучающая выборка - это отток клиентов, а проверочная - это все остальные признаки, которые потенциально могут влиять на решение клиента"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "305b17e1",
   "metadata": {},
   "source": [
    "Банк собирает данные о своих клиентах и хочет выяснить, какие клиенты чаще всего покидают банк. Для этого была сформирована обширная таблица с историческими данными.\n",
    "Целевым признаком в соответствии с задачей является \"Отток клиентов (Churn)\". Мы передаем такую таблицу модели в качестве \"образовательного материала\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "452ca9ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = df['Churn']  #Наименования признаков\n",
    "X = df.drop('Churn', axis=1) # Значения признаков"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "045086e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    X, y, test_size=0.25, random_state= 45)\n",
    "# random_state позволяет задавать базовое значение для генератора случайных чисел, чтобы сделать выборку неслучайной \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f9c21d6c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((4787, 11), (4787,))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Размер обучающей выборки\n",
    "X_train.shape, y_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "84c273c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1596, 11), (1596,))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Размер тестовой выборки\n",
    "X_test.shape, y_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0201b41e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ec1a0ef0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2501ea1b",
   "metadata": {},
   "source": [
    "## Модель ближайших соседей для произвольно заданного гиперпараметра *K*\n",
    "Напишем функцию, которая считает метрики построенной модели:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6684030d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userid</th>\n",
       "      <th>score</th>\n",
       "      <th>Age</th>\n",
       "      <th>Objects</th>\n",
       "      <th>Balance</th>\n",
       "      <th>Products</th>\n",
       "      <th>CreditCard</th>\n",
       "      <th>Loyalty</th>\n",
       "      <th>estimated_salary</th>\n",
       "      <th>City_cat</th>\n",
       "      <th>Gender_cat</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>4787.000000</td>\n",
       "      <td>4787.000000</td>\n",
       "      <td>4787.000000</td>\n",
       "      <td>4787.000000</td>\n",
       "      <td>4787.000000</td>\n",
       "      <td>4787.000000</td>\n",
       "      <td>4787.000000</td>\n",
       "      <td>4787.000000</td>\n",
       "      <td>4787.000000</td>\n",
       "      <td>4787.000000</td>\n",
       "      <td>4787.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.497183</td>\n",
       "      <td>0.602214</td>\n",
       "      <td>0.287423</td>\n",
       "      <td>0.498997</td>\n",
       "      <td>0.439925</td>\n",
       "      <td>0.129030</td>\n",
       "      <td>0.697514</td>\n",
       "      <td>0.512847</td>\n",
       "      <td>0.508251</td>\n",
       "      <td>0.506476</td>\n",
       "      <td>0.544809</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.287851</td>\n",
       "      <td>0.193475</td>\n",
       "      <td>0.140587</td>\n",
       "      <td>0.291413</td>\n",
       "      <td>0.128051</td>\n",
       "      <td>0.192955</td>\n",
       "      <td>0.459382</td>\n",
       "      <td>0.499887</td>\n",
       "      <td>0.285996</td>\n",
       "      <td>0.447354</td>\n",
       "      <td>0.498040</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>0.246767</td>\n",
       "      <td>0.466934</td>\n",
       "      <td>0.189189</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>0.356705</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.261767</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>0.492614</td>\n",
       "      <td>0.609218</td>\n",
       "      <td>0.270270</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.439208</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.513190</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>0.749771</td>\n",
       "      <td>0.739479</td>\n",
       "      <td>0.351351</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.523497</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.750887</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            userid        score          Age      Objects      Balance  \\\n",
       "count  4787.000000  4787.000000  4787.000000  4787.000000  4787.000000   \n",
       "mean      0.497183     0.602214     0.287423     0.498997     0.439925   \n",
       "std       0.287851     0.193475     0.140587     0.291413     0.128051   \n",
       "min       0.000000     0.000000     0.000000     0.000000     0.000000   \n",
       "25%       0.246767     0.466934     0.189189     0.200000     0.356705   \n",
       "50%       0.492614     0.609218     0.270270     0.500000     0.439208   \n",
       "75%       0.749771     0.739479     0.351351     0.800000     0.523497   \n",
       "max       1.000000     1.000000     1.000000     1.000000     1.000000   \n",
       "\n",
       "          Products   CreditCard      Loyalty  estimated_salary     City_cat  \\\n",
       "count  4787.000000  4787.000000  4787.000000       4787.000000  4787.000000   \n",
       "mean      0.129030     0.697514     0.512847          0.508251     0.506476   \n",
       "std       0.192955     0.459382     0.499887          0.285996     0.447354   \n",
       "min       0.000000     0.000000     0.000000          0.000000     0.000000   \n",
       "25%       0.000000     0.000000     0.000000          0.261767     0.000000   \n",
       "50%       0.000000     1.000000     1.000000          0.513190     0.500000   \n",
       "75%       0.333333     1.000000     1.000000          0.750887     1.000000   \n",
       "max       1.000000     1.000000     1.000000          1.000000     1.000000   \n",
       "\n",
       "        Gender_cat  \n",
       "count  4787.000000  \n",
       "mean      0.544809  \n",
       "std       0.498040  \n",
       "min       0.000000  \n",
       "25%       0.000000  \n",
       "50%       1.000000  \n",
       "75%       1.000000  \n",
       "max       1.000000  "
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Масштабирование данных\n",
    "scaler = MinMaxScaler().fit(X_train)\n",
    "X_train = pd.DataFrame(scaler.transform(X_train), columns = X_train.columns)\n",
    "X_test = pd.DataFrame(scaler.transform(X_test), columns = X_train.columns)\n",
    "X_train.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8e441123",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_model(model):\n",
    "    print(\"mean_absolute_error:\",\n",
    "          mean_absolute_error(y_test, model.predict(X_test)))\n",
    "    print(\"mean_squared_error:\",\n",
    "          mean_squared_error(y_test, model.predict(X_test)))\n",
    "    print(\"median_absolute_error:\",\n",
    "          median_absolute_error(y_test, model.predict(X_test)))\n",
    "    print(\"r2_score:\",\n",
    "          r2_score(y_test, model.predict(X_test)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7841d5d7",
   "metadata": {},
   "source": [
    "Попробуем метод ближайших соседей с гиперпараметром K = 10:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "b3baa0b9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsRegressor(n_neighbors=10)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reg_10 = KNeighborsRegressor(n_neighbors=10)\n",
    "reg_10.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "163ed1fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean_absolute_error: 0.2895363408521303\n",
      "mean_squared_error: 0.15360275689223057\n",
      "median_absolute_error: 0.2\n",
      "r2_score: 0.12241210313232487\n"
     ]
    }
   ],
   "source": [
    "test_model(reg_10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9208d3a1",
   "metadata": {},
   "source": [
    "1) mean_absolute_error: 0.289, чем ближе значение к нулю, тем лучше качество регрессии.\n",
    "\n",
    "2) mean_squared_error: 0.15, чем ближе значение к нулю, тем лучше модель\n",
    "\n",
    "3) median_absolute_error: 0.2\n",
    "\n",
    "4) r2_score: 0.12, чем ближе значение коэффициента к 0, тем слабее зависимость, то есть разброс предсказаний модели  𝑦̂   относительно разброса самой целевой переменной  𝑦  ) достаточно велик - то есть модель плохо описывает данные."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}