# Distributed training

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/yggdrasil-decision-forests/blob/main/documentation/public/docs/tutorial/distributed_training.ipynb)

## Setup

In [1]:
pip install ydf -U

In [2]:
import ydf  # Yggdrasil Decision Forests
import pandas as pd  # We use Pandas to load small datasets

import os
import threading

## What is distributed training?

**Distributed training** is a technique that involves splitting the computation cost of training a model over multiple computers. In other words, instead of training a model on a single machine, the model is trained on multiple machines in parallel. This can significantly speed up the training process, as well as allow for larger datasets to be used.

Distributed training has been used to train YDF models on datasets containing **billions of examples** and **thousands of features**.

Distributed training can also be used for hyperparameter tuning: While each model is trained on a single machine, multiple models can be trained and evaluated in parallel, which can speed up the process of finding the best hyperparameters. 
See the [tuning](../tuning) notebook for details.

Distributed training requires:

- A collection of machines called **workers** that can communicate with each other via IP. In this notebook, we will spawn workers locally for the sake of example.
- A **manager** machine that will execute the YDF code and can communicate with the workers via their IP addresses.
- A **shared directory** that is accessible by both the manager and worker machines. This directory will be used as a computation cache and for checkpoints in case one of the workers or the manager is interrupted during training.
- A dataset on disk in one of the supported YDF formats, such as csv or tfrecord. Furthermore, the dataset must be sharded, which means it must be divided into subfiles so that each worker can handle a different portion of the dataset in parallel. In this example, we will use a small dataset and split it up ahead of time. For large datasets, sharding should be done beforehand, for example using Apache Beam.

<div style="border:1px solid #8FAFDF; background-color:#DCEAFF; padding: 5px;">
<b>For Googlers</b>
<p>YDF internal examples available at <a href="http://go/ydf/examples">go/ydf/examples</a> demonstrate how to use distributed training on Google infrastructure.</p></div>

## Download and split dataset

As a general guideline, there should be approximately 10 shards for each worker. For instance, if you have 100 workers, the dataset should be sharded into 1000 pieces.


In [3]:
# Download a classification dataset and load it as a Pandas DataFrame.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")

In [4]:
def split_dataset(
    dataset: pd.DataFrame, tmp_dir: str, num_shards: int
) -> list[str]:
  """Splits a csv file into multiple csv files."""

  os.makedirs(tmp_dir,exist_ok=True)
  num_row_per_shard = (dataset.shape[0] + num_shards - 1) // num_shards
  paths = []
  for shard_idx in range(num_shards):
    begin_idx = shard_idx * num_row_per_shard
    end_idx = (shard_idx + 1) * num_row_per_shard
    shard_dataset = dataset.iloc[begin_idx:end_idx]
    shard_path = os.path.join(tmp_dir , f"shard_{shard_idx}.csv")
    paths.append(shard_path)
    shard_dataset.to_csv(shard_path, index=False)
  return paths

sharded_train_paths = split_dataset(train_ds, "train_ds", 10)
print(sharded_train_paths)

['train_ds/shard_0.csv', 'train_ds/shard_1.csv', 'train_ds/shard_2.csv', 'train_ds/shard_3.csv', 'train_ds/shard_4.csv', 'train_ds/shard_5.csv', 'train_ds/shard_6.csv', 'train_ds/shard_7.csv', 'train_ds/shard_8.csv', 'train_ds/shard_9.csv']


## Configure workers

A worker is a Python program that runs the command `ydf.start_worker(port)`.

Let's start 4 workers locally for the example.

In [5]:
def create_worker_thread(port):
    thread = threading.Thread(target=ydf.start_worker, args=(port,))
    thread.start()

create_worker_thread(8101)
create_worker_thread(8102)
create_worker_thread(8103)
create_worker_thread(8104)

[INFO 23-11-10 14:33:30.6428 CET worker.cc:41] Start YDF worker on port 8101
[INFO 23-11-10 14:33:30.6433 CET grpc_worker.cc:395] Start worker server at address [::]:8101
[INFO 23-11-10 14:33:30.6437 CET worker.cc:41] Start YDF worker on port 8102
[INFO 23-11-10 14:33:30.6437 CET grpc_worker.cc:395] Start worker server at address [::]:8102
[INFO 23-11-10 14:33:30.6441 CET worker.cc:41] Start YDF worker on port 8103
[INFO 23-11-10 14:33:30.6441 CET grpc_worker.cc:395] Start worker server at address [::]:8103
[INFO 23-11-10 14:33:30.6450 CET worker.cc:41] Start YDF worker on port 8104
[INFO 23-11-10 14:33:30.6450 CET grpc_worker.cc:395] Start worker server at address [::]:8104


## Train model

Let's train the model:

In [6]:
ydf.verbose(2)  # To show (a lot of) training logs

learner = ydf.DistributedGradientBoostedTreesLearner(
    label="income",
    num_trees=10,
    workers=["localhost:8101", "localhost:8102", "localhost:8103", "localhost:8104"],
    working_dir="work_dir",
    resume_training=True,
)
model = learner.train("csv:" + ",".join(sharded_train_paths))

[INFO 23-11-10 14:33:30.7025 CET csv_example_reader.cc:208] 0 row(s) processed
[INFO 23-11-10 14:33:30.7135 CET csv_example_reader.cc:212] Stop scanning the csv file to infer the type. Some records were not considered.
[INFO 23-11-10 14:33:30.7136 CET data_spec_inference.cc:426] 15 column(s) found
[INFO 23-11-10 14:33:30.7139 CET csv_example_reader.cc:296] 0 row(s) processed
[INFO 23-11-10 14:33:30.7934 CET data_spec_inference.cc:305] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the column workclass (7 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000
[INFO 23-11-10 14:33:30.7935 CET data_spec_inference.cc:305] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the column occupation (13 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000
[INFO 23-11-10 14:33:30.7935 CET data_spec_inference.cc:305] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the col

The model can be inspected and evaluated.

In [7]:
model.describe()



In [8]:
model.evaluate(test_ds)



[INFO 23-11-10 14:33:32.1714 CET abstract_model.cc:1343] Engine "GradientBoostedTreesQuickScorerExtended" built


Label \ Pred,<=50K,>50K
<=50K,7188,224
>50K,1235,1122


## Limitations

Distributed training is currently only available for Gradient Boosted Trees with regression and classification tasks.
