In [1]:
import ray
import random
from ray.air import session, Checkpoint
from ray.air import DatasetConfig
from ray.data import Dataset
from ray.train.torch import TorchTrainer

In [7]:
# our dummy model function
def model(batch):
    return len(batch) * 0.1 * random.uniform(0,1)

def train_loop():
    # By default, bulk loading is used and returns a Dataset object.
    data_shard: Dataset = session.get_dataset_shard("train")
    loss = 0.0
    # Manually iterate over the data 10 times (10 epochs).
    for epoch in range(1, 11):
        # for each epoch iterate over batches
        num_batches = 0
        for batch in data_shard.iter_batches():
            num_batches += 1
            batch_loss = model(batch)
            loss += batch_loss
        loss /= num_batches * 100
        if epoch % 2 == 0:
            print(f"Doing some training on epoch: {epoch} for batches: {num_batches} and loss over batch: {loss:.3f}")
        session.report({"loss": loss, "epoch": epoch}, 
                       checkpoint=Checkpoint.from_dict({"loss": loss, "epoch": epoch}))
    # View the stats for performance debugging.
    # print(data_shard.stats())

In [8]:
# Create our TorchTrainer
train_ds = ray.data.range_tensor(1000)
trainer = TorchTrainer(train_loop,
                       scaling_config={"num_workers": 1},
                       datasets={"train": train_ds},
                      )
result = trainer.fit()

Trial name,status,loc,iter,total time (s),loss,epoch,_timestamp
TorchTrainer_badf0_00000,TERMINATED,127.0.0.1:85750,10,2.11318,0.0263115,10,1658009272


[2m[36m(BaseWorkerMixin pid=85760)[0m 2022-07-16 15:07:51,384	INFO config.py:70 -- Setting up process group for: env:// [rank=0, world_size=1]


Result for TorchTrainer_badf0_00000:
  _time_this_iter_s: 0.05498075485229492
  _timestamp: 1658009272
  _training_iteration: 1
  date: 2022-07-16_15-07-52
  done: false
  epoch: 1
  experiment_id: de863d5814fd42479ea1aa6368376edb
  hostname: Juless-MacBook-Pro-16
  iterations_since_restore: 1
  loss: 0.028504291855607735
  node_ip: 127.0.0.1
  pid: 85750
  should_checkpoint: true
  time_since_restore: 1.814877986907959
  time_this_iter_s: 1.814877986907959
  time_total_s: 1.814877986907959
  timestamp: 1658009272
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: badf0_00000
  warmup_time: 0.002830982208251953
  
[2m[36m(BaseWorkerMixin pid=85760)[0m Doing some training on epoch: 2 for batches: 20 and loss over batch: 0.028
[2m[36m(BaseWorkerMixin pid=85760)[0m Doing some training on epoch: 4 for batches: 20 and loss over batch: 0.021
[2m[36m(BaseWorkerMixin pid=85760)[0m Doing some training on epoch: 6 for batches: 20 and loss over batch: 0.018
[2m[36m(BaseWo

2022-07-16 15:07:53,164	INFO tune.py:737 -- Total run time: 3.77 seconds (3.66 seconds for the tuning loop).


In [7]:
print(result.metrics)

{'loss': 0.024684769441396207, 'epoch': 100, '_timestamp': 1657839954, '_time_this_iter_s': 0.035117149353027344, '_training_iteration': 100, 'time_this_iter_s': 0.035944223403930664, 'should_checkpoint': True, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 100, 'trial_id': '7f364_00000', 'experiment_id': '5b6262e8afa148a6a87455dfbbdaad79', 'date': '2022-07-14_16-05-54', 'timestamp': 1657839954, 'time_total_s': 5.424535036087036, 'pid': 1526, 'hostname': 'Juless-MacBook-Pro-16', 'node_ip': '127.0.0.1', 'config': {}, 'time_since_restore': 5.424535036087036, 'timesteps_since_restore': 0, 'iterations_since_restore': 100, 'warmup_time': 0.002769947052001953, 'experiment_tag': '0'}


In [74]:
result.metrics["loss"]

0.23612264564954022

In [75]:
result.checkpoint.to_dict()['loss']

0.23612264564954022