Skip to content

Commit

Permalink
[bug] Make OSS Gloo-compliant (#102)
Browse files Browse the repository at this point in the history
* Broadcasting grad-enabled tensors is forbidden in Gloo, because this is not differentiable. Workaround
  • Loading branch information
blefaudeux committed Sep 22, 2020
1 parent d80c38f commit b488dcf
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ run_oss_benchmark: &run_oss_benchmark
name: Run OSS Benchmark
command: |
python benchmarks/oss.py
python benchmarks/oss.py --gloo
run_oss_ddp_benchmark: &run_oss_ddp_benchmark
- run:
Expand Down
22 changes: 12 additions & 10 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim.oss import OSS

BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
OPTIM = torch.optim.RMSprop


def dist_init(rank, world_size):
dist.init_process_group(
backend=BACKEND, init_method="tcp://localhost:29501", rank=rank, world_size=world_size, store=None
)
def dist_init(rank, world_size, backend):
print(f"Using backend: {backend}")
dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)


def get_problem(rank, data_size, batch_size):
Expand All @@ -47,11 +45,11 @@ def collate(inputs: List[Any]):


def train_oss_ddp(
rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200,
rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200, backend: str = "gloo",
):

# DDP
dist_init(rank, world_size)
dist_init(rank, world_size, backend)

# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
Expand Down Expand Up @@ -120,13 +118,14 @@ def train(
num_epochs: int = 10,
batch_size: int = 32,
data_size: int = 200,
backend: str = "gloo",
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
):
# DDP
dist_init(rank, world_size)
dist_init(rank, world_size, backend)

# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
Expand Down Expand Up @@ -214,26 +213,28 @@ def closure():
parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float)
parser.add_argument("--gloo", action="store_true", default=False)

# beta - test oss_ddp
parser.add_argument("--oss_ddp", action="store_true", default=False)

args = parser.parse_args()
print(f"Benchmark arguments: {args}")

backend = "nccl" if not args.gloo or not torch.cuda.is_available() else "gloo"
if args.oss_ddp:
print("\nBenchmark OSS DDP")
mp.spawn(
train_oss_ddp,
args=(args.world_size, args.epochs, args.batch_size, args.data_size),
args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend),
nprocs=args.world_size,
join=True,
)
else:
print("\nBenchmark vanilla optimizer")
mp.spawn(
train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend, False, False),
nprocs=args.world_size,
join=True,
)
Expand All @@ -246,6 +247,7 @@ def closure():
args.epochs,
args.batch_size,
args.data_size,
backend,
True,
args.check_regression,
args.reference_speed,
Expand Down
12 changes: 11 additions & 1 deletion fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,22 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) ->

# Sync all the states. Broadcast requests are issued async, we check completeness before moving on
requests = []
requires_grad = []
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
# NOTE: Broadcast is in-place and not differentiable
# Gloo will rightly assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank.
requires_grad.append((param, param.requires_grad))
param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=rank, group=self.group, async_op=True))

_ = list(map(lambda x: x.wait(), requests))
for fut, req_grad in zip(requests, requires_grad):
fut.wait()
req_grad[0].requires_grad = req_grad[1]

return loss

def local_state_dict(self) -> dict:
Expand Down

0 comments on commit b488dcf

Please sign in to comment.