In [1]:
%%capture
# capture will not print in notebook

import os
import sys
ENV_COLAB = 'google.colab' in sys.modules

if ENV_COLAB:
    ## install modules
    !python -m pip install dask[complete] --upgrade
    !pip install dask-ml[complete]

    ## print
    print('Environment: Google Colaboratory.')

# NOTE: If we update modules in gcolab, we need to restart runtime.

In [9]:
#Imports
import numpy as np
import pandas as pd
import seaborn as sns

from sklearn import datasets
from sklearn.model_selection import train_test_split

import dask
import dask_ml
import dask.array as da
from dask.distributed import Client, LocalCluster
from dask_ml.xgboost import XGBRegressor

print([(x.__name__,x.__version__) for x in [dask, dask_ml]])

# data
SEED = 100
X,y = datasets.load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X,y,
                                test_size=0.2,random_state=SEED)
da_Xtrain = da.from_array(X_train)
da_ytrain = da.from_array(y_train)
da_Xtest = da.from_array(X_test)
da_ytest = da.from_array(y_test)

# modelling
cluster = LocalCluster(processes=False,scheduler_port=1234)
client = Client(cluster)

est = XGBRegressor(random_state=SEED)
est.fit(da_Xtrain, da_ytrain)

da_txpreds = est.predict(da_Xtest)

[('dask', '2.20.0'), ('dask_ml', '1.5.0')]


Perhaps you already have a cluster running?
Hosting the HTTP server on port 36399 instead
  http_address["port"], self.http_server.port


In [10]:
da_txpreds

Unnamed: 0,Array,Chunk
Bytes,408 B,408 B
Shape,"(102,)","(102,)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 408 B 408 B Shape (102,) (102,) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",102  1,

Unnamed: 0,Array,Chunk
Bytes,408 B,408 B
Shape,"(102,)","(102,)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [11]:
type(da_txpreds)

dask.array.core.Array

In [14]:
tx_preds = da_txpreds.compute()
type(tx_preds)

numpy.ndarray

In [15]:
from sklearn import metrics

In [19]:
rmse = metrics.mean_squared_error(y_test, tx_preds)**0.5
r2 = metrics.r2_score(y_test, tx_preds)

print('RMSE     : ', rmse)
print('R-Squared: ', r2)

RMSE     :  3.148289197036133
R-Squared:  0.8973881297021495
