This notebook shows how to use `lightgbm.dask` to train a LightGBM model on data stored as a [Dask DataFrame](https://docs.dask.org/en/latest/dataframe.html) or [Dask Array](https://docs.dask.org/en/latest/array.html).

It uses `FargateCluster` from [`dask-cloudprovider`](https://github.com/dask/dask-cloudprovider) to create a distributed cluster.

This notebook was most recently updated to test https://github.com/microsoft/LightGBM/pull/3994.

<hr>

## Set up a cluster on AWS Fargate

In [1]:
import json

with open("../ecr-details.json", "r") as f:
    ecr_details = json.loads(f.read())

CONTAINER_IMAGE = ecr_details["repository"]["repositoryUri"] + ":1"
print(f"scheduler and worker image: {CONTAINER_IMAGE}")

scheduler and worker image: public.ecr.aws/w8s1c8b1/dask-lgb-test-4468592b-5204-4e3b-ad80-7c5ae698472a-cluster:1


Before proceeding, set up your AWS credentials. If you're unsure how to do this, see [the AWS docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html).

In [2]:
import os

os.environ["AWS_DEFAULT_REGION"] = "us-west-2"

Create a cluster with 3 workers. See https://cloudprovider.dask.org/en/latest/aws.html#dask_cloudprovider.aws.FargateCluster for more options.

In [3]:
from dask.distributed import Client
from dask_cloudprovider.aws import FargateCluster

n_workers = 3
cluster = FargateCluster(
    image=CONTAINER_IMAGE,
    worker_cpu=512,
    worker_mem=4096,
    n_workers=n_workers,
    fargate_use_private_ip=False,
    scheduler_timeout="40 minutes",
    find_address_timeout=60 * 10,
)
client = Client(cluster)
client.wait_for_workers(n_workers)

  next(self.gen)

+---------+---------------+---------------+---------------+
| Package | client        | scheduler     | workers       |
+---------+---------------+---------------+---------------+
| blosc   | 1.10.2        | 1.9.2         | 1.9.2         |
| lz4     | 3.1.3         | 3.1.1         | 3.1.1         |
| msgpack | 1.0.2         | 1.0.0         | 1.0.0         |
| numpy   | 1.20.1        | 1.18.1        | 1.18.1        |
| python  | 3.8.6.final.0 | 3.8.0.final.0 | 3.8.0.final.0 |
+---------+---------------+---------------+---------------+
Notes: 
-  msgpack: Variation is ok, as long as everything is above 0.6


In [4]:
print(f"View the dashboard: {cluster.dashboard_link}")

View the dashboard: http://54.186.105.162:8787/status


Click the link above to view a diagnostic dashboard while you run the training code below.

<hr>

## Train a model

In [6]:
import dask.array as da
from dask.distributed import wait
from lightgbm.dask import DaskLGBMRegressor

num_rows = 1e6
num_features = 1e2
num_partitions = 10
rows_per_chunk = num_rows / num_partitions

data = da.random.random((num_rows, num_features), (rows_per_chunk, num_features))

labels = da.random.random((num_rows, 1), (rows_per_chunk, 1))

data = data.persist()
labels = labels.persist()
_ = wait(data)
_ = wait(labels)

## method 1 - just defaults

You should be able to connect without specifying any IP addresses or ports.

In [8]:
model = DaskLGBMRegressor(num_iterations=10, num_leaves=10)
model.fit(data, labels)

Finding random open ports for workers


DaskLGBMRegressor(num_iterations=10, num_leaves=10, num_threads=1, time_out=120,
                  tree_learner='data')

## method 2 - local_listen_port

If you just pass `local_listen_port`, that should be fine.

In [15]:
model = DaskLGBMRegressor(num_iterations=10, num_leaves=10, local_listen_port=16000)
model.fit(data, labels)

Using passed-in 'local_listen_port' for all workers


DaskLGBMRegressor(num_iterations=10, num_leaves=10, num_threads=1, time_out=120,
                  tree_learner='data')

## method 3 - local_listen_port

You should be able to pass ``machines`` directly.

In [13]:
import socket
from urllib.parse import urlparse


def _find_random_open_port() -> int:
    """Find a random open port on localhost"""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("", 0))
        port = s.getsockname()[1]
    return port


worker_addresses = client.scheduler_info()["workers"].keys()
machines = ",".join(
    [
        urlparse(worker_address).hostname + ":" + str(_find_random_open_port())
        for worker_address in worker_addresses
    ]
)
print(machines)

172.31.4.109:34617,172.31.42.101:47359,172.31.55.201:36471


In [14]:
model = DaskLGBMRegressor(machines=machines)
model.fit(data, labels)

Using passed-in 'machines' parameter


DaskLGBMRegressor(num_iterations=10, num_leaves=10, num_threads=1, time_out=120,
                  tree_learner='data')