From 2d91b84da95c5146baa3828c923572949266bf68 Mon Sep 17 00:00:00 2001 From: Charlotte Godley Date: Tue, 3 Dec 2019 15:20:44 +0000 Subject: [PATCH] NFC: add model call to sklearn tutorial --- sklearn-gridsearch/grid-search.ipynb | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/sklearn-gridsearch/grid-search.ipynb b/sklearn-gridsearch/grid-search.ipynb index 112b76d..64a083e 100644 --- a/sklearn-gridsearch/grid-search.ipynb +++ b/sklearn-gridsearch/grid-search.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "# Hyperparamerter optimisation with Dotscience\n", "\n", @@ -17,6 +15,7 @@ "metadata": {}, "outputs": [], "source": [ + "import sklearn\n", "from sklearn import datasets\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.model_selection import GridSearchCV\n", @@ -28,9 +27,7 @@ }, { "cell_type": "markdown", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "Let's load the digits dataset and apply a classifier on this data." ] @@ -54,9 +51,7 @@ }, { "cell_type": "markdown", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "Split the dataset in two equal parts\n" ] @@ -73,9 +68,7 @@ }, { "cell_type": "markdown", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "Set the parameters by cross-validation" ] @@ -93,9 +86,7 @@ }, { "cell_type": "markdown", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "Let's tune hyper-parameters for the scores. \n", "\n", @@ -126,6 +117,7 @@ " print()\n", " means = clf.cv_results_['mean_test_score']\n", " stds = clf.cv_results_['std_test_score']\n", + " ds.sklearn_model(sklearn, clf, \"grid search\", \"model_%s.joblib\" % score)\n", " for mean, std, params in zip(means, stds, clf.cv_results_['params']):\n", " ds.start()\n", " ds.add_parameters(**params)\n", @@ -143,7 +135,6 @@ " print()\n", " y_true, y_pred = y_test, clf.predict(X_test)\n", " print(classification_report(y_true, y_pred))\n", - " print()\n", "\n", "# Note the problem is too easy: the hyperparameter plateau is too flat and the\n", "# output model is the same for precision and recall with ties in quality." @@ -166,9 +157,10 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.2" + "version": "3.6.8" } }, "nbformat": 4, - "nbformat_minor": 2 -} \ No newline at end of file + "nbformat_minor": 4 +} + \ No newline at end of file