Skip to content

Commit

Permalink
[feat] OSS: adding a --profile option to the benchmark (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Oct 14, 2020
1 parent 37c686e commit 34915bf
Showing 1 changed file with 37 additions and 22 deletions.
59 changes: 37 additions & 22 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import torch
import torch.autograd.profiler as profiler
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
Expand Down Expand Up @@ -49,21 +50,27 @@ def collate(inputs: List[Any]):
return model, dataloader, loss_fn


class OptimType(str, Enum):
vanilla = "pytorch"
oss = "oss"
oss_sdp = "oss_sdp"
everyone = "everyone"


def train(
rank: int,
world_size: int,
num_epochs: int = 10,
batch_size: int = 32,
data_size: int = 200,
backend: str = "gloo",
use_oss: bool = True,
use_sdp: bool = False,
optim_type: OptimType = OptimType.vanilla,
profile: bool = False,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
reference_loss: float = -1.0,
):
assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS"
# DDP
dist_init(rank=rank, world_size=world_size, backend=backend)

Expand All @@ -82,7 +89,7 @@ def train(
# Shard the optimizer
optimizer: Optional[torch.optim.Optimizer] = None

if use_sdp:
if optim_type == OptimType.oss_sdp:
ddp = ShardedDataParallel(
module=model,
optimizer=OPTIM,
Expand All @@ -97,7 +104,7 @@ def train(
model = DDP(model, device_ids=[rank], find_unused_parameters=True) # type: ignore
optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
if use_oss
if optim_type == OptimType.oss
else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
)

Expand All @@ -111,6 +118,7 @@ def train(

measurements = []
final_loss: Optional[float] = -1.0
need_profiling = profile

for epoch in range(num_epochs):
epoch_start = time.monotonic()
Expand All @@ -124,16 +132,29 @@ def closure():
loss /= world_size
loss.backward()

if use_sdp:
if optim_type == OptimType.oss_sdp:
ddp.reduce() # Send the gradients to the appropriate shards

return loss

final_loss = optimizer.step(closure)
if need_profiling:
print("Profiling the run")
with profiler.profile(use_cuda=True) as prof: # type: ignore
with profiler.record_function("batch"):
final_loss = optimizer.step(closure)
print("profiling done, final loss ", cast(float, final_loss))

if rank == 0:
prof.export_chrome_trace(f"{optim_type}_trace.json")

need_profiling = False # only profile once

else:
final_loss = optimizer.step(closure)

epoch_end = time.monotonic()

if use_oss:
if optim_type == OptimType.oss:
# Check the checkpointing in the case of the OSS optimizer
# Memory usage could spill over from there
optimizer = cast(OSS, optimizer)
Expand All @@ -160,7 +181,7 @@ def closure():
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

if use_oss and check_regression and dist.get_rank() == 0:
if check_regression and dist.get_rank() == 0:
assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
assert abs(cast(float, final_loss) - reference_loss) < 1e-3, "Loss regression detected"
Expand All @@ -171,13 +192,6 @@ def closure():


if __name__ == "__main__":

class OptimType(str, Enum):
vanilla = "pytorch"
oss = "oss"
oss_sdp = "oss_sdp"
everyone = "everyone"

parser = argparse.ArgumentParser(
description="Benchmark the optimizer state sharding, on a typical computer vision workload"
)
Expand All @@ -193,6 +207,7 @@ class OptimType(str, Enum):
"--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
)
parser.add_argument("--gloo", action="store_true", default=False)
parser.add_argument("--profile", action="store_true", default=False)

args = parser.parse_args()
print(f"Benchmark arguments: {args}")
Expand All @@ -209,8 +224,8 @@ class OptimType(str, Enum):
args.batch_size,
args.data_size,
backend,
False, # OSS
False, # SDP
OptimType.vanilla,
args.profile,
False, # no regression check
),
nprocs=args.world_size,
Expand All @@ -227,8 +242,8 @@ class OptimType(str, Enum):
args.batch_size,
args.data_size,
backend,
True, # OSS
False, # SDP
OptimType.oss,
args.profile,
args.check_regression,
args.reference_speed,
args.reference_memory,
Expand All @@ -248,8 +263,8 @@ class OptimType(str, Enum):
args.batch_size,
args.data_size,
backend,
True, # OSS
True, # SDP
OptimType.oss_sdp,
args.profile,
False, # FIXME: @lefaudeux - SDP should give the same results
-1, # Not checking SDP for speed regression for now, still slower than OSS
args.reference_memory,
Expand Down

0 comments on commit 34915bf

Please sign in to comment.