## Ray AIR : Trainer

### Deep Learning Trainers

Ray Train offer 3 main deep learning trainers: `TorchTrainer`, `TensorflowTrainer`, and `HorovodTrainer`

#### PyTorch

In [4]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import ray
from ray import train
from ray.air import session, Checkpoint
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

In [27]:
config = {
    'input_size': 1,
    'layer_size': 15,
    'output_size': 1,
    'num_epochs': 20
}

In [28]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(config['input_size'], config['layer_size'])
        self.fc2 = nn.Linear(config['layer_size'], config['output_size'])
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x
    
def train_loop_per_worker():
    dataset_shard = session.get_dataset_shard('train')
    model = Network()
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.003)
    
    model = train.torch.prepare_model(model)
    for epoch in range(config['num_epochs']):
        for batches in dataset_shard.iter_torch_batches(batch_size=32, dtypes=torch.float):
            inputs, labels = torch.unsqueeze(batches['x'], 1), batches['y']
            output = model(inputs)
            loss = loss_fn(output, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f'epoch: {epoch}, loss; {loss.item()}')
        
        session.report(
            {},
            checkpoint=Checkpoint.from_dict(dict(epoch=epoch, model=model.state_dict())),
        )

In [22]:
train_dataset = ray.data.from_items([{"x": x, "y": 2 * x + 1} for x in range(200)])
scaling_config = ScalingConfig(num_workers=3, use_gpu=False)

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    scaling_config=scaling_config,
    datasets={'train': train_dataset},
)

result = trainer.fit()

Trial name,status,loc,iter,total time (s),_timestamp,_time_this_iter_s,_training_iteration
TorchTrainer_f8638_00000,TERMINATED,127.0.0.1:11891,20,6.76346,1666082202,0.210664,20


[2m[36m(RayTrainWorker pid=11896)[0m 2022-10-18 17:36:38,761	INFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=3]
[2m[36m(RayTrainWorker pid=11896)[0m 2022-10-18 17:36:38,906	INFO train_loop_utils.py:300 -- Moving model to device: cpu
[2m[36m(RayTrainWorker pid=11896)[0m 2022-10-18 17:36:38,906	INFO train_loop_utils.py:347 -- Wrapping provided model in DDP.


[2m[36m(RayTrainWorker pid=11897)[0m epoch: 0, loss; 45093.17578125
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 0, loss; 25554690048.0
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 0, loss; 54554.171875
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 0, loss; 22445670400.0
[2m[36m(RayTrainWorker pid=11898)[0m epoch: 0, loss; 51718.91015625
[2m[36m(RayTrainWorker pid=11898)[0m epoch: 0, loss; 25215422464.0
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 0, loss; 882762.5
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 0, loss; 980065.0625
[2m[36m(RayTrainWorker pid=11898)[0m epoch: 0, loss; 1078653.0
Result for TorchTrainer_f8638_00000:
  _time_this_iter_s: 0.41376280784606934
  _timestamp: 1666082199
  _training_iteration: 1
  date: 2022-10-18_17-36-39
  done: false
  experiment_id: 6af9a4d21a064a91b0cf281e476862ec
  hostname: YONGJINs-MacBook-Pro.local
  iterations_since_restore: 1
  node_ip: 127.0.0.1
  pid: 11891
  should_checkpoint: true
  time_since_restore: 3.3061208

[2m[36m(RayTrainWorker pid=11897)[0m   return torch.as_tensor(ndarray, dtype=dtype, device=device)
[2m[36m(RayTrainWorker pid=11897)[0m   return F.mse_loss(input, target, reduction=self.reduction)
[2m[36m(RayTrainWorker pid=11896)[0m   return torch.as_tensor(ndarray, dtype=dtype, device=device)
[2m[36m(RayTrainWorker pid=11896)[0m   return F.mse_loss(input, target, reduction=self.reduction)
[2m[36m(RayTrainWorker pid=11896)[0m   return F.mse_loss(input, target, reduction=self.reduction)
[2m[36m(RayTrainWorker pid=11898)[0m   return torch.as_tensor(ndarray, dtype=dtype, device=device)
[2m[36m(RayTrainWorker pid=11898)[0m   return F.mse_loss(input, target, reduction=self.reduction)
[2m[36m(RayTrainWorker pid=11898)[0m   return F.mse_loss(input, target, reduction=self.reduction)
[2m[36m(RayTrainWorker pid=11897)[0m   return F.mse_loss(input, target, reduction=self.reduction)


[2m[36m(RayTrainWorker pid=11897)[0m epoch: 1, loss; 1000791.625
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 1, loss; 1060378.75
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 1, loss; 849056.875
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 1, loss; 1059024.5
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 1, loss; 1037115.4375
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 1, loss; 944651.8125
[2m[36m(RayTrainWorker pid=11898)[0m epoch: 1, loss; 1041520.625
[2m[36m(RayTrainWorker pid=11898)[0m epoch: 1, loss; 1048851.25
[2m[36m(RayTrainWorker pid=11898)[0m epoch: 1, loss; 1041540.25
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 2, loss; 930459.875
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 2, loss; 961454.4375
[2m[36m(RayTrainWorker pid=11898)[0m epoch: 2, loss; 965653.375
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 2, loss; 60474158743552.0
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 2, loss; 1677171200.0
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 2,

[2m[36m(RayTrainWorker pid=11897)[0m epoch: 14, loss; 38566543360.0
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 14, loss; 38119370752.0
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 14, loss; 37623693312.0
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 14, loss; 38578601984.0
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 14, loss; 38115459072.0
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 14, loss; 37645025280.0
[2m[36m(RayTrainWorker pid=11898)[0m epoch: 14, loss; 38574698496.0
[2m[36m(RayTrainWorker pid=11898)[0m epoch: 14, loss; 38116786176.0
[2m[36m(RayTrainWorker pid=11898)[0m epoch: 14, loss; 37663268864.0
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 15, loss; 37198725120.0
[2m[36m(RayTrainWorker pid=11896)[0m epoch: 15, loss; 37210566656.0
[2m[36m(RayTrainWorker pid=11898)[0m epoch: 15, loss; 37206736896.0
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 15, loss; 36767662080.0
[2m[36m(RayTrainWorker pid=11897)[0m epoch: 15, loss; 36288868352.0
[2m[

2022-10-18 17:36:43,696	INFO tune.py:758 -- Total run time: 10.58 seconds (10.44 seconds for the tuning loop).


#### Tensorflow

In [23]:
import tensorflow as tf

from ray.air import session
from ray.air.callbacks.keras import Callback
from ray.train.tensorflow import prepare_dataset_shard
from ray.train.tensorflow import TensorflowTrainer
from ray.air.config import ScalingConfig

In [34]:
def build_model() -> tf.keras.Model:
    model = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=()),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(10),
            tf.keras.layers.Dense(1),
        ]
    )
    
    return model

def train_func(cofnig: dict):
    batch_size = config.get("batch_size", 64)
    epochs = config.get('epochs', 10)
    
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        multi_worker_model = build_model()
        multi_worker_model.compile(
            optimizer=tf.keras.optimizers.SGD(learning_rate=config.get('lr', 1e-3)),
            loss=tf.keras.losses.mean_squared_error,
            metrics=[tf.keras.metrics.mean_squared_error],
        )
        
    dataset = session.get_dataset_shard('train')
    
    def to_tf_dataset(dataset, batch_size):
        def to_tensor_iterator():
            for batch in dataset.iter_tf_batches(
                batch_size=batch_size, dtypes=tf.float32
            ):
                yield batch['x'], batch['y']
                
        output_signature = (
            tf.TensorSpec(shape=(None), dtype=tf.float32),
            tf.TensorSpec(shape=(None), dtype=tf.float32),
        )
        tf_dataset = tf.data.Dataset.from_generator(
            to_tensor_iterator, output_signature=output_signature
        )
        
        return prepare_dataset_shard(tf_dataset)
    
    result = []
    for _ in range(epochs):
        tf_dataset = to_tf_dataset(dataset=dataset, batch_size=batch_size)
        history = multi_worker_model.fit(tf_dataset, callbacks=[Callback()])
        result.append(history.history)
        
    return result

In [35]:
train_dataset = ray.data.from_items([{"x": x, "y": 2 * x + 1} for x in range(200)])

num_workers = 2
use_gpu = False

config = {"lr": 1e-3, "batch_size": 32, "epochs": 4}

trainer = TensorflowTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=config,
    scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
    datasets={'train': train_dataset},
)

result = trainer.fit()
print(result.metrics)

Trial name,status,loc,iter,total time (s),loss,mean_squared_error,_timestamp
TensorflowTrainer_e22d0_00000,TERMINATED,127.0.0.1:13685,4,10.5579,,,1666084749


[2m[36m(RayTrainWorker pid=13695)[0m 2022-10-18 18:19:04.285909: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
[2m[36m(RayTrainWorker pid=13695)[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
[2m[36m(RayTrainWorker pid=13695)[0m 2022-10-18 18:19:04.289391: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> 127.0.0.1:53021, 1 -> 127.0.0.1:53022}
[2m[36m(RayTrainWorker pid=13695)[0m 2022-10-18 18:19:04.289512: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:272] Initialize GrpcChannelCache for job worker -> {0 -> 127.0.0.1:53021, 1 -> 127.0.0.1:53022}
[2m[36m(RayTrainWorker pid=13695)[0m 2022-10-18 18:19:04.291473: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:438] 

      1/Unknown - 3s 3s/step - loss: 103792.5625 - mean_squared_error: 103792.5625
      1/Unknown - 3s 3s/step - loss: 103792.5625 - mean_squared_error: 103792.5625
      5/Unknown - 3s 13ms/step - loss: nan - mean_squared_error: nan              
      5/Unknown - 3s 13ms/step - loss: nan - mean_squared_error: nan              
Result for TensorflowTrainer_e22d0_00000:
  _time_this_iter_s: 3.653472900390625
  _timestamp: 1666084747
  _training_iteration: 1
  date: 2022-10-18_18-19-08
  done: false
  experiment_id: 2279cddf94834c3392d3276938acaf67
  hostname: YONGJINs-MacBook-Pro.local
  iterations_since_restore: 1
  loss: .nan
  mean_squared_error: .nan
  node_ip: 127.0.0.1
  pid: 13685
  should_checkpoint: true
  time_since_restore: 9.213613986968994
  time_this_iter_s: 9.213613986968994
  time_total_s: 9.213613986968994
  timestamp: 1666084748
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: e22d0_00000
  warmup_time: 0.017626285552978516
  
      1/Unknown - 0s 58m

2022-10-18 18:19:09,591	INFO tune.py:758 -- Total run time: 16.76 seconds (16.63 seconds for the tuning loop).


{'loss': nan, 'mean_squared_error': nan, '_timestamp': 1666084749, '_time_this_iter_s': 0.4578092098236084, '_training_iteration': 4, 'time_this_iter_s': 0.4618649482727051, 'should_checkpoint': True, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 4, 'trial_id': 'e22d0_00000', 'experiment_id': '2279cddf94834c3392d3276938acaf67', 'date': '2022-10-18_18-19-09', 'timestamp': 1666084749, 'time_total_s': 10.557934999465942, 'pid': 13685, 'hostname': 'YONGJINs-MacBook-Pro.local', 'node_ip': '127.0.0.1', 'config': {}, 'time_since_restore': 10.557934999465942, 'timesteps_since_restore': 0, 'iterations_since_restore': 4, 'warmup_time': 0.017626285552978516, 'experiment_tag': '0'}


[2m[36m(RayTrainWorker pid=13694)[0m Exception ignored in: <function Pool.__del__ at 0x7fad9ce34ee0>
[2m[36m(RayTrainWorker pid=13694)[0m Traceback (most recent call last):
[2m[36m(RayTrainWorker pid=13694)[0m   File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/multiprocessing/pool.py", line 268, in __del__
[2m[36m(RayTrainWorker pid=13694)[0m     self._change_notifier.put(None)
[2m[36m(RayTrainWorker pid=13694)[0m   File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/multiprocessing/queues.py", line 368, in put
[2m[36m(RayTrainWorker pid=13694)[0m     self._writer.send_bytes(obj)
[2m[36m(RayTrainWorker pid=13694)[0m   File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
[2m[36m(RayTrainWorker pid=13694)[0m     self._send_bytes(m[offset:offset + size])
[2m[36m(RayTrainWorker pid=13694)[0m   File "/Library/Frameworks/Python.framework/Versions/3.8/lib/