In [1]:
import ray
import random
from ray.air import session, Checkpoint
from ray.air import DatasetConfig
from ray.data import Dataset
from ray.train.torch import TorchTrainer

In [2]:
# our dummy model function
def model(batch):
    return len(batch) * 0.1 * random.uniform(0,1)

def train_loop():
    # By default, bulk loading is used and returns a Dataset object.
    data_shard: Dataset = session.get_dataset_shard("train")
    loss = 0.0
    # Manually iterate over the data 10 times (10 epochs).
    for epoch in range(1, 11):
        # for each epoch iterate over batches
        num_batches = 0
        for batch in data_shard.iter_batches():
            num_batches += 1
            batch_loss = model(batch)
            loss += batch_loss
        loss /= num_batches * 100
        if epoch % 2 == 0:
            print(f"Doing some training on epoch: {epoch} for batches: {num_batches} and loss over batch: {loss:.3f}")
        session.report({"loss": loss, "epoch": epoch}, 
                       checkpoint=Checkpoint.from_dict({"loss": loss, "epoch": epoch}))
    # View the stats for performance debugging.
    print(data_shard.stats())

In [3]:
# Create our TorchTrainer
from ray.air.config import ScalingConfig

train_ds = ray.data.range_tensor(1000)
trainer = TorchTrainer(train_loop,
                       scaling_config= ScalingConfig(num_workers=1),
                       datasets={"train": train_ds},
                      )
result = trainer.fit()

2022-08-02 11:36:01,693	INFO worker.py:1481 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8267[39m[22m.


Trial name,status,loc,iter,total time (s),loss,epoch,_timestamp
TorchTrainer_f5fe9_00000,TERMINATED,127.0.0.1:62493,10,2.70262,0.167062,10,1659465366


[2m[36m(RayTrainWorker pid=62500)[0m 2022-08-02 11:36:04,858	INFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=1]


Result for TorchTrainer_f5fe9_00000:
  _time_this_iter_s: 0.060938358306884766
  _timestamp: 1659465366
  _training_iteration: 1
  date: 2022-08-02_11-36-06
  done: false
  epoch: 1
  experiment_id: 0b5e8d1f4c664af09090c1ac797f0a9a
  hostname: Juless-MacBook-Pro-16
  iterations_since_restore: 1
  loss: 0.12134055856052596
  node_ip: 127.0.0.1
  pid: 62493
  should_checkpoint: true
  time_since_restore: 2.3281710147857666
  time_this_iter_s: 2.3281710147857666
  time_total_s: 2.3281710147857666
  timestamp: 1659465366
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: f5fe9_00000
  warmup_time: 0.002975940704345703
  
[2m[36m(RayTrainWorker pid=62500)[0m Doing some training on epoch: 2 for batches: 4 and loss over batch: 0.131
[2m[36m(RayTrainWorker pid=62500)[0m Doing some training on epoch: 4 for batches: 4 and loss over batch: 0.096
[2m[36m(RayTrainWorker pid=62500)[0m Doing some training on epoch: 6 for batches: 4 and loss over batch: 0.138
[2m[36m(RayTrainW

2022-08-02 11:36:06,750	INFO tune.py:758 -- Total run time: 4.25 seconds (4.08 seconds for the tuning loop).


In [6]:
print(result.metrics)

{'loss': 0.11696128050647575, 'epoch': 10, '_timestamp': 1659227860, '_time_this_iter_s': 0.03619694709777832, '_training_iteration': 10, 'time_this_iter_s': 0.03714895248413086, 'should_checkpoint': True, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 10, 'trial_id': 'f98f8_00000', 'experiment_id': '974e7749bd434876b92f52e09afd0c6c', 'date': '2022-07-30_17-37-40', 'timestamp': 1659227860, 'time_total_s': 2.3749780654907227, 'pid': 51580, 'hostname': 'Juless-MacBook-Pro-16', 'node_ip': '127.0.0.1', 'config': {}, 'time_since_restore': 2.3749780654907227, 'timesteps_since_restore': 0, 'iterations_since_restore': 10, 'warmup_time': 0.0033597946166992188, 'experiment_tag': '0'}


In [7]:
result.metrics["loss"]

0.11696128050647575

In [8]:
result.checkpoint.to_dict()['loss']

0.11696128050647575