# Parallel training with Thinc and Ray

This notebook is based off one of [Ray's tutorials](https://ray.readthedocs.io/en/latest/auto_examples/plot_parameter_server.html) and shows how to use Thinc and Ray to implement parallel training. It includes implementations for both synchronous and asynchronous parameter server training.

In [None]:
!pip install thinc ml_datasets ray psutil setproctitle

Let's start with a simple model and [config file](https://thinc.ai/docs/usage-config). You can edit the `CONFIG` string within the file, or copy it out to a separate file and use `Config.from_disk` to load it from a path. The `[ray]` section contains the settings to use for Ray. (We're using a config for convenience, but you don't have to – you can also just hard-code the values.)

In [22]:
import thinc
from thinc.api import chain, ReLu, Softmax

@thinc.registry.layers("relu_relu_softmax.v1")
def make_relu_relu_softmax(hidden_width: int, dropout: float):
    return chain(
        ReLu(hidden_width, dropout=dropout),
        ReLu(hidden_width, dropout=dropout),
        Softmax(),
    )

CONFIG = """
[training]
iterations = 200
batch_size = 128

[evaluation]
batch_size = 256
frequency = 10

[model]
@layers = "relu_relu_softmax.v1"
hidden_width = 128
dropout = 0.2

[optimizer]
@optimizers = "Adam.v1"

[ray]
num_workers = 2
object_store_memory = 3000000000
num_cpus = 2
"""

Just like in the original Ray tutorial, we're using the MNIST data (via our [`ml-datasets`](https://github.com/explosion/ml-datasets) package) and are setting up two helper functions: 

1. `get_data_loader`: Return shuffled batches of a given batch size.
2. `evaluate`: Evaluate a model on batches of data.

In [23]:
import ml_datasets
from thinc.api import get_shuffled_batches, evaluate_model_on_arrays

MNIST = ml_datasets.mnist()

def get_data_loader(batch_size):
    (train_X, train_Y), (dev_X, dev_Y) = MNIST
    train_batches = get_shuffled_batches(train_X, train_Y, batch_size)
    dev_batches = get_shuffled_batches(dev_X, dev_Y, batch_size)
    return train_batches, dev_batches

def evaluate(model, batch_size):
    dev_X, dev_Y = MNIST[1]
    return evaluate_model_on_arrays(model, dev_X, dev_Y, batch_size)

---

## Setting up Ray

### Getters and setters for gradients and weights

Using Thinc's `Model.walk` method, we can implement the following helper functions to get and set weights and parameters for each node in a model's tree. Those functions can later be used by the parameter server and workers.

In [24]:
from collections import defaultdict

def get_model_weights(model):
    params = defaultdict(dict)
    for node in model.walk():
        for name in node.param_names:
            if node.has_param(name):
                params[node.id][name] = node.get_param(name)
    return params

def set_model_weights(model, params):
    for node in model.walk():
        for name, param in params[node.id].items():
            node.set_param(name, param)

def get_model_grads(model):
    grads = defaultdict(dict)
    for node in model.walk():
        for name in node.grad_names:
            grads[node.id][name] = node.get_grad(name)
    return grads

def set_model_grads(model, grads):
    for node in model.walk():
        for name, grad in grads[node.id].items():
            node.set_grad(name, grad)

### Defining the Parameter Server

> The parameter server will hold a copy of the model. During training, it will:
>
> 1. Receive gradients and apply them to its model.
> 2. Send the updated model back to the workers.
>
> The `@ray.remote` decorator defines a remote process. It wraps the `ParameterServer `class and allows users to instantiate it as a remote actor. ([Source](https://ray.readthedocs.io/en/latest/auto_examples/plot_parameter_server.html#defining-the-parameter-server))

Here, the `ParameterServer` is initialized with a model and optimizer, and has a method to apply gradients received by the workers and a method to get the weights from the current model, using the helper functions defined above.

In [25]:
import ray

@ray.remote
class ParameterServer:
    def __init__(self, model, optimizer):
        self.model = model
        self.optimizer = optimizer

    def apply_gradients(self, *worker_grads):
        summed_gradients = defaultdict(dict)
        for grads in worker_grads:
            for node_id, node_grads in grads.items():
                for name, grad in node_grads.items():
                    if name in summed_gradients[node_id]:
                        summed_gradients[node_id][name] += grad
                    else:
                        summed_gradients[node_id][name] = grad.copy()
        set_model_grads(self.model, summed_gradients)
        self.model.finish_update(self.optimizer)
        return get_model_weights(self.model)

    def get_weights(self):
        return get_model_weights(self.model)

### Defining the Worker

> The worker will also hold a copy of the model. During training. it will continuously evaluate data and send gradients to the parameter server. The worker will synchronize its model with the Parameter Server model weights. ([Source](https://ray.readthedocs.io/en/latest/auto_examples/plot_parameter_server.html#defining-the-worker))

To compute the gradients during training, we can call the model on a batch of data (and set `is_train=True`). This returns the predictions and a `backprop` callback to update the model.

In [26]:
from thinc.api import fix_random_seed

@ray.remote
class DataWorker:
    def __init__(self, model, batch_size=128, seed=0):
        self.model = model
        fix_random_seed(seed)
        self.data_iterator = iter(get_data_loader(batch_size)[0])
        self.batch_size = batch_size

    def compute_gradients(self, weights):
        set_model_weights(self.model, weights)
        try:
            data, target = next(self.data_iterator)
        except StopIteration:  # When the epoch ends, start a new epoch.
            self.data_iterator = iter(get_data_loader(self.batch_size)[0])
            data, target = next(self.data_iterator)
        guesses, backprop = self.model(data, is_train=True)
        backprop((guesses - target) / target.shape[0])
        return get_model_grads(self.model)

### Setting up the model

Using the `CONFIG` defined above, we can load the settings and set up the model and optimizer. Thinc's `registry.make_from_config` will parse the config, resolve all references to registered functions and return a dict.

In [27]:
from thinc.api import registry, Config
C = registry.make_from_config(Config().from_str(CONFIG))
C

{'training': {'iterations': 200, 'batch_size': 128},
 'evaluation': {'batch_size': 256, 'frequency': 10},
 'model': <thinc.model.Model at 0x143eb3d08>,
 'optimizer': <thinc.optimizers.Optimizer at 0x14438d128>,
 'ray': {'num_workers': 2, 'object_store_memory': 3000000000, 'num_cpus': 2}}

We didn't specify all the dimensions in the model, so we need to pass in a batch of data to finish initialization. This lets Thinc infer the missing shapes.

In [28]:
optimizer = C["optimizer"]
model = C["model"]

(train_X, train_Y), (dev_X, dev_Y) = MNIST
model.initialize(X=train_X[:5], Y=train_Y[:5])

<thinc.model.Model at 0x143eb3d08>

---

## Training

### Synchronous Parameter Server training

We can now create a synchronous parameter server training scheme:

1. Call `ray.init` with the settings defined in the config.
2. Instantiate a process for the `ParameterServer`.
3. Create multiple workers (`n_workers`, as defined in the config).


_(The Ray tutorial didn't mention whether to set a different random seed for the workers, but it makes sense? Otherwise it seems the workers will iterate over the batches in the same order, which seems wrong?)_

In [29]:
ray.init(
    ignore_reinit_error=True,
    object_store_memory=C["ray"]["object_store_memory"],
    num_cpus=C["ray"]["num_cpus"],
)
ps = ParameterServer.remote(model, optimizer)
workers = []
for i in range(C["ray"]["num_workers"]):
    worker = DataWorker.remote(model, batch_size=C["training"]["batch_size"], seed=i)
    workers.append(worker)

2020-01-21 02:26:00,552	INFO resource_spec.py:216 -- Starting Ray with 2.93 GiB memory available for workers and up to 2.79 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).


On each iteration, we now compute the gradients for each worker. After all gradients are available, `ParameterServer.apply_gradients` is called to calculate the update. The `frequency` setting in the `evaluation` config specifies how often to evaluate – for instance, a frequency of `10` means we're only evaluating every 10th epoch. 

In [30]:
current_weights = ps.get_weights.remote()
for i in range(C["training"]["iterations"]):
    gradients = [worker.compute_gradients.remote(current_weights) for worker in workers]
    current_weights = ps.apply_gradients.remote(*gradients)
    if i % C["evaluation"]["frequency"] == 0:
        set_model_weights(model, ray.get(current_weights))
        accuracy = evaluate(model, C["evaluation"]["batch_size"])
        print(f"{i} \taccuracy: {accuracy:.3f}")
print(f"Final \taccuracy: {accuracy:.3f}")
ray.shutdown()



0 	accuracy: 0.107
10 	accuracy: 0.291
20 	accuracy: 0.538
30 	accuracy: 0.648
40 	accuracy: 0.492
50 	accuracy: 0.590
60 	accuracy: 0.508
70 	accuracy: 0.382
80 	accuracy: 0.317
90 	accuracy: 0.356
100 	accuracy: 0.387
110 	accuracy: 0.296
120 	accuracy: 0.317
130 	accuracy: 0.443
140 	accuracy: 0.392
150 	accuracy: 0.368
160 	accuracy: 0.407
170 	accuracy: 0.492
180 	accuracy: 0.546
190 	accuracy: 0.561
Final 	accuracy: 0.561


### Asynchronous Parameter Server Training

> Here, workers will asynchronously compute the gradients given its current weights and send these gradients to the parameter server as soon as they are ready. When the Parameter server finishes applying the new gradient, the server will send back a copy of the current weights to the worker. The worker will then update the weights and repeat. ([Source](https://ray.readthedocs.io/en/latest/auto_examples/plot_parameter_server.html#asynchronous-parameter-server-training))

The setup looks the same and we can reuse the config. Make sure to call `ray.shutdown()` to clean up resources and processes before calling `ray.init` again.

In [31]:
ray.init(
    ignore_reinit_error=True,
    object_store_memory=C["ray"]["object_store_memory"],
    num_cpus=C["ray"]["num_cpus"],
)
ps = ParameterServer.remote(model, optimizer)
workers = []
for i in range(C["ray"]["num_workers"]):
    worker = DataWorker.remote(model, batch_size=C["training"]["batch_size"], seed=i)
    workers.append(worker)

2020-01-21 02:26:15,201	INFO resource_spec.py:216 -- Starting Ray with 2.78 GiB memory available for workers and up to 2.79 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).


In [32]:
current_weights = ps.get_weights.remote()
gradients = {}
for worker in workers:
    gradients[worker.compute_gradients.remote(current_weights)] = worker

for i in range(C["training"]["iterations"] * C["ray"]["num_workers"]):
    ready_gradient_list, _ = ray.wait(list(gradients))
    ready_gradient_id = ready_gradient_list[0]
    worker = gradients.pop(ready_gradient_id)
    current_weights = ps.apply_gradients.remote(*[ready_gradient_id])
    gradients[worker.compute_gradients.remote(current_weights)] = worker
    if i % C["evaluation"]["frequency"] == 0:
        set_model_weights(model, ray.get(current_weights))
        accuracy = evaluate(model, C["evaluation"]["batch_size"])
        print(f"{i} \taccuracy: {accuracy:.3f}")
print(f"Final \taccuracy: {accuracy:.3f}")
ray.shutdown()

0 	accuracy: 0.566
10 	accuracy: 0.615
20 	accuracy: 0.664
30 	accuracy: 0.699
40 	accuracy: 0.724
50 	accuracy: 0.753
60 	accuracy: 0.773
70 	accuracy: 0.744
80 	accuracy: 0.776
90 	accuracy: 0.748
100 	accuracy: 0.673
110 	accuracy: 0.694
120 	accuracy: 0.715
130 	accuracy: 0.713
140 	accuracy: 0.701
150 	accuracy: 0.693
160 	accuracy: 0.717
170 	accuracy: 0.725
180 	accuracy: 0.686
190 	accuracy: 0.674
200 	accuracy: 0.707
210 	accuracy: 0.745
220 	accuracy: 0.759
230 	accuracy: 0.745
240 	accuracy: 0.695
250 	accuracy: 0.663
260 	accuracy: 0.672
270 	accuracy: 0.693
280 	accuracy: 0.717
290 	accuracy: 0.745
300 	accuracy: 0.737
310 	accuracy: 0.702
320 	accuracy: 0.692
330 	accuracy: 0.688
340 	accuracy: 0.697
350 	accuracy: 0.719
360 	accuracy: 0.730
370 	accuracy: 0.709
380 	accuracy: 0.682
390 	accuracy: 0.689
Final 	accuracy: 0.689


---

## Links & Resources

- [Ray documentation](https://ray.readthedocs.io/en/latest/index.html)
- [Training models](https://thinc.ai/docs/usage-training) (Thinc)