# LightGBM + Dask

<table>
    <tr>
        <td>
            <img src="./_img/lightgbm.svg" width="300">
        </td>
        <td>
            <img src="./_img/dask-horizontal.svg" width="300">
        </td>
        <td>
            <img src="./_img/aws.svg" width="150">
        </td>
    </tr>
</table>

This notebook shows how to use `lightgbm.dask` to train a LightGBM model on data stored as a [Dask DataFrame](https://docs.dask.org/en/latest/dataframe.html) or [Dask Array](https://docs.dask.org/en/latest/array.html).

It uses `FargateCluster` from [`dask-cloudprovider`](https://github.com/dask/dask-cloudprovider) to create a distributed cluster.

<hr>

## Set up a cluster on AWS Fargate

In [2]:
import json

with open("../ecr-details.json", "r") as f:
    ecr_details = json.loads(f.read())

CONTAINER_IMAGE = ecr_details["repository"]["repositoryUri"] + ":1"
print(f"scheduler and worker image: {CONTAINER_IMAGE}")

scheduler and worker image: public.ecr.aws/m0z8a6o8/dask-lgb-test-4468592b-5204-4e3b-ad80-7c5ae698472a-cluster:1


Before proceeding, set up your AWS credentials. If you're unsure how to do this, see [the AWS docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html).

In [3]:
import os

os.environ["AWS_DEFAULT_REGION"] = "us-west-2"

Create a cluster with 3 workers. See https://cloudprovider.dask.org/en/latest/aws.html#dask_cloudprovider.aws.FargateCluster for more options.

In [4]:
from dask_cloudprovider.aws import FargateCluster
from dask.distributed import Client

n_workers = 3
cluster = FargateCluster(
    image=CONTAINER_IMAGE,
    worker_cpu=512,
    worker_mem=4096,
    n_workers=n_workers,
    fargate_use_private_ip=False,
    scheduler_timeout="15 minutes",
    find_address_timeout=60 * 10,
)
client = Client(cluster)
client.wait_for_workers(n_workers)

  next(self.gen)

+---------+---------------+---------------+---------------+
| Package | client        | scheduler     | workers       |
+---------+---------------+---------------+---------------+
| blosc   | 1.10.1        | 1.9.2         | 1.9.2         |
| msgpack | 1.0.2         | 1.0.0         | 1.0.0         |
| numpy   | 1.19.5        | 1.18.5        | 1.18.5        |
| python  | 3.8.5.final.0 | 3.8.0.final.0 | 3.8.0.final.0 |
| tornado | 6.0.4         | 6.1           | 6.1           |
+---------+---------------+---------------+---------------+
Notes: 
-  msgpack: Variation is ok, as long as everything is above 0.6
distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
asyncio.exceptions.CancelledError


In [None]:
print(f"View the dashboard: {cluster.dashboard_link}")

Click the link above to view a diagnostic dashboard while you run the training code below.

<hr>

## Train a model

In [None]:
import dask.array as da
from dask.distributed import wait
from lightgbm.dask import DaskLGBMRegressor

num_rows = 1e6
num_features = 1e2
num_partitions = 10
rows_per_chunk = num_rows / num_partitions

data = da.random.random((num_rows, num_features), (rows_per_chunk, num_features))

labels = da.random.random((num_rows, 1), (rows_per_chunk, 1))

Right now, the Dask Arrays `data` and `labels` are lazy. Before training, you can force the cluster to compute them by running `.persist()` and then wait for that computation to finish by `wait()`-ing on them.

In [None]:
data = data.persist()
labels = labels.persist()
_ = wait(data)
_ = wait(labels)

With the data set up on the workers, train a model. `lightgbm.dask.DaskLGBMRegressor` has an interface that tries to stay as close as possible to the non-Dask scikit-learn interface to LightGBM (`lightgbm.sklearn.LGBMRegressor`).

In [None]:
dask_reg = DaskLGBMRegressor(
    silent=False,
    max_depth=5,
    random_state=708,
    objective="regression_l2",
    learning_rate=0.1,
    tree_learner="data",
    n_estimators=10,
    min_child_samples=1,
    n_jobs=-1,
)

dask_reg.fit(
    client=client,
    X=data,
    y=labels,
)

The model produced by this training run is an instance of `DaskLGBMRegressor`. To get a regular non-Dask model (which can be pickled and saved), run `.to_local()`.

In [None]:
local_model = dask_reg.to_local()
type(local_model)

You can visualize this model by looking at a data frame representation of it.

In [None]:
local_model.booster_.trees_to_dataframe()