# Worksheet 7 - Classification (Part II)

### Lecture and Tutorial Learning Goals:

After completing this week's lecture and tutorial work, you will be able to:

* Describe what a test data set is and how it is used in classification.
* Using Python, evaluate classification accuracy using a test data set and appropriate metrics.
* Using Python, execute cross-validation in Python to choose the number of neighbours.
* Identify when it is necessary to scale variables before classification and do this using Python
* In a dataset with > 2 attributes, perform k-nearest neighbour classification in Python using the `scikit-learn` package to predict the class of a test dataset.
* Describe advantages and disadvantages of the k-nearest neighbour classification algorithm.


In [None]:
### Run this cell before continuing.
import altair as alt
import numpy as np
import pandas as pd
import sklearn
from sklearn.compose import make_column_transformer
from sklearn.metrics import confusion_matrix
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.model_selection import (
    GridSearchCV,
    RandomizedSearchCV,
    cross_validate,
    train_test_split,
)
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler

alt.data_transformers.disable_max_rows()
alt.renderers.enable("mimetype")

**Question 0.1** Multiple Choice:
<br>{points: 1}

Before applying k-nearest neighbour to a classification task, we need to scale the data. What is the purpose of this step?

A. To help speed up the knn algorithm.

B. To convert all data observations to numeric values.

C. To ensure all data observations will be on a comparable scale and contribute equal shares to the calculation of the distance between points.

D. None of the above.

*Assign your answer to an object called `answer0_1`. Make sure your answer is an uppercase letter and is surrounded by quotation marks (e.g. `"F"`)*.

In [None]:
# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(answer0_1)).encode("utf-8")+b"6d918ae6d6ce01ae").hexdigest() == "88a22a44eab962877dbbfd38cf03905a65468829", "type of answer0_1 is not str. answer0_1 should be an str"
assert sha1(str(len(answer0_1)).encode("utf-8")+b"6d918ae6d6ce01ae").hexdigest() == "0f2a8ef9e0399c1d396b58ae072efee31b72787e", "length of answer0_1 is not correct"
assert sha1(str(answer0_1.lower()).encode("utf-8")+b"6d918ae6d6ce01ae").hexdigest() == "c0d56f52fa4dfad808ba4e91ceddafa80b9e358c", "value of answer0_1 is not correct"
assert sha1(str(answer0_1).encode("utf-8")+b"6d918ae6d6ce01ae").hexdigest() == "27e26cd0893d5b5a2ee5044da105247947f5db20", "correct string value of answer0_1 but incorrect case of letters"

print('Success!')

## 1. Fruit Data Example - (Part II)
**Question 1.0** 
<br>{points: 1}

First, load the file, `fruit_data.csv` (found in the data folder) from the previous tutorial, into your notebook.

*Assign your data to an object called `fruit_data`.*

In [None]:
# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_data is None)).encode("utf-8")+b"34f1cc4afe805395").hexdigest() == "1eb5081ff3b6020e49ceeaa8d7c8bbf6c76fa66f", "type of fruit_data is None is not bool. fruit_data is None should be a bool"
assert sha1(str(fruit_data is None).encode("utf-8")+b"34f1cc4afe805395").hexdigest() == "7b16639472c42e576b0a1882e5366e5eec528c42", "boolean value of fruit_data is None is not correct"

assert sha1(str(type(fruit_data.shape)).encode("utf-8")+b"d5ae33a19b58d4cf").hexdigest() == "cb989c83f59b8b96589b231f28d085fc1b1fa0fb", "type of fruit_data.shape is not tuple. fruit_data.shape should be a tuple"
assert sha1(str(len(fruit_data.shape)).encode("utf-8")+b"d5ae33a19b58d4cf").hexdigest() == "51af3317f4f0bece0859b666b6cd6ab2f6b6dd7c", "length of fruit_data.shape is not correct"
assert sha1(str(sorted(map(str, fruit_data.shape))).encode("utf-8")+b"d5ae33a19b58d4cf").hexdigest() == "ccdd6ede1e736993fb462f18693f4f73cc12fdbf", "values of fruit_data.shape are not correct"
assert sha1(str(fruit_data.shape).encode("utf-8")+b"d5ae33a19b58d4cf").hexdigest() == "e6e4c5a9e5850fa0e56aa4681d7d1603318e6574", "order of elements of fruit_data.shape is not correct"

assert sha1(str(type(fruit_data.fruit_name.dtype)).encode("utf-8")+b"df91db5e30f56bf2").hexdigest() == "73222a35bd90d3b3391f8a87795e522da12ac545", "type of fruit_data.fruit_name.dtype is not correct"
assert sha1(str(fruit_data.fruit_name.dtype).encode("utf-8")+b"df91db5e30f56bf2").hexdigest() == "e28d2910f0fd7be93fc9cfb082b6afe4177bfcaf", "value of fruit_data.fruit_name.dtype is not correct"

print('Success!')

Let's take a look at the first six observations in the fruit dataset. Run the cell below.

In [None]:
# Run this cell.
fruit_data.head(6)

Run the cell below, and find the nearest neighbour based on mass and width to the first observation just by looking at the scatterplot (the first observation has been circled for you).

In [None]:
# Run this cell.
point1 = [192, 8.4]
point2 = [180, 8]
point44 = [194, 7.2]

fruit_chart = (
    alt.Chart(fruit_data)
    .mark_point(size=15)
    .encode(
        x=alt.X("mass", title="Mass (grams)"),
        y=alt.Y("width", title="Width (cm)", scale=alt.Scale(zero=False)),
        color=alt.Color("fruit_name", title="Name of the Fruit"),
    )
)

(
    fruit_chart
    + alt.Chart(pd.DataFrame([[192, 8.4]], columns=["x", "y"]))
    .mark_point(size=150)
    .encode(x="x", y="y", color=alt.value("black"))
    + alt.Chart(pd.DataFrame([[1, 183, 8.5]], columns=["text", "x", "y"]))
    .mark_text(size=15)
    .encode(x="x", y="y", text="text", color=alt.value("black"))
).configure_axis(labelFontSize=20, titleFontSize=20).configure_legend(
    titleFontSize=15, labelFontSize=15
).properties(
    width=400, height=300
)

**Question 1.1** Multiple Choice:
<br>{points: 1}

Based on the graph generated, what is the `fruit_name` of the closest data point to the one circled?

A. apple

B. lemon

C. mandarin

D. orange

*Assign your answer to an object called `answer1_1`. Make sure your answer is an uppercase letter and is surrounded by quotation marks (e.g. `"F"`).*

In [None]:
# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(answer1_1)).encode("utf-8")+b"d98fae4ee65c8631").hexdigest() == "4b2f4b7df278b9cfa6f1864f03c964a5b3acbac9", "type of answer1_1 is not str. answer1_1 should be an str"
assert sha1(str(len(answer1_1)).encode("utf-8")+b"d98fae4ee65c8631").hexdigest() == "ba09b7ac7ba209f68860bfdfccc2b2e307d67eca", "length of answer1_1 is not correct"
assert sha1(str(answer1_1.lower()).encode("utf-8")+b"d98fae4ee65c8631").hexdigest() == "d3d988a907e8833dd5cfa2fbd7d0e3b7c7b8359e", "value of answer1_1 is not correct"
assert sha1(str(answer1_1).encode("utf-8")+b"d98fae4ee65c8631").hexdigest() == "bbbccd0cf0b25cefc0039043f3c7289f3a860209", "correct string value of answer1_1 but incorrect case of letters"

print('Success!')

**Question 1.2**
<br>{points: 1}

Using `mass` and `width`, calculate the distance between the first observation and the second observation.

We provide a scaffolding to get you started.

*Assign your answer to an object called `fruit_dist_2`.*

In [None]:
# ___ = euclidean_distances(fruit_data.loc[0:1, ["mass", ___]])

# your code here
raise NotImplementedError
fruit_dist_2

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_dist_2)).encode("utf-8")+b"4e8b976f01797b35").hexdigest() == "74667340118243db09f4c4fde5a217e0b38ffa6c", "type of fruit_dist_2 is not correct"
assert sha1(str(fruit_dist_2).encode("utf-8")+b"4e8b976f01797b35").hexdigest() == "a01b382d196e66a373c7f06f0cb096b17d20585d", "value of fruit_dist_2 is not correct"

print('Success!')

**Question 1.3**
<br>{points: 1}

Calculate the distance between the first and the **44th observation** in the `fruit` dataset using the `mass` and `width` variables.

You can see from the data frame output from the cell below that **observation 44** has `mass` = 194 g and `width` = 7.2 cm.

*Assign your answer to an object called `fruit_dist_44`.*

In [None]:
# Run this cell to see the 44th observation
pd.DataFrame(fruit_data.iloc[43, :]).T

In [None]:
# Run this cell.
point1 = [192, 8.4]
point2 = [180, 8]
point44 = [194, 7.2]

(
    fruit_chart
    + alt.Chart(
        pd.DataFrame([[192, 8.4], [180, 8.0], [193.5, 7.2]], columns=["x", "y"])
    )
    .mark_point(size=150)
    .encode(x="x", y="y", color=alt.value("black"))
    + alt.Chart(
        pd.DataFrame(
            [[1, 183, 8.5], [2, 169, 8.1], [44, 204, 7.1]], columns=["text", "x", "y"]
        )
    )
    .mark_text(size=15)
    .encode(x="x", y="y", text="text", color=alt.value("black"))
).configure_axis(labelFontSize=20, titleFontSize=20).configure_legend(
    titleFontSize=15, labelFontSize=15
).properties(
    width=400, height=300
)

In [None]:
# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_dist_44)).encode("utf-8")+b"dd38d4bba8c1924c").hexdigest() == "ca70b97a821ef17b0dc53bab254238f8dd38b016", "type of fruit_dist_44 is not correct"
assert sha1(str(fruit_dist_44).encode("utf-8")+b"dd38d4bba8c1924c").hexdigest() == "25dbd133f39d19cf4471a3a756b05bcbf8c08bab", "value of fruit_dist_44 is not correct"

print('Success!')

What do you notice about your answers from **Question 1.2** & **1.3** that you just calculated? Is it what you would expect given the scatter plot above? Why or why not? Discuss about this.

*Hint: Look at where the observations are on the scatterplot in the cell above this question, and what might happen if we changed grams into kilograms to measure the mass?*

**Question 1.4**
<br>{points: 1}

From the distance calculation, we see that observation 1 and 44 have a smaller distance than observation 1 and 2. However, if we look at the scatterplot the distance of the first observation to the second observation appears closer than to the 44th observation.

Which of the following statements is correct?

A. A difference of 12 g in mass between observation 1 and 2 is large compared to a difference of 1.2 cm in width between observation 1 and 44. Consequently, mass will drive the classification results, and width will have less of an effect. Hence, our distance calculation reflects that.

B. If we measured mass in kilograms, then we’d get different classification results.

C. We should standardize the data so that all variables will be on a comparable scale.

D. All of the above.

*Assign your answer to an object called `answer1_4`. Make sure your answer is an uppercase letter and is surrounded by quotation marks (e.g. `"F"`).*

In [None]:
# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(answer1_4)).encode("utf-8")+b"80fd7f987b33b4f5").hexdigest() == "93740966023b81ec0d69a8b8f9e711c92b3c9d8a", "type of answer1_4 is not str. answer1_4 should be an str"
assert sha1(str(len(answer1_4)).encode("utf-8")+b"80fd7f987b33b4f5").hexdigest() == "46002a0433ea53405f16db85ac37bb833255362e", "length of answer1_4 is not correct"
assert sha1(str(answer1_4.lower()).encode("utf-8")+b"80fd7f987b33b4f5").hexdigest() == "d0ce93fdea3850604c184a05bec6dbe4243e0511", "value of answer1_4 is not correct"
assert sha1(str(answer1_4).encode("utf-8")+b"80fd7f987b33b4f5").hexdigest() == "87df0caaffbbfc7204bbb13cac4b4095350f4e03", "correct string value of answer1_4 but incorrect case of letters"

print('Success!')

**Question 1.5**
<br>{points: 1}

Scale and center all the variables of the `fruit` dataset and save them as columns in your data table. We will use the `StandardScaler` in the preprocessor. Then `.fit_transform` the preprocessor so that we can examine the output.

Fit and transform your preprocessor with predictors `mass`, `width`, `height`, and `color_score`. For other columns, we will `passthrough` them in the preprocessor.

Name the preprocessor `fruit_data_preprocessor`. Concatenate the transformed columns with the original dataframe and save the dataset object in a dataframe and call it `fruit_data_scaled`. Make sure to name the new columns `scaled_*` where * is the old column name (e.g. `scaled_mass`). 

In [None]:
# ___ = make_column_transformer(
#     (
#         "passthrough",
#         [
#             "fruit_label",
#             "fruit_name",
#             "fruit_subtype",
#         ],
#     ),
#     (___, [___, ___, ___, ___]),
# )

# ___ = pd.DataFrame(
#     fruit_data_preprocessor.___(___),
#     columns=[
#         "fruit_label",
#         "fruit_name",
#         "fruit_subtype",
#         ___,
#         ___,
#         ___,
#         ___,
#     ],
# )


# your code here
raise NotImplementedError
fruit_data_scaled.head()

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_data_scaled is None)).encode("utf-8")+b"23461a55f13e294a").hexdigest() == "b1e0a25f7fe948b51f02de4a71efd3dc8e425bb5", "type of fruit_data_scaled is None is not bool. fruit_data_scaled is None should be a bool"
assert sha1(str(fruit_data_scaled is None).encode("utf-8")+b"23461a55f13e294a").hexdigest() == "005e11e738c69aa6b8b44af555084ea867de93bb", "boolean value of fruit_data_scaled is None is not correct"

assert sha1(str(type(fruit_data_scaled.shape)).encode("utf-8")+b"91a756536e8f6474").hexdigest() == "c8712d8907bdd864598a59849ed6753d5bcd7b92", "type of fruit_data_scaled.shape is not tuple. fruit_data_scaled.shape should be a tuple"
assert sha1(str(len(fruit_data_scaled.shape)).encode("utf-8")+b"91a756536e8f6474").hexdigest() == "6c6ebe81e074599903cba0706811c8aaa19e2572", "length of fruit_data_scaled.shape is not correct"
assert sha1(str(sorted(map(str, fruit_data_scaled.shape))).encode("utf-8")+b"91a756536e8f6474").hexdigest() == "805ff00cb22b8c65f5cb0afac267a4324af2141e", "values of fruit_data_scaled.shape are not correct"
assert sha1(str(fruit_data_scaled.shape).encode("utf-8")+b"91a756536e8f6474").hexdigest() == "4ad10c26b42645e9a29acfe0ec9472763f8863c2", "order of elements of fruit_data_scaled.shape is not correct"

assert sha1(str(type(fruit_data_scaled.fruit_name.dtype)).encode("utf-8")+b"4b339d3fda04f48b").hexdigest() == "e3f9857ab5b10ac639d8f9d99097cbb42b45a99c", "type of fruit_data_scaled.fruit_name.dtype is not correct"
assert sha1(str(fruit_data_scaled.fruit_name.dtype).encode("utf-8")+b"4b339d3fda04f48b").hexdigest() == "ef18f280c5eab5d9788efd511672310656b4afb1", "value of fruit_data_scaled.fruit_name.dtype is not correct"

assert sha1(str(type(fruit_data_preprocessor.transformers_[0][2])).encode("utf-8")+b"ef547f3635cadde4").hexdigest() == "b83291a7398ad20b334bc062563c87fdd3628c84", "type of fruit_data_preprocessor.transformers_[0][2] is not list. fruit_data_preprocessor.transformers_[0][2] should be a list"
assert sha1(str(len(fruit_data_preprocessor.transformers_[0][2])).encode("utf-8")+b"ef547f3635cadde4").hexdigest() == "7139b6a0a30005643f213e2a339b41e84f32eebf", "length of fruit_data_preprocessor.transformers_[0][2] is not correct"
assert sha1(str(sorted(map(str, fruit_data_preprocessor.transformers_[0][2]))).encode("utf-8")+b"ef547f3635cadde4").hexdigest() == "3cab12095a727034aea75847a2de0695ed5fa031", "values of fruit_data_preprocessor.transformers_[0][2] are not correct"
assert sha1(str(fruit_data_preprocessor.transformers_[0][2]).encode("utf-8")+b"ef547f3635cadde4").hexdigest() == "3cab12095a727034aea75847a2de0695ed5fa031", "order of elements of fruit_data_preprocessor.transformers_[0][2] is not correct"

assert sha1(str(type(fruit_data_preprocessor.transformers_[1][1])).encode("utf-8")+b"03962083d99625b3").hexdigest() == "9285b7ff93989089a10a4f3032e2742808b3bccf", "type of fruit_data_preprocessor.transformers_[1][1] is not correct"
assert sha1(str(fruit_data_preprocessor.transformers_[1][1]).encode("utf-8")+b"03962083d99625b3").hexdigest() == "a3b57e4c8c0dcee4840dd4a95afaac5fc9fdbe6d", "value of fruit_data_preprocessor.transformers_[1][1] is not correct"

assert sha1(str(type(fruit_data_preprocessor.transformers_[1][2])).encode("utf-8")+b"d771c0809f7d2c75").hexdigest() == "dce7dbbc663978ea17a14c9108230704a0e40fa5", "type of fruit_data_preprocessor.transformers_[1][2] is not list. fruit_data_preprocessor.transformers_[1][2] should be a list"
assert sha1(str(len(fruit_data_preprocessor.transformers_[1][2])).encode("utf-8")+b"d771c0809f7d2c75").hexdigest() == "d9bb7308b7033b9dc7c4269b8d93385beb5dc066", "length of fruit_data_preprocessor.transformers_[1][2] is not correct"
assert sha1(str(sorted(map(str, fruit_data_preprocessor.transformers_[1][2]))).encode("utf-8")+b"d771c0809f7d2c75").hexdigest() == "21eae4047f2091a55eceaaaf121a15990721278d", "values of fruit_data_preprocessor.transformers_[1][2] are not correct"
assert sha1(str(fruit_data_preprocessor.transformers_[1][2]).encode("utf-8")+b"d771c0809f7d2c75").hexdigest() == "6495d3e72e7e148c913f92734ecc4db42eeaed9a", "order of elements of fruit_data_preprocessor.transformers_[1][2] is not correct"

assert sha1(str(type(sum(fruit_data_scaled.scaled_mass.dropna()))).encode("utf-8")+b"60e42f5657d40094").hexdigest() == "b96f2e2f95e24b308ee13d7e7c2422de9d8b573b", "type of sum(fruit_data_scaled.scaled_mass.dropna()) is not float. Please make sure it is float and not np.float64, etc. You can cast your value into a float using float()"
assert sha1(str(round(sum(fruit_data_scaled.scaled_mass.dropna()), 2)).encode("utf-8")+b"60e42f5657d40094").hexdigest() == "22c244496e72cf6e896676ace8ac8569ab91cebb", "value of sum(fruit_data_scaled.scaled_mass.dropna()) is not correct (rounded to 2 decimal places)"

assert sha1(str(type(sum(fruit_data_scaled.scaled_width.dropna()))).encode("utf-8")+b"ff50ff8162e5fb85").hexdigest() == "daf3af2722a538271d2d8ff84d49b72383dfa04a", "type of sum(fruit_data_scaled.scaled_width.dropna()) is not float. Please make sure it is float and not np.float64, etc. You can cast your value into a float using float()"
assert sha1(str(round(sum(fruit_data_scaled.scaled_width.dropna()), 2)).encode("utf-8")+b"ff50ff8162e5fb85").hexdigest() == "b6f0a4f88581f9bce22e76ce17658fe247ce2c07", "value of sum(fruit_data_scaled.scaled_width.dropna()) is not correct (rounded to 2 decimal places)"

assert sha1(str(type(sum(fruit_data_scaled.scaled_height.dropna()))).encode("utf-8")+b"03ec619d801c8332").hexdigest() == "03c1373ead74232d12e80a272a08013708e5c499", "type of sum(fruit_data_scaled.scaled_height.dropna()) is not float. Please make sure it is float and not np.float64, etc. You can cast your value into a float using float()"
assert sha1(str(round(sum(fruit_data_scaled.scaled_height.dropna()), 2)).encode("utf-8")+b"03ec619d801c8332").hexdigest() == "d9877b75951c4dc865bad1dc5903cbb338bcff19", "value of sum(fruit_data_scaled.scaled_height.dropna()) is not correct (rounded to 2 decimal places)"

assert sha1(str(type(sum(fruit_data_scaled.scaled_color_score.dropna()))).encode("utf-8")+b"f71fd2667d70e1b7").hexdigest() == "16dc273ce346e81bc43a9d0e1bd7fcb3af38ba66", "type of sum(fruit_data_scaled.scaled_color_score.dropna()) is not float. Please make sure it is float and not np.float64, etc. You can cast your value into a float using float()"
assert sha1(str(round(sum(fruit_data_scaled.scaled_color_score.dropna()), 2)).encode("utf-8")+b"f71fd2667d70e1b7").hexdigest() == "0a1de72bfcda1427cbba3d26d3b3bd1dca1807d2", "value of sum(fruit_data_scaled.scaled_color_score.dropna()) is not correct (rounded to 2 decimal places)"

print('Success!')


**Question 1.6**
<br>{points: 1}

Let's repeat **Question 1.2** and **1.3** with the scaled variables:

- calculate the distance with the scaled mass and width variables between observations 1 and 2
- calculate the distances with the scaled mass and width variables between observations 1 and 44

After you do this, think about how these distances compared to the distances you computed in **Question 1.2** and **1.3** for the same points.

*Assign your answers to objects called `distance_2` and `distance_44` respectively.*

In [None]:
# your code here
raise NotImplementedError
print(distance_2)
print(distance_44)

In [None]:
from hashlib import sha1
assert sha1(str(type(distance_2 is None)).encode("utf-8")+b"c4d1f5de10d61eb9").hexdigest() == "8d0c9ef0891ac27a685db8f6914e326d5133396d", "type of distance_2 is None is not bool. distance_2 is None should be a bool"
assert sha1(str(distance_2 is None).encode("utf-8")+b"c4d1f5de10d61eb9").hexdigest() == "a6e79c55da03e5180a82ff3fa1450c357d006756", "boolean value of distance_2 is None is not correct"

assert sha1(str(type(distance_2)).encode("utf-8")+b"67456cf718c92435").hexdigest() == "2137e64cfff6d73c8e510a48fba39618ce731845", "type of type(distance_2) is not correct"

assert sha1(str(type(distance_2)).encode("utf-8")+b"2f49b07da2fa8245").hexdigest() == "a44af83f96cb6006fa5f4c1c3f88b5992da52f25", "type of distance_2 is not correct"
assert sha1(str(distance_2).encode("utf-8")+b"2f49b07da2fa8245").hexdigest() == "d166a6f100d021d8d180362eff86bbda5c32ef25", "value of distance_2 is not correct"

assert sha1(str(type(distance_44 is None)).encode("utf-8")+b"4760ee01d1686c89").hexdigest() == "2e2a1be7853a239d9c6519f9e8b23a8a05956c17", "type of distance_44 is None is not bool. distance_44 is None should be a bool"
assert sha1(str(distance_44 is None).encode("utf-8")+b"4760ee01d1686c89").hexdigest() == "494e2b7a8d15d5b951c3a94c94df76e5fab82c25", "boolean value of distance_44 is None is not correct"

assert sha1(str(type(distance_44)).encode("utf-8")+b"4fbf52033fa18987").hexdigest() == "edfeecdd9c8439c03c3ae10e53f11793148005b8", "type of type(distance_44) is not correct"

assert sha1(str(type(distance_44)).encode("utf-8")+b"a55f85850445ef0d").hexdigest() == "186b7b99e07aa1459cf4d1590b7ea8ded77738d7", "type of distance_44 is not correct"
assert sha1(str(distance_44).encode("utf-8")+b"a55f85850445ef0d").hexdigest() == "4dd6c7a2bf540e25ecd2c44eff602e6f04061cb3", "value of distance_44 is not correct"

print('Success!')

## Randomness and Setting Seeds

This worksheet uses functions from the `scikit-learn` library, which not only allows us to perform K-nearest neighbour classification, but also allows us to evaluate how well our classification worked. In order to ensure that the steps in the worksheet are reproducible, we need to set a *`random_state`* or *random seed*, i.e., a numerical "starting value," which determines the sequence of random numbers Python will generate.

Below in many cells we have included an argument to set the `random_state` or `np.random.seed`. They are necessary to make sure the autotesting code functions properly.

## 2. Splitting the data into a training and test set

In this exercise, we will be partitioning `fruit_data` into a training (75%) and testing (25%) set using the `scikit-learn` package. After creating the test set, we will put the test set away in a lock box and not touch it again until we have found the best k-nn classifier we can make using the training set. We will use the variable `fruit_name` as our class label. 


**Question 2.0**
<br> {points: 1}

To create the training and test set, we would use the `train_test_split` function from `scikit-learn` pacakge. Save the trained dataset and test dataset as `fruit_train` and `fruit_test`, respectively. 

In [None]:
# Randomly take 75% of the data in the training set.
# This will be proportional to the different number of fruit names in the dataset.

# ___, ___ = train_test_split(___, test_size=___, random_state=123) # set the random state to be 123

# your code here
raise NotImplementedError
print(fruit_train.head())
print(fruit_test.head())

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_train is None)).encode("utf-8")+b"0b39255c4c2e66e7").hexdigest() == "b732435724e7c1a3b25a835f82aa3e9500213b04", "type of fruit_train is None is not bool. fruit_train is None should be a bool"
assert sha1(str(fruit_train is None).encode("utf-8")+b"0b39255c4c2e66e7").hexdigest() == "45ab8eca17bd78878348135ba312968fa84a64d3", "boolean value of fruit_train is None is not correct"

assert sha1(str(type(fruit_test is None)).encode("utf-8")+b"100623cc5ff2eda3").hexdigest() == "a7ac8901fdd53665cc39e91e1129da5577b1bb0a", "type of fruit_test is None is not bool. fruit_test is None should be a bool"
assert sha1(str(fruit_test is None).encode("utf-8")+b"100623cc5ff2eda3").hexdigest() == "d0ca857bf3dac0b85344bb6af2f8dbdec10f89d7", "boolean value of fruit_test is None is not correct"

assert sha1(str(type(fruit_train.shape)).encode("utf-8")+b"8f3a70ef74860192").hexdigest() == "5fccab095d74661a43e4cb405a2987df9339b83c", "type of fruit_train.shape is not tuple. fruit_train.shape should be a tuple"
assert sha1(str(len(fruit_train.shape)).encode("utf-8")+b"8f3a70ef74860192").hexdigest() == "3aa616f1d50d062c8c4f491d6dddd9524a9fc92e", "length of fruit_train.shape is not correct"
assert sha1(str(sorted(map(str, fruit_train.shape))).encode("utf-8")+b"8f3a70ef74860192").hexdigest() == "026c561418ac6d4d82f19714b23ecba6a9870073", "values of fruit_train.shape are not correct"
assert sha1(str(fruit_train.shape).encode("utf-8")+b"8f3a70ef74860192").hexdigest() == "67bbbf1e7c7c0a3064bde0d41557a22722fb01c0", "order of elements of fruit_train.shape is not correct"

assert sha1(str(type(fruit_test.shape)).encode("utf-8")+b"6db59a536c8f44fa").hexdigest() == "b3430fed1c8db814abc9cbe90a995857511a0800", "type of fruit_test.shape is not tuple. fruit_test.shape should be a tuple"
assert sha1(str(len(fruit_test.shape)).encode("utf-8")+b"6db59a536c8f44fa").hexdigest() == "516b24ef63aca69c2de236c8a2cd2adeda9eaf0e", "length of fruit_test.shape is not correct"
assert sha1(str(sorted(map(str, fruit_test.shape))).encode("utf-8")+b"6db59a536c8f44fa").hexdigest() == "a8788d4e2dde4094614b5aec42edea555af4d362", "values of fruit_test.shape are not correct"
assert sha1(str(fruit_test.shape).encode("utf-8")+b"6db59a536c8f44fa").hexdigest() == "b7646a8afe0d72028d53e17622ed083ef2d40e63", "order of elements of fruit_test.shape is not correct"

assert sha1(str(type(sum(fruit_train.mass))).encode("utf-8")+b"548d2a789f773c14").hexdigest() == "a370a4e6ecbe9fb00857d29766c6c3751f8bedf2", "type of sum(fruit_train.mass) is not int. Please make sure it is int and not np.int64, etc. You can cast your value into an int using int()"
assert sha1(str(sum(fruit_train.mass)).encode("utf-8")+b"548d2a789f773c14").hexdigest() == "e6526cdbef27c0ba06a29b01fbcbe31202b395dc", "value of sum(fruit_train.mass) is not correct"

assert sha1(str(type(sum(fruit_test.mass))).encode("utf-8")+b"b827e6b271b1dd9d").hexdigest() == "44eab88c4873d08064d14195b422b0e6f87fe967", "type of sum(fruit_test.mass) is not int. Please make sure it is int and not np.int64, etc. You can cast your value into an int using int()"
assert sha1(str(sum(fruit_test.mass)).encode("utf-8")+b"b827e6b271b1dd9d").hexdigest() == "7a1173a05d5d87ba3e634294289092ebd104a3b2", "value of sum(fruit_test.mass) is not correct"

print('Success!')

**Question 2.1** 
<br> {points: 1}

K-nearest neighbors is sensitive to the scale of the predictors so we should do some preprocessing to standardize them. Remember that standardizing involves centering/shifting (subtracting the mean of each variable) and scaling (dividing by its standard deviation). Also remember that standardization is *part of your training procedure*, so you can't use your test data to compute the shift / scale values for each variable. Therefore, you must pass only the training data to your preprocessor to compute the preprocessing steps. This ensures that our test data does not influence any aspect of our model training. Once we have created the standardization preprocessor, we can then later on apply it separately to both the training and test data sets.

For this exercise, let's see if `mass` and `color_score` can predict `fruit_name`. 

To scale and center the data, first, pass the predictors to the `make_column_transformer` function to make the preprocessor.

*Assign your answer to an object called `fruit_preprocessor`.*

In [None]:
# ___ = make_column_transformer(
#     (___, [___, ___]),
# )

# your code here
raise NotImplementedError
fruit_preprocessor

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_preprocessor is None)).encode("utf-8")+b"b517803ed81e6b41").hexdigest() == "aec139ffa717740fb43fafc885ceecfcdbb463bd", "type of fruit_preprocessor is None is not bool. fruit_preprocessor is None should be a bool"
assert sha1(str(fruit_preprocessor is None).encode("utf-8")+b"b517803ed81e6b41").hexdigest() == "baf799a8415a33998e1a42b502824381a81c69ab", "boolean value of fruit_preprocessor is None is not correct"

assert sha1(str(type(type(fruit_preprocessor))).encode("utf-8")+b"ccb2f7a97d25760f").hexdigest() == "d9b55cc667aeb2dd748103d03d6b2a0579110e7c", "type of type(fruit_preprocessor) is not correct"
assert sha1(str(type(fruit_preprocessor)).encode("utf-8")+b"ccb2f7a97d25760f").hexdigest() == "952b1289ebb2abcbcd8bf54d772ca7311f2217cf", "value of type(fruit_preprocessor) is not correct"

assert sha1(str(type(fruit_preprocessor.transformers[0][0])).encode("utf-8")+b"3b4075494a77fca2").hexdigest() == "5e2615c849c02d587ce8e24161c590a559131551", "type of fruit_preprocessor.transformers[0][0] is not str. fruit_preprocessor.transformers[0][0] should be an str"
assert sha1(str(len(fruit_preprocessor.transformers[0][0])).encode("utf-8")+b"3b4075494a77fca2").hexdigest() == "b32953be847bc0c6af10ce6e9c560bcb7ad0a0a2", "length of fruit_preprocessor.transformers[0][0] is not correct"
assert sha1(str(fruit_preprocessor.transformers[0][0].lower()).encode("utf-8")+b"3b4075494a77fca2").hexdigest() == "6004371d1d54093956efe182e331606d854fa200", "value of fruit_preprocessor.transformers[0][0] is not correct"
assert sha1(str(fruit_preprocessor.transformers[0][0]).encode("utf-8")+b"3b4075494a77fca2").hexdigest() == "6004371d1d54093956efe182e331606d854fa200", "correct string value of fruit_preprocessor.transformers[0][0] but incorrect case of letters"

assert sha1(str(type(fruit_preprocessor.transformers[0][2])).encode("utf-8")+b"8ea0fc8ab91f6695").hexdigest() == "5a3f5b65b335d8f0feefa5c38ba92621232300f1", "type of fruit_preprocessor.transformers[0][2] is not list. fruit_preprocessor.transformers[0][2] should be a list"
assert sha1(str(len(fruit_preprocessor.transformers[0][2])).encode("utf-8")+b"8ea0fc8ab91f6695").hexdigest() == "f9ebe0d2beac9837be130ccbd3ffc7d1e696d62c", "length of fruit_preprocessor.transformers[0][2] is not correct"
assert sha1(str(sorted(map(str, fruit_preprocessor.transformers[0][2]))).encode("utf-8")+b"8ea0fc8ab91f6695").hexdigest() == "de54fd93b3779bac2f3beb4c146231d175c84a9e", "values of fruit_preprocessor.transformers[0][2] are not correct"
assert sha1(str(fruit_preprocessor.transformers[0][2]).encode("utf-8")+b"8ea0fc8ab91f6695").hexdigest() == "fac2df93f02fc89055c12e7fdb97639535d5eaf7", "order of elements of fruit_preprocessor.transformers[0][2] is not correct"

print('Success!')

**Question 2.2**
<br> {points: 1}

So far, we have split the training and testing datasets as well as preprocessed the data. Now, let's create our K-nearest neighbour classifier with only the training set using the `scikit-learn` package. First, create the classifier by specifying that we want $K = 3$ neighbors and $weights = "distance"$. *Assign your answer to an object called `knn_spec`*. 

Name the predictor as `X` and the target as `y`. 

Next, train the classifier with the training data set using the `make_pipeline` and `fit` function. The `make_pipeline` function allows you to bundle together your pre-processing, modeling, and post-processing requests. Scaffolding is provided below for you.

*Assign your answer to an object called `fruit_fit`*.

In [None]:
# ___ = KNeighborsClassifier(n_neighbors=___, weights="distance")

# ___ = ___[["mass", "color_score"]]
# ___ = fruit_train[___]

# ___ = make_pipeline(___, ___).fit(___, ___)

# your code here
raise NotImplementedError
fruit_fit

In [None]:
from hashlib import sha1
assert sha1(str(type(knn_spec is None)).encode("utf-8")+b"77f044790719f78f").hexdigest() == "1a865832162b29394f1efaf1f04183dfa2518165", "type of knn_spec is None is not bool. knn_spec is None should be a bool"
assert sha1(str(knn_spec is None).encode("utf-8")+b"77f044790719f78f").hexdigest() == "84d24b0c386cff5ce5abb5d23aa0780195ddc051", "boolean value of knn_spec is None is not correct"

assert sha1(str(type(type(knn_spec))).encode("utf-8")+b"9257093782246136").hexdigest() == "ca9f1d4dbf15bbbeb3e723d02a519d1b61dfebe2", "type of type(knn_spec) is not correct"
assert sha1(str(type(knn_spec)).encode("utf-8")+b"9257093782246136").hexdigest() == "b059ff724cbffe8789f8e5fe21de89325d3ab6db", "value of type(knn_spec) is not correct"

assert sha1(str(type(knn_spec.effective_metric_)).encode("utf-8")+b"5144fae2f7bf2551").hexdigest() == "32afb30fd514ae4a65c3d39460938ca879ce5285", "type of knn_spec.effective_metric_ is not str. knn_spec.effective_metric_ should be an str"
assert sha1(str(len(knn_spec.effective_metric_)).encode("utf-8")+b"5144fae2f7bf2551").hexdigest() == "b71f3ae366d83e87aea750cd38aa3cd60444f9ae", "length of knn_spec.effective_metric_ is not correct"
assert sha1(str(knn_spec.effective_metric_.lower()).encode("utf-8")+b"5144fae2f7bf2551").hexdigest() == "5a71a9febc1cb5409d08185e9ead2167c4ab0852", "value of knn_spec.effective_metric_ is not correct"
assert sha1(str(knn_spec.effective_metric_).encode("utf-8")+b"5144fae2f7bf2551").hexdigest() == "5a71a9febc1cb5409d08185e9ead2167c4ab0852", "correct string value of knn_spec.effective_metric_ but incorrect case of letters"

assert sha1(str(type(knn_spec.n_neighbors)).encode("utf-8")+b"1231a984bf0b8f59").hexdigest() == "9ef3ca54b9921de39bb6357e1412dac8f781931a", "type of knn_spec.n_neighbors is not int. Please make sure it is int and not np.int64, etc. You can cast your value into an int using int()"
assert sha1(str(knn_spec.n_neighbors).encode("utf-8")+b"1231a984bf0b8f59").hexdigest() == "ab271a2eaf8b96a25708097f480983ef7103d798", "value of knn_spec.n_neighbors is not correct"

assert sha1(str(type(sum(X.mass))).encode("utf-8")+b"0388a53fd056c85f").hexdigest() == "71fb66b45ad42c70e9e87f9249fe043e75dd00fa", "type of sum(X.mass) is not int. Please make sure it is int and not np.int64, etc. You can cast your value into an int using int()"
assert sha1(str(sum(X.mass)).encode("utf-8")+b"0388a53fd056c85f").hexdigest() == "6ad4ca7d565922edb1aa13d3fdbcdc4cb82c9ebd", "value of sum(X.mass) is not correct"

assert sha1(str(type(sum(X.color_score))).encode("utf-8")+b"b731f78c4828f8f5").hexdigest() == "c49c6dbd40ee411f2112c03ea59ecb6ad0528bf9", "type of sum(X.color_score) is not float. Please make sure it is float and not np.float64, etc. You can cast your value into a float using float()"
assert sha1(str(round(sum(X.color_score), 2)).encode("utf-8")+b"b731f78c4828f8f5").hexdigest() == "9db9541e5f84ee0aedb0816528d2662b08510334", "value of sum(X.color_score) is not correct (rounded to 2 decimal places)"

assert sha1(str(type(y.name)).encode("utf-8")+b"dcbf6f19df84dbc0").hexdigest() == "7376e9110bb28f2319a2a1909cbdbebcfcb3ee5e", "type of y.name is not str. y.name should be an str"
assert sha1(str(len(y.name)).encode("utf-8")+b"dcbf6f19df84dbc0").hexdigest() == "40160408283a59e3cf5b1a48cd83f6ab4f5de9c5", "length of y.name is not correct"
assert sha1(str(y.name.lower()).encode("utf-8")+b"dcbf6f19df84dbc0").hexdigest() == "13cbfa8b860046661f15d3a7eacc61b9621a88e2", "value of y.name is not correct"
assert sha1(str(y.name).encode("utf-8")+b"dcbf6f19df84dbc0").hexdigest() == "13cbfa8b860046661f15d3a7eacc61b9621a88e2", "correct string value of y.name but incorrect case of letters"

assert sha1(str(type(fruit_fit is None)).encode("utf-8")+b"7b368ab063c2f764").hexdigest() == "26551e49a6bfa070b124b40462d2ce2f0b87aa0d", "type of fruit_fit is None is not bool. fruit_fit is None should be a bool"
assert sha1(str(fruit_fit is None).encode("utf-8")+b"7b368ab063c2f764").hexdigest() == "43c7432018911bb7ea3f5270c69f082a4aa6f51b", "boolean value of fruit_fit is None is not correct"

assert sha1(str(type(type(fruit_fit))).encode("utf-8")+b"61daae49bee61cb2").hexdigest() == "989b27588202686838ef617cbe8aaa53a4567f15", "type of type(fruit_fit) is not correct"
assert sha1(str(type(fruit_fit)).encode("utf-8")+b"61daae49bee61cb2").hexdigest() == "cccded5ecb5d0ff4141744fa3982dbd3858e6248", "value of type(fruit_fit) is not correct"

assert sha1(str(type(len(fruit_fit.named_steps))).encode("utf-8")+b"f6793e9a8fff9567").hexdigest() == "4b51d6bb8e064f8cec30052c3c6df70ad904e3fd", "type of len(fruit_fit.named_steps) is not int. Please make sure it is int and not np.int64, etc. You can cast your value into an int using int()"
assert sha1(str(len(fruit_fit.named_steps)).encode("utf-8")+b"f6793e9a8fff9567").hexdigest() == "5fef1f0ba7c0645c509dfa94a2e03075c93518d1", "value of len(fruit_fit.named_steps) is not correct"

assert sha1(str(type(fruit_fit.named_steps.keys())).encode("utf-8")+b"cecf8c3019e854e7").hexdigest() == "367020a4aef6feb2ae42f7cfd1790409e01bffb6", "type of fruit_fit.named_steps.keys() is not correct"
assert sha1(str(fruit_fit.named_steps.keys()).encode("utf-8")+b"cecf8c3019e854e7").hexdigest() == "0ad7be27879087852e1957c7e37b19c163352142", "value of fruit_fit.named_steps.keys() is not correct"

print('Success!')

**Question 2.3**
<br> {points: 1}

Now that we have created our K-nearest neighbor classifier object, let's predict the class labels for our test set.

First, pass your fitted model and the **test dataset** to the `predict` function. Then, use the `pd.concat` function to add the column of predictions to the original test data. 

*Assign your answer to an object called `fruit_test_predictions`.*

In [None]:
# ___ = ___.predict(___[[___, ___]])
# ___ = pd.concat(
#     [
#         fruit_test.reset_index(drop=True),
#         pd.DataFrame(fruit_test_predictions, columns=["predicted"]),
#     ],
#     axis=1,
# ) # use concat to add the predictions column to the original test data

# your code here
raise NotImplementedError
fruit_test_predictions.head()

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_test_predictions is None)).encode("utf-8")+b"5de608b7f7bcb18b").hexdigest() == "bb6420318effbcc68200553506e884de813b39cf", "type of fruit_test_predictions is None is not bool. fruit_test_predictions is None should be a bool"
assert sha1(str(fruit_test_predictions is None).encode("utf-8")+b"5de608b7f7bcb18b").hexdigest() == "29595ca6668b2a696947cbd64a37074da6c1fc21", "boolean value of fruit_test_predictions is None is not correct"

assert sha1(str(type(fruit_test_predictions)).encode("utf-8")+b"1d52d00e48ae7c5d").hexdigest() == "e3374ada95975cf337dd1d9e59aa30909c9e2387", "type of type(fruit_test_predictions) is not correct"

assert sha1(str(type(fruit_test_predictions.shape)).encode("utf-8")+b"3894a51165929b73").hexdigest() == "5ad9e362b4b272d99d7a78eacf1de1a0a482384f", "type of fruit_test_predictions.shape is not tuple. fruit_test_predictions.shape should be a tuple"
assert sha1(str(len(fruit_test_predictions.shape)).encode("utf-8")+b"3894a51165929b73").hexdigest() == "0b18586c5499484860103f82b95611d1fcb9c822", "length of fruit_test_predictions.shape is not correct"
assert sha1(str(sorted(map(str, fruit_test_predictions.shape))).encode("utf-8")+b"3894a51165929b73").hexdigest() == "ad04ad7dbcc37e35727c04d15d83a2df597ede07", "values of fruit_test_predictions.shape are not correct"
assert sha1(str(fruit_test_predictions.shape).encode("utf-8")+b"3894a51165929b73").hexdigest() == "76d2146ace422fff8dc17eac54c9f12b9268d674", "order of elements of fruit_test_predictions.shape is not correct"

assert sha1(str(type("predicted" in fruit_test_predictions.columns)).encode("utf-8")+b"dd166f7ba2618fa3").hexdigest() == "b60bdec5b15149239290c8d57b6a69a2a9f60319", "type of \"predicted\" in fruit_test_predictions.columns is not bool. \"predicted\" in fruit_test_predictions.columns should be a bool"
assert sha1(str("predicted" in fruit_test_predictions.columns).encode("utf-8")+b"dd166f7ba2618fa3").hexdigest() == "a9314f2515f30c8f7ba819beb53cc9a910be6022", "boolean value of \"predicted\" in fruit_test_predictions.columns is not correct"

print('Success!')

**Question 2.4**
<br> {points: 1}

Great! We have now computed some predictions for our test datasets! Wouldn't it be interesting if we could find out our classifier's accuracy? 

Thankfully, the `score` function from the `scikit-learn` package can help us. To get the statistics about the quality of our model, you need to call the `score` function on the `fruit_fit` model. Name the predictors as `X_test` and the target as `y_test`. We should pass the `X_test` and `y_test` into the `score` function.

*Assign your answer to an object called `fruit_prediction_accuracy`.*

In [None]:
# ___ = ___[[___, ___]]
# ___ = ___["fruit_name"]

# ___ = fruit_fit.score(___, ___)

# your code here
raise NotImplementedError
fruit_prediction_accuracy

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_prediction_accuracy is None)).encode("utf-8")+b"b18f0fb65c781930").hexdigest() == "e29669e70e1d70749964a2a1781c4dcf5168fea1", "type of fruit_prediction_accuracy is None is not bool. fruit_prediction_accuracy is None should be a bool"
assert sha1(str(fruit_prediction_accuracy is None).encode("utf-8")+b"b18f0fb65c781930").hexdigest() == "ccebec3a236c751fb3622d0913bebc0a91815950", "boolean value of fruit_prediction_accuracy is None is not correct"

assert sha1(str(type(fruit_prediction_accuracy)).encode("utf-8")+b"91086ab3fb55222e").hexdigest() == "f9b8e7cd6247bf4ad4ae2bb479ee7d5bc27589a0", "type of fruit_prediction_accuracy is not correct"
assert sha1(str(fruit_prediction_accuracy).encode("utf-8")+b"91086ab3fb55222e").hexdigest() == "3a3ac43aac706e2b06653de21daf097e4fd02940", "value of fruit_prediction_accuracy is not correct"

print('Success!')

**Question 2.5**
<br> {points: 1}

Now, let's look at the *confusion matrix* for the classifier. This will show us the table of predicted labels and correct labels. 

A confusion matrix is essentially a classification matrix. The columns of the confusion matrix represent the actual class and the rows represent the predicted class (or vice versa). Shown below is an example of a confusion matrix.

|                  |          |  Actual Values |                |
|:----------------:|----------|:--------------:|:--------------:|
|                  |          |    Positive    |    Negative    |
|**Predicted Value**  | Positive |  True Positive | False Positive|
|                  | Negative | False Negative | True Negative  |


- A **true positive** is an outcome where the model correctly predicts the positive class.
- A **true negative** is an outcome where the model correctly predicts the negative class.
- A **false positive** is an outcome where the model incorrectly predicts the positive class.
- A **false negative** is an outcome where the model incorrectly predicts the negative class.

<br>

We can create a confusion matrix by using the `confusion_matrix` function from `scikit-learn` package. 

*Assign your answer to an object called `fruit_mat`*.

In [None]:
# ___ = ___(
#     fruit_test_predictions[___],  # true labels
#     fruit_test_predictions[___],  # predicted labels
#     labels=fruit_fit.classes_, # specify the label for each class
# )

# your code here
raise NotImplementedError
fruit_mat

It is hard for us to interpret the confusion matrix as shown above. We could use the `ConfusionMatrixDisplay` function of the `scikit-learn` package to plot the confusion matrix. Please run the cell below.

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

disp = ConfusionMatrixDisplay(
    confusion_matrix=fruit_mat, display_labels=fruit_fit.classes_
)
disp.plot()

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_mat is None)).encode("utf-8")+b"9e831774fc3eba04").hexdigest() == "ec2888c6e3f3d89ea6fb58252674ec1ca86a6006", "type of fruit_mat is None is not bool. fruit_mat is None should be a bool"
assert sha1(str(fruit_mat is None).encode("utf-8")+b"9e831774fc3eba04").hexdigest() == "530d7ab148da865bbcb7840315f495fbf1d7af20", "boolean value of fruit_mat is None is not correct"

assert sha1(str(type(fruit_mat)).encode("utf-8")+b"8671f5dc3ad6038f").hexdigest() == "499d82f1bf8fe92e43dfacab411252afadd21344", "type of type(fruit_mat) is not correct"

assert sha1(str(type(fruit_mat.sum())).encode("utf-8")+b"4d136ad9c9110aae").hexdigest() == "73a079f5311d4704f7414128419b8b1d7a004a3c", "type of fruit_mat.sum() is not correct"
assert sha1(str(fruit_mat.sum()).encode("utf-8")+b"4d136ad9c9110aae").hexdigest() == "d8172b6694ab65e0f41fb91e50826bada8eb209a", "value of fruit_mat.sum() is not correct"

print('Success!')

**Question 2.6** Multiple Choice:
<br> {points: 1}

Reading `fruit_mat`, how many observations were labelled correctly?

A. 7

B. 8

C. 9

D. 14

*Assign your answer to an object called `answer2_6`. Make sure your answer is an uppercase letter and is surrounded by quotation marks (e.g. `"F"`).*

In [None]:
# your code here
raise NotImplementedError
answer2_6

In [None]:
from hashlib import sha1
assert sha1(str(type(answer2_6)).encode("utf-8")+b"e191e31e8e307ea4").hexdigest() == "eb02ce64fceaedd450185c90810bd638942b687d", "type of answer2_6 is not str. answer2_6 should be an str"
assert sha1(str(len(answer2_6)).encode("utf-8")+b"e191e31e8e307ea4").hexdigest() == "2146c229bae7304d6bd6d4d49dc29b3b83d29aab", "length of answer2_6 is not correct"
assert sha1(str(answer2_6.lower()).encode("utf-8")+b"e191e31e8e307ea4").hexdigest() == "12e6520b21c8bf2080ebe2bdcdb371c3ca559a20", "value of answer2_6 is not correct"
assert sha1(str(answer2_6).encode("utf-8")+b"e191e31e8e307ea4").hexdigest() == "601ef9d5d93eb254458f72311508813fb4bd0b3a", "correct string value of answer2_6 but incorrect case of letters"

print('Success!')

### 3. Cross-validation

**Question 3.1**
<br> {points: 1}

The vast majority of predictive models in statistics and machine learning have parameters that you have to pick. For the past few exercises, we have had to pick the number of neighbours for the class vote. But, is it possible to make this selection, *i.e., tune the model, in a principled way?* Ideally, we want to maximize the performance of our classifier on data *it hasn’t seen yet*.

There is also an important detail to mention about the process of tuning: we can, if we want to, split our overall training data up in multiple different ways, train and evaluate a classifier for each split, and then choose the parameter based on all of the different results. If we just split our overall training data once, our best parameter choice will depend strongly on whatever data was lucky enough to end up in the validation set. Perhaps using multiple different train / validation splits, we’ll get a better estimate of accuracy, which will lead to a better choice of the number of neighbours $K$ for the overall set of training data. 

This leads to the idea of cross-validation. In cross-validation, we split our overall training data into $C$ evenly-sized chunks, and then iteratively use 1 chunk as the validation set and combine the remaining $C−1$ chunks as the **training set.**

We can perform a cross-validation in Python using the `cross_validate` function from the `scikit-learn` package. The function returns a dictionary of float arrays of scores. To use this function, you have to identify model, the training set as well as specify the `cv` (the number of folds $C$, defaults to 5). We should set `return_train_score` to be `True` to return the training score as well. Before we use the function, we need to perform the pipeline analysis again. You can reuse the `fruit_preprocessor` and `knn_spec` objects you made earlier. Name your predictors as `X_val` and target as `y_val`.

*Assign your answer to an object called `fruit_vfold_score`*.

In [None]:
np.random.seed(2020)  # DO NOT REMOVE

# ___ = ___[["mass", "color_score"]]
# ___ = ___["fruit_name"]
# ___ = ___(fruit_preprocessor, knn_spec)
# ___ = cross_validate(___,  ___, ___, return_train_score=True,)

# your code here
raise NotImplementedError
pd.DataFrame(fruit_vfold_score)

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_vfold_score is None)).encode("utf-8")+b"1a393102aadde615").hexdigest() == "c036177cd361108c309dcf85f8e79a8acb00cadf", "type of fruit_vfold_score is None is not bool. fruit_vfold_score is None should be a bool"
assert sha1(str(fruit_vfold_score is None).encode("utf-8")+b"1a393102aadde615").hexdigest() == "52e271f79fe798b559a000b5749ebcbc565ee410", "boolean value of fruit_vfold_score is None is not correct"

assert sha1(str(type(fruit_vfold_score)).encode("utf-8")+b"a2156bffd1016d39").hexdigest() == "3bad1d029a791574c6a69786a3e0074327358d1b", "type of type(fruit_vfold_score) is not correct"

assert sha1(str(type(len(pd.DataFrame(fruit_vfold_score)))).encode("utf-8")+b"16db1850eefb7191").hexdigest() == "3d31554b0cf0b340a97af9547794079be6e3235d", "type of len(pd.DataFrame(fruit_vfold_score)) is not int. Please make sure it is int and not np.int64, etc. You can cast your value into an int using int()"
assert sha1(str(len(pd.DataFrame(fruit_vfold_score))).encode("utf-8")+b"16db1850eefb7191").hexdigest() == "1fc17cd04f8129d1d113a2884bb5ae1e150832af", "value of len(pd.DataFrame(fruit_vfold_score)) is not correct"

assert sha1(str(type(pd.DataFrame(fruit_vfold_score).shape)).encode("utf-8")+b"91a6373469f26eab").hexdigest() == "c8b575a62b020f00ec0c6446f1421244205be38b", "type of pd.DataFrame(fruit_vfold_score).shape is not tuple. pd.DataFrame(fruit_vfold_score).shape should be a tuple"
assert sha1(str(len(pd.DataFrame(fruit_vfold_score).shape)).encode("utf-8")+b"91a6373469f26eab").hexdigest() == "3ec29a3e13635ace8952e03eb37662005de2bdab", "length of pd.DataFrame(fruit_vfold_score).shape is not correct"
assert sha1(str(sorted(map(str, pd.DataFrame(fruit_vfold_score).shape))).encode("utf-8")+b"91a6373469f26eab").hexdigest() == "b3853e26da8a702e055ff8c5220c657dbad3f7c2", "values of pd.DataFrame(fruit_vfold_score).shape are not correct"
assert sha1(str(pd.DataFrame(fruit_vfold_score).shape).encode("utf-8")+b"91a6373469f26eab").hexdigest() == "e2079c2072d28c4d5ef197b9f2fae4f780e13b62", "order of elements of pd.DataFrame(fruit_vfold_score).shape is not correct"

assert sha1(str(type(X_val.columns.values)).encode("utf-8")+b"7195080b523653b1").hexdigest() == "267b89f684ca37570e6d827d833e899423fc2253", "type of X_val.columns.values is not correct"
assert sha1(str(X_val.columns.values).encode("utf-8")+b"7195080b523653b1").hexdigest() == "79963f0bfa49c9497f1e54614532aba74b751d76", "value of X_val.columns.values is not correct"

assert sha1(str(type(sum(X_val.color_score))).encode("utf-8")+b"a92f372419a22ad4").hexdigest() == "248f9bc80d772450dde4b2c81ee2951a18fae1a5", "type of sum(X_val.color_score) is not float. Please make sure it is float and not np.float64, etc. You can cast your value into a float using float()"
assert sha1(str(round(sum(X_val.color_score), 2)).encode("utf-8")+b"a92f372419a22ad4").hexdigest() == "1d30c965da3b7e66ab6c198c8cc6e6d51c193002", "value of sum(X_val.color_score) is not correct (rounded to 2 decimal places)"

assert sha1(str(type(sum(X_val.mass))).encode("utf-8")+b"a8a03bf152d3b99d").hexdigest() == "29797a9afbb16bddc3a52e9458576bbe21c5ff90", "type of sum(X_val.mass) is not int. Please make sure it is int and not np.int64, etc. You can cast your value into an int using int()"
assert sha1(str(sum(X_val.mass)).encode("utf-8")+b"a8a03bf152d3b99d").hexdigest() == "e214b881b56ffae447f7e2670b09e79209e7f474", "value of sum(X_val.mass) is not correct"

print('Success!')

**Question 3.2**
<br> {points: 1}

Now that we have ran a cross-validation on each train/validation split, one has to ask, how accurate was the classifier's validation across the folds? We can aggregate the *mean* and *standard error* of these scores from each folds. The standard error is essentially a measure of how uncertain we are in the mean value. 

*Assign your answer to an object called `fruit_metrics_mean` and `fruit_metrics_std`.*

In [None]:
# ___ = pd.DataFrame(___).___()
# ___ = pd.DataFrame(___).___()

# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_metrics_mean.shape)).encode("utf-8")+b"de9b6799e0c35ac8").hexdigest() == "c82abddec0b7e0d592c9b454673b01f43696391c", "type of fruit_metrics_mean.shape is not tuple. fruit_metrics_mean.shape should be a tuple"
assert sha1(str(len(fruit_metrics_mean.shape)).encode("utf-8")+b"de9b6799e0c35ac8").hexdigest() == "33cc1fd5256d8002dc5faa68047f991da6b6a2db", "length of fruit_metrics_mean.shape is not correct"
assert sha1(str(sorted(map(str, fruit_metrics_mean.shape))).encode("utf-8")+b"de9b6799e0c35ac8").hexdigest() == "95cc1d0d2605e0f957ed4bfe260da5fdcc9ca17e", "values of fruit_metrics_mean.shape are not correct"
assert sha1(str(fruit_metrics_mean.shape).encode("utf-8")+b"de9b6799e0c35ac8").hexdigest() == "791eaf106197377953221c66369a4221ca56db80", "order of elements of fruit_metrics_mean.shape is not correct"

assert sha1(str(type(fruit_metrics_std.shape)).encode("utf-8")+b"029439db9085f69c").hexdigest() == "3144efdb7ddf9c45177388f8c833558e5fa701e6", "type of fruit_metrics_std.shape is not tuple. fruit_metrics_std.shape should be a tuple"
assert sha1(str(len(fruit_metrics_std.shape)).encode("utf-8")+b"029439db9085f69c").hexdigest() == "c474c7c5f72ad08e881924b08460f41c618ea409", "length of fruit_metrics_std.shape is not correct"
assert sha1(str(sorted(map(str, fruit_metrics_std.shape))).encode("utf-8")+b"029439db9085f69c").hexdigest() == "4191e28e9f219ce58c3cdb67718d24c4032dc547", "values of fruit_metrics_std.shape are not correct"
assert sha1(str(fruit_metrics_std.shape).encode("utf-8")+b"029439db9085f69c").hexdigest() == "8620ec8695d40e8b9cd44d69402236a114716ddb", "order of elements of fruit_metrics_std.shape is not correct"

assert sha1(str(type(fruit_metrics_mean.train_score)).encode("utf-8")+b"0cda0b749871e97d").hexdigest() == "84159ff94d8e32c92ec8380bf90a2b38dbd6f3db", "type of fruit_metrics_mean.train_score is not correct"
assert sha1(str(fruit_metrics_mean.train_score).encode("utf-8")+b"0cda0b749871e97d").hexdigest() == "56bf5152f2bd657daa4bd35afc98772e16d5045e", "value of fruit_metrics_mean.train_score is not correct"

assert sha1(str(type(fruit_metrics_std.train_score)).encode("utf-8")+b"21521706643be23c").hexdigest() == "5de01093c4d2e0ca0201e718b132e92830afce6f", "type of fruit_metrics_std.train_score is not correct"
assert sha1(str(fruit_metrics_std.train_score).encode("utf-8")+b"21521706643be23c").hexdigest() == "f244427dddd0084376712d1e29cfb2d587d7e369", "value of fruit_metrics_std.train_score is not correct"

assert sha1(str(type(fruit_metrics_mean.test_score)).encode("utf-8")+b"c017966e2926fe55").hexdigest() == "997d044589621e150564cc46ce82e050d5b6cf53", "type of fruit_metrics_mean.test_score is not correct"
assert sha1(str(fruit_metrics_mean.test_score).encode("utf-8")+b"c017966e2926fe55").hexdigest() == "e75e1d20bbdb8c1e14f49199c8ca16a8d13c358c", "value of fruit_metrics_mean.test_score is not correct"

assert sha1(str(type(fruit_metrics_std.test_score)).encode("utf-8")+b"c4b5160758ae21a8").hexdigest() == "cc38519eed05da8e71cc812e0b768cde165c2875", "type of fruit_metrics_std.test_score is not correct"
assert sha1(str(fruit_metrics_std.test_score).encode("utf-8")+b"c4b5160758ae21a8").hexdigest() == "d77f9263abf3495177fb1d5df52389b95a9e6dfe", "value of fruit_metrics_std.test_score is not correct"

print('Success!')

## 4. Parameter value selection

Using a 5-fold cross-validation, we have established a prediction accuracy for our classifier. 

If we had to improve our classifier, we have to change the parameter: number of neighbours, $K$. Since cross-validation helps us evaluate the accuracy of our classifier, we can use cross-validation to calculate an accuracy for each value of $K$ in a reasonable range, and then pick the value of $K$ that gives us the best accuracy. 

The great thing about the `scikit-learn` package is that it provides two following build-in methods for tuning parameters. Each parameter in the model can be adjusted rather than given a specific value. We can define a set of values for each hyperparamters and find the best parameters in this set.

- Exhaustive grid search
    - [sklearn.model_selection.GridSearchCV](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)
    - A user specifies a set of values for each hyperparameter.
    - The method considers product of the sets and then evaluates each combination one by one.
    
- Randomized hyperparameter optimization
    - [sklearn.model_selection.RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
    - Samples configurations at random until certain budget (e.g., time) is exhausted

**Question 4.0**
<br> {points: 1}

Create a new K-nearest neighbor model specification but instead of specifying a particular value for the `n_neighbors` argument, try `GridSearchCV` and `RandomizedSearchCV`. Before we use `GridSearchCV` and `RandomizedSearchCV` to run 4-fold cross-validations to tune hyperparameters, we should define the parameter grid by passing the set of values for each parameters that you would like to tune. We would also need to redefine the pipeline to use default values for parameters. 

*Assign your answer to an object called `knn_tune_grid` and `knn_tune_random` respectively.* 

In [None]:
### Run this cell
param_grid = {
    "kneighborsclassifier__n_neighbors": range(2, 15, 1),
}
fruit_tune_pipe = make_pipeline(fruit_preprocessor, KNeighborsClassifier())

In [None]:
# ___ = GridSearchCV(
#     ___, ___, cv=__,
# )
# ___ = RandomizedSearchCV(
#     ___, ___, n_iter=10, cv=_, random_state=2023
# )

# your code here
raise NotImplementedError
knn_tune_grid
knn_tune_random

In [None]:
from hashlib import sha1
assert sha1(str(type(knn_tune_grid is None)).encode("utf-8")+b"34d0fc9f53599c1f").hexdigest() == "9b949efc63813df04cf8103cf4e4a96c29b19d6f", "type of knn_tune_grid is None is not bool. knn_tune_grid is None should be a bool"
assert sha1(str(knn_tune_grid is None).encode("utf-8")+b"34d0fc9f53599c1f").hexdigest() == "de60092a2e7180507aeba0a00ebe13dc671ac37d", "boolean value of knn_tune_grid is None is not correct"

assert sha1(str(type(type(knn_tune_grid))).encode("utf-8")+b"f6311d45d3701c42").hexdigest() == "b6c085eb718f5c3a956b31b2f7ab3a76a22d2145", "type of type(knn_tune_grid) is not correct"
assert sha1(str(type(knn_tune_grid)).encode("utf-8")+b"f6311d45d3701c42").hexdigest() == "92fc2b16199cf43e8545727fcffc158c8c01b47e", "value of type(knn_tune_grid) is not correct"

assert sha1(str(type(knn_tune_grid.param_grid.keys())).encode("utf-8")+b"3a6612fff021a686").hexdigest() == "783d7f050c6d176c057076a0c2d8d45159c6dd8e", "type of knn_tune_grid.param_grid.keys() is not correct"
assert sha1(str(knn_tune_grid.param_grid.keys()).encode("utf-8")+b"3a6612fff021a686").hexdigest() == "1af9c007d8a80a5f9d50d2d1913a569011f79ebf", "value of knn_tune_grid.param_grid.keys() is not correct"

assert sha1(str(type(knn_tune_grid.estimator.named_steps.keys())).encode("utf-8")+b"fc2fbb27922aef87").hexdigest() == "5fad37d96cbdb70d681899a87447afb975349741", "type of knn_tune_grid.estimator.named_steps.keys() is not correct"
assert sha1(str(knn_tune_grid.estimator.named_steps.keys()).encode("utf-8")+b"fc2fbb27922aef87").hexdigest() == "ba00008a5120364b839f52a01359ec650531160a", "value of knn_tune_grid.estimator.named_steps.keys() is not correct"

assert sha1(str(type(knn_tune_random is None)).encode("utf-8")+b"4a293e4ab2f3b9f5").hexdigest() == "9ecd8e41e1cea197674536d0c02c3418994d658e", "type of knn_tune_random is None is not bool. knn_tune_random is None should be a bool"
assert sha1(str(knn_tune_random is None).encode("utf-8")+b"4a293e4ab2f3b9f5").hexdigest() == "e6cbb125d485fd1987817cc49ba882b565c2110f", "boolean value of knn_tune_random is None is not correct"

assert sha1(str(type(type(knn_tune_random))).encode("utf-8")+b"575f4428cd0a6ea3").hexdigest() == "b1b838a8f9e92142ad3ef91cc9ef2af311b62aeb", "type of type(knn_tune_random) is not correct"
assert sha1(str(type(knn_tune_random)).encode("utf-8")+b"575f4428cd0a6ea3").hexdigest() == "3d1e38f7b8e45da897696126816a3ffee14ffbd3", "value of type(knn_tune_random) is not correct"

assert sha1(str(type(knn_tune_random.param_distributions.keys())).encode("utf-8")+b"d3444f420529de36").hexdigest() == "262ee18efc9422c750d4f8f1f02700f7203bc79c", "type of knn_tune_random.param_distributions.keys() is not correct"
assert sha1(str(knn_tune_random.param_distributions.keys()).encode("utf-8")+b"d3444f420529de36").hexdigest() == "cb7bd600e280c901d215c7b7953076a38d4981be", "value of knn_tune_random.param_distributions.keys() is not correct"

assert sha1(str(type(knn_tune_random.estimator.named_steps.keys())).encode("utf-8")+b"8751d21acf4056b4").hexdigest() == "2fb29fca97035b10e51c2e64b7ab25a73f191538", "type of knn_tune_random.estimator.named_steps.keys() is not correct"
assert sha1(str(knn_tune_random.estimator.named_steps.keys()).encode("utf-8")+b"8751d21acf4056b4").hexdigest() == "b65f26662e09f7a4bc2f2bd83d30f10a90a0053a", "value of knn_tune_random.estimator.named_steps.keys() is not correct"

print('Success!')

**Question 4.1**
<br>{points: 1}

Now, let's fit the models to the data. Name the predictors as `X_tune` and target as `y_tune`.

*Assign your tuned models to objects called `knn_model_grid` and `knn_model_random`.*

Next, from `knn_model_*`, find out the `cv_results_` and save it in a dataframe. 

*Assign your answer to an object called `accuracies_grid` and `accuracies_random` respectively.*

In [None]:
# ___ = ___[[___, ___]]
# ___ = ___[___]

# ___ = ___.fit(___, ___)
# ___ = ___.fit(___, ___)

# ___ = pd.DataFrame(___.cv_results_)
# ___ = pd.DataFrame(___.cv_results_)

# your code here
raise NotImplementedError
knn_model_grid
knn_model_random
accuracies_grid
accuracies_random

In [None]:
from hashlib import sha1
assert sha1(str(type(type(knn_model_grid))).encode("utf-8")+b"cbc2aa9ebfb76239").hexdigest() == "e520a22f302950d2d33c1b411ac5b71416d1072c", "type of type(knn_model_grid) is not correct"
assert sha1(str(type(knn_model_grid)).encode("utf-8")+b"cbc2aa9ebfb76239").hexdigest() == "f62333ca0d5593fa06aac543e21930180a20f7b3", "value of type(knn_model_grid) is not correct"

assert sha1(str(type(type(knn_model_random))).encode("utf-8")+b"6f865cd8112c80ae").hexdigest() == "9b31442cdf6fba21c559a26004a1ecb503474876", "type of type(knn_model_random) is not correct"
assert sha1(str(type(knn_model_random)).encode("utf-8")+b"6f865cd8112c80ae").hexdigest() == "b34e56552a657a21d6928e25e4cf035d2ef50028", "value of type(knn_model_random) is not correct"

assert sha1(str(type(accuracies_grid is None)).encode("utf-8")+b"f69495a56ceab146").hexdigest() == "3c8ff32b26c02be050d5a4229ebf9e348a814818", "type of accuracies_grid is None is not bool. accuracies_grid is None should be a bool"
assert sha1(str(accuracies_grid is None).encode("utf-8")+b"f69495a56ceab146").hexdigest() == "e9342ca9eb3305685876b413a86494cbc60a7d70", "boolean value of accuracies_grid is None is not correct"

assert sha1(str(type(accuracies_grid)).encode("utf-8")+b"1f15d8f5979caf84").hexdigest() == "1babbc3735214129fb539f4da7edefedce60f02d", "type of type(accuracies_grid) is not correct"

assert sha1(str(type(accuracies_grid.shape)).encode("utf-8")+b"a0dcb1ebf833583e").hexdigest() == "e98aab4531265caf1fe1c2b5fe4e02a1c5c00020", "type of accuracies_grid.shape is not tuple. accuracies_grid.shape should be a tuple"
assert sha1(str(len(accuracies_grid.shape)).encode("utf-8")+b"a0dcb1ebf833583e").hexdigest() == "9402e9d0f083d493257857b7397c6082e75f8ced", "length of accuracies_grid.shape is not correct"
assert sha1(str(sorted(map(str, accuracies_grid.shape))).encode("utf-8")+b"a0dcb1ebf833583e").hexdigest() == "4b4a52570159aa9d8c991796b592f9348af365fa", "values of accuracies_grid.shape are not correct"
assert sha1(str(accuracies_grid.shape).encode("utf-8")+b"a0dcb1ebf833583e").hexdigest() == "7c6f09b60dd1c8ac3399a597517d1f69ae0d510b", "order of elements of accuracies_grid.shape is not correct"

assert sha1(str(type(sum(accuracies_grid.mean_test_score))).encode("utf-8")+b"665f245486d54ba9").hexdigest() == "bbb14fbfcbe77a152e7e2ef8bc35a549154d680f", "type of sum(accuracies_grid.mean_test_score) is not float. Please make sure it is float and not np.float64, etc. You can cast your value into a float using float()"
assert sha1(str(round(sum(accuracies_grid.mean_test_score), 2)).encode("utf-8")+b"665f245486d54ba9").hexdigest() == "64c2ff09cb093fc6ac71d9ca9c72f98aa3d55bc2", "value of sum(accuracies_grid.mean_test_score) is not correct (rounded to 2 decimal places)"

assert sha1(str(type(sum(accuracies_grid.std_test_score))).encode("utf-8")+b"c8c19bf50dfa94e0").hexdigest() == "9349beb32cfa8817d41cd5c34feafb874f0d8a0a", "type of sum(accuracies_grid.std_test_score) is not float. Please make sure it is float and not np.float64, etc. You can cast your value into a float using float()"
assert sha1(str(round(sum(accuracies_grid.std_test_score), 2)).encode("utf-8")+b"c8c19bf50dfa94e0").hexdigest() == "802125abd28c735a595e13ec13c217b4d02cc577", "value of sum(accuracies_grid.std_test_score) is not correct (rounded to 2 decimal places)"

assert sha1(str(type(sum(accuracies_grid.param_kneighborsclassifier__n_neighbors))).encode("utf-8")+b"57ad04681e7988cb").hexdigest() == "9f64f5bf079432b7a16f1c75e1ede102f9928ee3", "type of sum(accuracies_grid.param_kneighborsclassifier__n_neighbors) is not int. Please make sure it is int and not np.int64, etc. You can cast your value into an int using int()"
assert sha1(str(sum(accuracies_grid.param_kneighborsclassifier__n_neighbors)).encode("utf-8")+b"57ad04681e7988cb").hexdigest() == "4ab460a87e8acd68a90ec007b7a541780397d51e", "value of sum(accuracies_grid.param_kneighborsclassifier__n_neighbors) is not correct"

assert sha1(str(type(accuracies_random is None)).encode("utf-8")+b"a461ad08ddd63ced").hexdigest() == "e5e26ca3d62aef625464c996bcb82f8ebc807491", "type of accuracies_random is None is not bool. accuracies_random is None should be a bool"
assert sha1(str(accuracies_random is None).encode("utf-8")+b"a461ad08ddd63ced").hexdigest() == "7fdb6693b495c9335f4e8db288a50a60cbaa46d0", "boolean value of accuracies_random is None is not correct"

assert sha1(str(type(accuracies_random)).encode("utf-8")+b"b5f625d1323e05c2").hexdigest() == "6cf620d531f5ea837143c60d4c0503a6171f7e83", "type of type(accuracies_random) is not correct"

assert sha1(str(type(accuracies_random.shape)).encode("utf-8")+b"dbefcf285e09e894").hexdigest() == "42deead6eb39552c7c6242a4864220b3fc275183", "type of accuracies_random.shape is not tuple. accuracies_random.shape should be a tuple"
assert sha1(str(len(accuracies_random.shape)).encode("utf-8")+b"dbefcf285e09e894").hexdigest() == "536ce28744857df895866711ba12d6669a918839", "length of accuracies_random.shape is not correct"
assert sha1(str(sorted(map(str, accuracies_random.shape))).encode("utf-8")+b"dbefcf285e09e894").hexdigest() == "e62e764970fb38653c05c0a8980c89b6a902badf", "values of accuracies_random.shape are not correct"
assert sha1(str(accuracies_random.shape).encode("utf-8")+b"dbefcf285e09e894").hexdigest() == "159c2a23ca832ae5edf3c692b1bb0fabfaf91534", "order of elements of accuracies_random.shape is not correct"

assert sha1(str(type(sum(accuracies_random.mean_test_score))).encode("utf-8")+b"6bf1114fdb5fac45").hexdigest() == "764f32f397e03a272de287b3a0260cc59aa303f4", "type of sum(accuracies_random.mean_test_score) is not float. Please make sure it is float and not np.float64, etc. You can cast your value into a float using float()"
assert sha1(str(round(sum(accuracies_random.mean_test_score), 2)).encode("utf-8")+b"6bf1114fdb5fac45").hexdigest() == "6d333aa60b0f1f5b8b4c2059c0a4389fed4692fd", "value of sum(accuracies_random.mean_test_score) is not correct (rounded to 2 decimal places)"

assert sha1(str(type(sum(accuracies_random.std_test_score))).encode("utf-8")+b"0006af941b0d77ae").hexdigest() == "6664dbf6046656713484fb69ded2deb4b6e3edb7", "type of sum(accuracies_random.std_test_score) is not float. Please make sure it is float and not np.float64, etc. You can cast your value into a float using float()"
assert sha1(str(round(sum(accuracies_random.std_test_score), 2)).encode("utf-8")+b"0006af941b0d77ae").hexdigest() == "13970d0153ca98c34fcbcdd83b378ae70a2e4cf3", "value of sum(accuracies_random.std_test_score) is not correct (rounded to 2 decimal places)"

assert sha1(str(type(sum(accuracies_random.param_kneighborsclassifier__n_neighbors))).encode("utf-8")+b"6c9bf19570fce7cc").hexdigest() == "85c8e48635993821637f2f76feab504517fe7d85", "type of sum(accuracies_random.param_kneighborsclassifier__n_neighbors) is not int. Please make sure it is int and not np.int64, etc. You can cast your value into an int using int()"
assert sha1(str(sum(accuracies_random.param_kneighborsclassifier__n_neighbors)).encode("utf-8")+b"6c9bf19570fce7cc").hexdigest() == "36141693acde6d8764ea96fb6422189cb4622911", "value of sum(accuracies_random.param_kneighborsclassifier__n_neighbors) is not correct"

print('Success!')

**Question 4.2**
<br>{points: 1} 


Now, let's find the best value of the number of neighbors. 

Create a line plot using the `accuracies_*` dataframes with `param_kneighborsclassifier__n_neighbors` on the x-axis and the `mean_test_score` on the y-axis. 

*Assign your answer to an object called `accuracy_versus_k_grid` and `accuracy_versus_k_random`.*

In [None]:
# ___ = (
#     alt.Chart(___, title="Grid Search")
#     .mark_line(point=True)
#     .encode(
#         x=alt.X(
#             ___,
#             title="Neighbors",
#             scale=alt.Scale(zero=False),
#         ),
#         y=alt.Y(
#             ___, 
#             title="Mean Test Score", 
#             scale=alt.Scale(zero=False)
#         ),
#     )
#     .configure_axis(labelFontSize=10, titleFontSize=15)
#     .properties(width=400, height=300)
# )

# ___ = (
#     alt.Chart(___, title="Randomized Search")
#     .mark_line(point=True)
#     .encode(
#         x=alt.X(
#             ___,
#             title="Neighbors",
#             scale=alt.Scale(zero=False),
#         ),
#         y=alt.Y(
#             ___, 
#             title="Mean Test Score", 
#             scale=alt.Scale(zero=False)
#         ),
#     )
#     .configure_axis(labelFontSize=10, titleFontSize=15)
#     .properties(width=400, height=300)
# )


# your code here
raise NotImplementedError

In [None]:
# run the cell
accuracy_versus_k_grid

In [None]:
# run the cell
accuracy_versus_k_random

In [None]:
from hashlib import sha1
assert sha1(str(type(accuracy_versus_k_grid is None)).encode("utf-8")+b"49f2da29d0d97921").hexdigest() == "3523b3d6b36989bd3570ab4a7134147162674722", "type of accuracy_versus_k_grid is None is not bool. accuracy_versus_k_grid is None should be a bool"
assert sha1(str(accuracy_versus_k_grid is None).encode("utf-8")+b"49f2da29d0d97921").hexdigest() == "6405edbbe7734d237985c3e6eede3c64bda7f4ea", "boolean value of accuracy_versus_k_grid is None is not correct"

assert sha1(str(type(accuracy_versus_k_grid.encoding.x.field)).encode("utf-8")+b"e62ec94e7e94766a").hexdigest() == "e365112fc60623961e3417cf90f5d6b34c2c15e2", "type of accuracy_versus_k_grid.encoding.x.field is not str. accuracy_versus_k_grid.encoding.x.field should be an str"
assert sha1(str(len(accuracy_versus_k_grid.encoding.x.field)).encode("utf-8")+b"e62ec94e7e94766a").hexdigest() == "f138329c34ad27888bdfda5751849853a7b52dd6", "length of accuracy_versus_k_grid.encoding.x.field is not correct"
assert sha1(str(accuracy_versus_k_grid.encoding.x.field.lower()).encode("utf-8")+b"e62ec94e7e94766a").hexdigest() == "f7b1b5be09f2a54ad7ccf6cfb875379af6539600", "value of accuracy_versus_k_grid.encoding.x.field is not correct"
assert sha1(str(accuracy_versus_k_grid.encoding.x.field).encode("utf-8")+b"e62ec94e7e94766a").hexdigest() == "f7b1b5be09f2a54ad7ccf6cfb875379af6539600", "correct string value of accuracy_versus_k_grid.encoding.x.field but incorrect case of letters"

assert sha1(str(type(accuracy_versus_k_grid.encoding.y.field)).encode("utf-8")+b"1d8bc7ae8f1bef28").hexdigest() == "918db73f6fda1e25833bab21a9ae6a7688cbb1f1", "type of accuracy_versus_k_grid.encoding.y.field is not str. accuracy_versus_k_grid.encoding.y.field should be an str"
assert sha1(str(len(accuracy_versus_k_grid.encoding.y.field)).encode("utf-8")+b"1d8bc7ae8f1bef28").hexdigest() == "397dc8a1e8369998caa1857c4639461bf3c33c3a", "length of accuracy_versus_k_grid.encoding.y.field is not correct"
assert sha1(str(accuracy_versus_k_grid.encoding.y.field.lower()).encode("utf-8")+b"1d8bc7ae8f1bef28").hexdigest() == "743e440f635ceb1ba130f17a0d77caf39777f60d", "value of accuracy_versus_k_grid.encoding.y.field is not correct"
assert sha1(str(accuracy_versus_k_grid.encoding.y.field).encode("utf-8")+b"1d8bc7ae8f1bef28").hexdigest() == "743e440f635ceb1ba130f17a0d77caf39777f60d", "correct string value of accuracy_versus_k_grid.encoding.y.field but incorrect case of letters"

assert sha1(str(type(accuracy_versus_k_grid.mark.type)).encode("utf-8")+b"f0e47bde55f7a86a").hexdigest() == "f89a0b712690d361b0e2d11a526d8174b4b77f25", "type of accuracy_versus_k_grid.mark.type is not str. accuracy_versus_k_grid.mark.type should be an str"
assert sha1(str(len(accuracy_versus_k_grid.mark.type)).encode("utf-8")+b"f0e47bde55f7a86a").hexdigest() == "d9602269a18e9bb41e6bb167d115ed165ba1abda", "length of accuracy_versus_k_grid.mark.type is not correct"
assert sha1(str(accuracy_versus_k_grid.mark.type.lower()).encode("utf-8")+b"f0e47bde55f7a86a").hexdigest() == "79e22e69ae62e6e28a481cca6e1c374c1696c781", "value of accuracy_versus_k_grid.mark.type is not correct"
assert sha1(str(accuracy_versus_k_grid.mark.type).encode("utf-8")+b"f0e47bde55f7a86a").hexdigest() == "79e22e69ae62e6e28a481cca6e1c374c1696c781", "correct string value of accuracy_versus_k_grid.mark.type but incorrect case of letters"

assert sha1(str(type(accuracy_versus_k_grid.mark['point'])).encode("utf-8")+b"da18fbfb88ce2ea2").hexdigest() == "ab832c3ea5a2439c6cca95fe9fd512fd96512265", "type of accuracy_versus_k_grid.mark['point'] is not bool. accuracy_versus_k_grid.mark['point'] should be a bool"
assert sha1(str(accuracy_versus_k_grid.mark['point']).encode("utf-8")+b"da18fbfb88ce2ea2").hexdigest() == "cdc62b907e551e89230d50ad92cf9847614d17f6", "boolean value of accuracy_versus_k_grid.mark['point'] is not correct"

assert sha1(str(type(accuracy_versus_k_random is None)).encode("utf-8")+b"51ee835feacb71a0").hexdigest() == "20f222632a325bdd59af48aecd01c5260d7fa354", "type of accuracy_versus_k_random is None is not bool. accuracy_versus_k_random is None should be a bool"
assert sha1(str(accuracy_versus_k_random is None).encode("utf-8")+b"51ee835feacb71a0").hexdigest() == "6f3dc1c30529866d7d121907675fb51000873922", "boolean value of accuracy_versus_k_random is None is not correct"

assert sha1(str(type(accuracy_versus_k_random.encoding.x.field)).encode("utf-8")+b"6e3a7bc8900cc90b").hexdigest() == "14644ddec7aac27385f400a115f71a6226740cbd", "type of accuracy_versus_k_random.encoding.x.field is not str. accuracy_versus_k_random.encoding.x.field should be an str"
assert sha1(str(len(accuracy_versus_k_random.encoding.x.field)).encode("utf-8")+b"6e3a7bc8900cc90b").hexdigest() == "0265a2fb74d3e43d5e5c3d1bdd8e78c4fb8dd8a6", "length of accuracy_versus_k_random.encoding.x.field is not correct"
assert sha1(str(accuracy_versus_k_random.encoding.x.field.lower()).encode("utf-8")+b"6e3a7bc8900cc90b").hexdigest() == "339c55775bb82c21aaaf5e707d35e0273ebe9d12", "value of accuracy_versus_k_random.encoding.x.field is not correct"
assert sha1(str(accuracy_versus_k_random.encoding.x.field).encode("utf-8")+b"6e3a7bc8900cc90b").hexdigest() == "339c55775bb82c21aaaf5e707d35e0273ebe9d12", "correct string value of accuracy_versus_k_random.encoding.x.field but incorrect case of letters"

assert sha1(str(type(accuracy_versus_k_random.encoding.y.field)).encode("utf-8")+b"9c3e67cd14bd032b").hexdigest() == "5c1f6f022bdc570eac0a8e10a23f0187919546a9", "type of accuracy_versus_k_random.encoding.y.field is not str. accuracy_versus_k_random.encoding.y.field should be an str"
assert sha1(str(len(accuracy_versus_k_random.encoding.y.field)).encode("utf-8")+b"9c3e67cd14bd032b").hexdigest() == "6bc2bee41c371c68b1e0750056355cf987a110b6", "length of accuracy_versus_k_random.encoding.y.field is not correct"
assert sha1(str(accuracy_versus_k_random.encoding.y.field.lower()).encode("utf-8")+b"9c3e67cd14bd032b").hexdigest() == "1091d38371294ed5b88ccd9fe94cf41f6ad355e0", "value of accuracy_versus_k_random.encoding.y.field is not correct"
assert sha1(str(accuracy_versus_k_random.encoding.y.field).encode("utf-8")+b"9c3e67cd14bd032b").hexdigest() == "1091d38371294ed5b88ccd9fe94cf41f6ad355e0", "correct string value of accuracy_versus_k_random.encoding.y.field but incorrect case of letters"

assert sha1(str(type(accuracy_versus_k_random.mark.type)).encode("utf-8")+b"2a8c9a4460ee7a4b").hexdigest() == "30dbc0d04218b0eb7e06f4e2591de07ede399b25", "type of accuracy_versus_k_random.mark.type is not str. accuracy_versus_k_random.mark.type should be an str"
assert sha1(str(len(accuracy_versus_k_random.mark.type)).encode("utf-8")+b"2a8c9a4460ee7a4b").hexdigest() == "8a74b1797dab41f0cfa6951fa719b7262d807f88", "length of accuracy_versus_k_random.mark.type is not correct"
assert sha1(str(accuracy_versus_k_random.mark.type.lower()).encode("utf-8")+b"2a8c9a4460ee7a4b").hexdigest() == "9109fd0fed79556ac60b007c8d1a8cb4ee15f8a4", "value of accuracy_versus_k_random.mark.type is not correct"
assert sha1(str(accuracy_versus_k_random.mark.type).encode("utf-8")+b"2a8c9a4460ee7a4b").hexdigest() == "9109fd0fed79556ac60b007c8d1a8cb4ee15f8a4", "correct string value of accuracy_versus_k_random.mark.type but incorrect case of letters"

assert sha1(str(type(accuracy_versus_k_random.mark['point'])).encode("utf-8")+b"f09265ae54bb4ef0").hexdigest() == "9522b04f91b0429cd7f98f70cd551f7448ad662e", "type of accuracy_versus_k_random.mark['point'] is not bool. accuracy_versus_k_random.mark['point'] should be a bool"
assert sha1(str(accuracy_versus_k_random.mark['point']).encode("utf-8")+b"f09265ae54bb4ef0").hexdigest() == "72674d08069e5113755292e9edbe42580e53d745", "boolean value of accuracy_versus_k_random.mark['point'] is not correct"

print('Success!')

From the plots above, we can see that $K = 2$, $3$, or $4$ provides the highest accuracy. Larger $K$ values result in a reduced accuracy estimate. Remember: the values you see on this plot are estimates of the true accuracy of our classifier. Although the $K = 2$, $3$ or $4$ value is higher than the others on this plot, that doesn’t mean the classifier is necessarily more accurate with this parameter value! 

Great, now you have completed a full analysis with cross-validation using the `scikit-learn` package! For your information, we can choose any number of folds and typically, the more we use the better our accuracy estimate will be (lower standard error). However, more folds would mean a greater computation time. In practice, $cv$ is chosen to be either 5 or 10. 