# Trainables

So far we have been using the functional interface to Raytune, which is lightweight and easy to get started with.

However, is limited in a couple of ways (1) it doesn't allow us to maintain state (2) raytune cannot 'see' or manage training iterations (3) it doesn't let us use some other useful parts of Raytune like the checkpointing or schedulers.

We'll take a look at a simple trainable below


In [1]:
%load_ext autoreload
%autoreload 2

from dependencies import *

Loading dependencies we have already seen...
Importing ray...
Done...


## Trainable Interface

 1. By subclassing tune.Trainable
 2. Setup state in `__init__`
 3. Implement `_train()` such that si completely one using unit/iteration of training
 4. Implement `_save` to save state, checkpoint models, etc...
 5. Implement `_restore` to, restore...


In [2]:
from os import path

class MyTrainable(tune.Trainable):
    
    # create data_loaders, load model
    def _setup(self, config):
        # config (dict): A dict of hyperparameters
        self.x = 0
        self.a = config["a"]

    # loops through the data
    def _train(self):  # This is called iteratively.
        self.x += self.a
        print("Trainable", f"({self.a})", self.x)
        return {"score": self.x }
    
    
    def _save(self, checkpoint_dir):
        checkpoint_path = path.join(checkpoint_dir, "model.npy")
        np.save(checkpoint_path, np.array(self.x))
        return checkpoint_path

    #
    # Restore is used internally by Raytune and schedulers. 
    # It's only useful manually on single training runs.
    #
    def _restore(self, checkpoint_path):
        print("CHECKPOINT PATH", checkpoint_path)
        self.x = np.load(checkpoint_path)[0]


## Start Ray

In [4]:
ray.shutdown()
ray.init(num_cpus=2, num_gpus=0, include_dashboard=False)

{'node_ip_address': '192.168.123.68',
 'raylet_ip_address': '192.168.123.68',
 'redis_address': '192.168.123.68:6379',
 'object_store_address': '/tmp/ray/session_2020-11-09_17-38-16_449401_6032/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2020-11-09_17-38-16_449401_6032/sockets/raylet',
 'webui_url': None,
 'session_dir': '/tmp/ray/session_2020-11-09_17-38-16_449401_6032',
 'metrics_export_port': 65250}

## Run

Do some simple tuning

In [5]:
analysis = tune.run(
    MyTrainable,
    name="simple_trainable",
    stop={"training_iteration": 20},
    config={ "a": tune.grid_search([1,2]) },
    checkpoint_freq=5,
    resources_per_trial=dict(cpu=1, gpu=0),
    local_dir="~/ray_results/my_trainable")

print('best config: ', analysis.get_best_config(metric="score", mode="max"))

Trial name,status,loc,a
MyTrainable_0083d_00000,RUNNING,,1
MyTrainable_0083d_00001,PENDING,,2




Result for MyTrainable_0083d_00000:
  date: 2020-11-09_17-38-31
  done: false
  experiment_id: 484aae004d544b0f9f947b6e807e7e53
  experiment_tag: 0_a=1
  hostname: Schlepptop
  iterations_since_restore: 1
  node_ip: 192.168.123.68
  pid: 6445
  score: 1
  time_since_restore: 0.0002770423889160156
  time_this_iter_s: 0.0002770423889160156
  time_total_s: 0.0002770423889160156
  timestamp: 1604939911
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 0083d_00000
  
Result for MyTrainable_0083d_00001:
  date: 2020-11-09_17-38-31
  done: false
  experiment_id: e3dab9f255aa49f4b5a3085b30001614
  experiment_tag: 1_a=2
  hostname: Schlepptop
  iterations_since_restore: 1
  node_ip: 192.168.123.68
  pid: 6446
  score: 2
  time_since_restore: 0.00029468536376953125
  time_this_iter_s: 0.00029468536376953125
  time_total_s: 0.00029468536376953125
  timestamp: 1604939911
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 0083d_00001
  
[2m[36m(pid=6445)[0m Trainable

Trial name,status,loc,a,iter,total time (s),score
MyTrainable_0083d_00000,TERMINATED,,1,20,0.00149846,20
MyTrainable_0083d_00001,TERMINATED,,2,20,0.00161791,40


Go check the ray_results directory!!!

In [7]:
ray.shutdown()