In [None]:
import dask.array as da
import os

from dask_cloudprovider.aws import FargateCluster
from dask.distributed import Client, LocalCluster, wait

from lightgbm.dask import LGBMRegressor

In [None]:
n_workers = 3
cluster = LocalCluster()
client = Client(cluster)
client.wait_for_workers(n_workers)

print(f"View the dashboard: {cluster.dashboard_link}")

In [None]:
num_rows = 1e6
num_features = 1e2
num_partitions = 10
rows_per_chunk = num_rows / num_partitions

In [None]:
data = da.random.random(
    (num_rows, num_features),
    (rows_per_chunk, num_features)
)

labels = da.random.random(
    (num_rows, 1),
    (rows_per_chunk, 1)
)


data = data.persist()
labels = labels.persist()
_ = wait(data)
_ = wait(labels)

In [None]:
dask_reg = LGBMRegressor(
    silent=False,
    max_depth=5,
    random_state=708,
    objective="regression_l2",
    learning_rate=0.1,
    tree_learner="data",
    n_estimators=10,
    min_child_samples=1,
    n_jobs=-1
)

dask_reg.fit(
    client=client,
    X=data,
    y=labels,
)

In [None]:
local_model = dask_reg.to_local()