# Trainables

So far we have been using the functional interface to Raytune, which is lightweight and easy to get started with.

However, is limited in a couple of ways (1) it doesn't allow us to maintain state (2) raytune cannot 'see' or manage training iterations (3) it doesn't let us use some other useful parts of Raytune like the checkpointing or schedulers.

We'll take a look at a simple trainable below


In [1]:
%load_ext autoreload
%autoreload 2

from dependencies import *

Loading dependencies we have already seen...
Importing ray...
Done...


## Trainable Interface

 1. By subclassing tune.Trainable
 2. Setup state in `__init__`
 3. Implement `_train()` such that si completely one using unit/iteration of training
 4. Implement `_save` to save state, checkpoint models, etc...
 5. Implement `_restore` to, restore...


In [2]:
from os import path

class MyTrainable(tune.Trainable):
    
    
    def _setup(self, config):
        # config (dict): A dict of hyperparameters
        self.x = 0
        self.a = config["a"]

        
    def _train(self):  # This is called iteratively.
        self.x += self.a
        print("Trainable", f"({self.a})", self.x)
        return {"score": self.x }
    
    
    def _save(self, checkpoint_dir):
        checkpoint_path = path.join(checkpoint_dir, "model.npy")
        np.save(checkpoint_path, np.array(self.x))
        return checkpoint_path

    #
    # Restore is used internally by Raytune and schedulers. 
    # It's only useful manually on single training runs.
    #
    def _restore(self, checkpoint_path):
        print("CHECKPOINT PATH", checkpoint_path)
        self.x = np.load(checkpoint_path)[0]


## Start Ray

In [2]:
ray.shutdown()
ray.init(num_cpus=2, num_gpus=0, include_webui=True)

2020-06-10 16:43:56,206	INFO resource_spec.py:204 -- Starting Ray with 35.84 GiB memory available for workers and up to 17.94 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).
2020-06-10 16:43:56,508	INFO services.py:1168 -- View the Ray dashboard at [1m[32mlocalhost:8265[39m[22m


{'node_ip_address': '192.168.1.39',
 'raylet_ip_address': '192.168.1.39',
 'redis_address': '192.168.1.39:18955',
 'object_store_address': '/tmp/ray/session_2020-06-10_16-43-56_180499_14591/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2020-06-10_16-43-56_180499_14591/sockets/raylet',
 'webui_url': 'localhost:8265',
 'session_dir': '/tmp/ray/session_2020-06-10_16-43-56_180499_14591'}

## Run

Do some simple tuning

In [None]:
analysis = tune.run(
    MyTrainable,
    name="simple_trainable",
    stop={"training_iteration": 20},
    config={ "a": tune.grid_search([1,2]) },
    checkpoint_freq=5,
    resources_per_trial=dict(cpu=1, gpu=0),
    local_dir="~/ray_results/my_trainable")

print('best config: ', analysis.get_best_config(metric="score", mode="max"))

Go check the ray_results directory!!!

In [3]:
ray.shutdown()