## Machine Learning: Train on Large Datasets

Most estimators in scikit-learn are designed to work on in-memory arrays. Training with larger datasets may require different algorithms.

All of the algorithms implemented in Dask-ML work well on larger than memory datasets, which you might store in a [dask array](http://dask.pydata.org/en/latest/array.html) or [dataframe](http://dask.pydata.org/en/latest/dataframe.html).

*Note: this notebook requires [dask-ml](http://ml.dask.org/)*

In [None]:
from dask.distributed import Client
client = Client()
client

## Create a large random dataset

In [None]:
import dask
from distributed.utils import format_bytes

import dask_ml.cluster
import dask_ml.datasets

In this example, we'll generate a large random dask array on our cluster. In practice,
we would load the data from our data store (SQL table, HDFS, cloud storage).

In [None]:
X, y = dask_ml.datasets.make_blobs(
    n_samples=100_000_000,
    n_features=50,
    centers=3,
    chunks=500_000,
)

format_bytes(X.nbytes)

In [None]:
X = X.persist()

## Cluster with K-Means ||

We'll use the k-means implemented in Dask-ML to cluster the points. It uses the `k-means||` (read: "k-means parallel") initialization algorithm, which scales better than `k-means++`. All of the computation, both during and after initialization, can be done in parallel.

In [None]:
km = dask_ml.cluster.KMeans(n_clusters=3, init_max_iter=2, oversampling_factor=10, random_state=0)

%time km.fit(X)

During training, you'll notice some distinct phases

* Initialization: finding the best intital clusters
* Expectation Maximization: Alternating between finding the closest cluster center between each point, and finding the new center of all points closest to a cluster
* Finalization: computing statistics like `inertia`

We'll plot a sample of points, colored by the cluster each falls into.

## Inspect Results

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.scatter(X[::20000, 0], X[::20000, 1], marker='.', c=km.labels_[::20000],
           cmap='viridis', alpha=0.25);

For all the estimators implemented in Dask-ML, see the [API documentation](http://ml.dask.org/modules/api.html).