Skip to content
Switch branches/tags
Go to file
Cannot retrieve contributors at this time
#!/usr/bin/env python3
Optimization utilities for scalable, high-performance reinforcement learning.
import torch.distributed as dist
from torch.optim.optimizer import Optimizer, required
class Distributed(Optimizer):
Synchronizes the gradients of a model across replicas.
At every step, `Distributed` averages the gradient across all replicas
before calling the wrapped optimizer.
The `sync` parameters determines how frequently the parameters are
synchronized between replicas, to minimize numerical divergences.
This is done by calling the `sync_parameters()` method.
If `sync is None`, this never happens except upon initialization of the
* **params** (iterable) - Iterable of parameters.
* **opt** (Optimizer) - The optimizer to wrap and synchronize.
* **sync** (int, *optional*, default=None) - Parameter
synchronization frequency.
1. Zinkevich et al. 2010. “Parallelized Stochastic Gradient Descent.”
opt = optim.Adam(model.parameters())
opt = Distributed(model.parameters(), opt, sync=1)
def __init__(self, params, opt, sync=None):
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
self.opt = opt
self.sync = sync
self.iter = 0
defaults = {}
super(Distributed, self).__init__(params, defaults)
def sync_parameters(self, root=0):
Broadcasts all parameters of root to all other replicas.
* **root** (int, *optional*, default=0) - Rank of root replica.
if self.world_size > 1:
for group in self.param_groups:
for p in group['params']:
dist.broadcast(, src=root)
def step(self):
if self.world_size > 1:
num_replicas = float(self.world_size)
# Average all gradients
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
d_p = p.grad
# Perform the averaging
dist.all_reduce(d_p) / num_replicas)
# Perform optimization step
self.iter += 1
if self.sync is not None and self.iter >= self.sync:
self.iter = 0