This notebook shows how to use `lightgbm.dask` to train a LightGBM model on data stored as a [Dask DataFrame](https://docs.dask.org/en/latest/dataframe.html) or [Dask Array](https://docs.dask.org/en/latest/array.html).

It uses `FargateCluster` from [`dask-cloudprovider`](https://github.com/dask/dask-cloudprovider) to create a distributed cluster.

<hr>

## Set up a cluster on AWS Fargate

In [1]:
import json

with open("ecr-details.json", "r") as f:
    ecr_details = json.loads(f.read())

CONTAINER_IMAGE = ecr_details["repository"]["repositoryUri"] + ":1"
print(f"scheduler and worker image: {CONTAINER_IMAGE}")

scheduler and worker image: public.ecr.aws/m0z8a6o8/dask-lgb-test-66d8e7ff-3c48-426c-993b-6bdf1243f916-cluster:1


Before proceeding, set up your AWS credentials. If you're unsure how to do this, see [the AWS docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html).

In [2]:
import os

os.environ["AWS_DEFAULT_REGION"] = "us-west-2"

Create a cluster with 3 workers. See https://cloudprovider.dask.org/en/latest/aws.html#dask_cloudprovider.aws.FargateCluster for more options.

In [3]:
from dask_cloudprovider.aws import FargateCluster
from dask.distributed import Client

n_workers = 3
cluster = FargateCluster(
    image=CONTAINER_IMAGE,
    worker_cpu=512,
    worker_mem=4096,
    n_workers=n_workers,
    fargate_use_private_ip=False,
    scheduler_timeout="40 minutes",
    find_address_timeout=60 * 10,
)
client = Client(cluster)
client.wait_for_workers(n_workers)

  next(self.gen)

+---------+---------------+---------------+---------------+
| Package | client        | scheduler     | workers       |
+---------+---------------+---------------+---------------+
| blosc   | 1.10.1        | 1.9.2         | 1.9.2         |
| msgpack | 1.0.2         | 1.0.0         | 1.0.0         |
| numpy   | 1.19.5        | 1.18.5        | 1.18.5        |
| python  | 3.8.5.final.0 | 3.8.0.final.0 | 3.8.0.final.0 |
| tornado | 6.0.4         | 6.1           | 6.1           |
+---------+---------------+---------------+---------------+
Notes: 
-  msgpack: Variation is ok, as long as everything is above 0.6


In [4]:
print(f"View the dashboard: {cluster.dashboard_link}")

View the dashboard: http://54.203.139.62:8787/status


Click the link above to view a diagnostic dashboard while you run the training code below.

<hr>

## Train a model

In [5]:
import dask.array as da
from dask.distributed import wait
from lightgbm.dask import DaskLGBMRegressor

In [6]:
for i in range(10):
    print(f"attempt {i}")
    
    client.restart()

    num_rows = 1e6
    num_features = 1e2
    num_partitions = 10
    rows_per_chunk = num_rows / num_partitions

    data = da.random.random((num_rows, num_features), (rows_per_chunk, num_features))

    labels = da.random.random((num_rows, 1), (rows_per_chunk, 1))

    data = data.persist()
    labels = labels.persist()
    _ = wait(data)
    _ = wait(labels)

    dask_reg = DaskLGBMRegressor(
        silent=False,
        max_depth=5,
        random_state=708,
        objective="regression_l2",
        learning_rate=0.1,
        tree_learner="data",
        n_estimators=10,
        min_child_samples=1,
        n_jobs=-1,
        local_listen_port=12400,
    )

    dask_reg.fit(
        client=client,
        X=data,
        y=labels,
    )

attempt 0
worker_address_to_port
{'tcp://172.31.4.117:35941': 12400, 'tcp://172.31.49.205:42975': 12401, 'tcp://172.31.57.243:37337': 12402}
attempt 1
worker_address_to_port
{'tcp://172.31.4.117:40295': 12400, 'tcp://172.31.49.205:37995': 12402, 'tcp://172.31.57.243:38989': 12401}
attempt 2
worker_address_to_port
{'tcp://172.31.57.243:46333': 12400, 'tcp://172.31.49.205:45193': 12402, 'tcp://172.31.4.117:36941': 12401}
attempt 3
worker_address_to_port
{'tcp://172.31.4.117:43961': 12400, 'tcp://172.31.49.205:33229': 12402, 'tcp://172.31.57.243:44241': 12403}
attempt 4
worker_address_to_port
{'tcp://172.31.4.117:34297': 12400, 'tcp://172.31.57.243:46027': 12402, 'tcp://172.31.49.205:41983': 12401}
attempt 5
worker_address_to_port
{'tcp://172.31.4.117:34703': 12400, 'tcp://172.31.57.243:40619': 12401, 'tcp://172.31.49.205:43009': 12402}
attempt 6
worker_address_to_port
{'tcp://172.31.57.243:33825': 12400, 'tcp://172.31.4.117:32845': 12401, 'tcp://172.31.49.205:41851': 12403}
attempt 7
wor

The model produced by this training run is an instance of `DaskLGBMRegressor`. To get a regular non-Dask model (which can be pickled and saved), run `.to_local()`.

In [None]:
local_model = dask_reg.to_local()
type(local_model)

You can visualize this model by looking at a data frame representation of it. You can also check that the model really used inputs from all workers.

In [None]:
dask_reg.booster_.trees_to_dataframe().iloc[0,]["count"] == 1e6