Skip to content

Commit

Permalink
[fix] OSS benchmark cleanup (#109)
Browse files Browse the repository at this point in the history
- small benchmark refactor, only one for all backends and ddp
- deterministic, enforce alignment with pytorch ddp
  • Loading branch information
blefaudeux committed Sep 24, 2020
1 parent 7c5203e commit 5355347
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 98 deletions.
11 changes: 2 additions & 9 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,9 @@ run_oss_benchmark: &run_oss_benchmark
- run:
name: Run OSS Benchmark
command: |
python benchmarks/oss.py
python benchmarks/oss.py --gloo
python benchmarks/oss.py --check_regression
python benchmarks/oss.py --gloo --optim_type oss
run_oss_ddp_benchmark: &run_oss_ddp_benchmark
- run:
name: Run OSS DDP Benchmark
command: |
python benchmarks/oss.py --oss_ddp
# -------------------------------------------------------------------------------------
# Jobs to run
Expand Down Expand Up @@ -259,8 +254,6 @@ jobs:

- <<: *run_oss_benchmark

- <<: *run_oss_ddp_benchmark



workflows:
Expand Down
169 changes: 80 additions & 89 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@


import argparse
from enum import Enum
import math
import time
from typing import Any, List, Optional, cast

import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
Expand Down Expand Up @@ -44,74 +46,6 @@ def collate(inputs: List[Any]):
return model, dataloader, loss_fn


def train_oss_ddp(
rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200, backend: str = "gloo",
):

# DDP
dist_init(rank, world_size, backend)

# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

ddp = ShardedDataParallel(
module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size
)
optimizer = ddp.optimizer

# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank)

# Dummy training loop
torch.cuda.synchronize(rank)
training_start = time.monotonic()
model.train()

measurements = []

for epoch in range(num_epochs):
epoch_start = time.monotonic()

for batch in dataloader:

def closure():
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
loss /= world_size
loss.backward()

dist.all_reduce(loss, op=dist.ReduceOp.SUM)

if dist.get_rank() == 0:
print(f"Loss: {loss.item()}")

ddp.reduce() # Send the gradients to the appropriate shards
return loss

optimizer.step(closure)

epoch_end = time.monotonic()

measurements.append(data_size / (epoch_end - epoch_start))
if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")

torch.cuda.synchronize(rank)
training_stop = time.monotonic()
img_per_sec = data_size / (training_stop - training_start) * num_epochs
max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")


def train(
rank: int,
world_size: int,
Expand All @@ -120,22 +54,44 @@ def train(
data_size: int = 200,
backend: str = "gloo",
use_oss: bool = True,
use_sdp: bool = False,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
reference_loss: float = -1.0,
):
assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS"
# DDP
dist_init(rank, world_size, backend)

# Setup
torch.cuda.set_device(rank)
torch.cuda.manual_seed(0)
torch.manual_seed(0) # also sets the cuda seed
np.random.seed(0)

if backend == "nccl":
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

# Shard the optimizer
optimizer: torch.optim.Optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
if use_oss
else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
)
optimizer: Optional[torch.optim.Optimizer] = None

if use_sdp:
ddp = ShardedDataParallel(
module=model, optimizer=OPTIM, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size,
)
ddp.train()
optimizer = ddp.optimizer
model = ddp
else:
optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
if use_oss
else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
)

# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank)
Expand All @@ -162,6 +118,9 @@ def closure():

dist.all_reduce(loss, op=dist.ReduceOp.SUM)

if use_sdp:
ddp.reduce() # Send the gradients to the appropriate shards

return loss

final_loss = optimizer.step(closure)
Expand All @@ -179,7 +138,7 @@ def closure():

measurements.append(data_size / (epoch_end - epoch_start))
if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss}")
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")

torch.cuda.synchronize(rank)
training_stop = time.monotonic()
Expand All @@ -198,11 +157,19 @@ def closure():
if use_oss and check_regression and dist.get_rank() == 0:
assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
assert abs(cast(float, final_loss) - reference_loss) < 1e-3, "Loss regression detected"

print("[Regression Test] VALID")


if __name__ == "__main__":

class OptimType(str, Enum):
vanilla = "pytorch"
oss = "oss"
oss_sdp = "oss_sdp"
everyone = "everyone"

parser = argparse.ArgumentParser(
description="Benchmark the optimizer state sharding, on a typical computer vision workload"
)
Expand All @@ -211,34 +178,38 @@ def closure():
parser.add_argument("--batch_size", action="store", default=32, type=int)
parser.add_argument("--data_size", action="store", default=512, type=int)
parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
parser.add_argument("--reference_speed", action="store", default=29.7, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float)
parser.add_argument("--reference_loss", action="store", default=0.866, type=float)
parser.add_argument(
"--optim_type", type=OptimType, choices=[o.value for o in OptimType], default=OptimType.everyone
)
parser.add_argument("--gloo", action="store_true", default=False)

# beta - test oss_ddp
parser.add_argument("--oss_ddp", action="store_true", default=False)

args = parser.parse_args()
print(f"Benchmark arguments: {args}")

backend = "nccl" if not args.gloo or not torch.cuda.is_available() else "gloo"
if args.oss_ddp:
print("\nBenchmark OSS DDP")
mp.spawn(
train_oss_ddp,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend),
nprocs=args.world_size,
join=True,
)
else:

if args.optim_type == OptimType.vanilla or args.optim_type == OptimType.everyone:
print("\nBenchmark vanilla optimizer")
mp.spawn(
train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, backend, False, False),
args=(
args.world_size,
args.epochs,
args.batch_size,
args.data_size,
backend,
False, # OSS
False, # SDP
False, # no regression check
),
nprocs=args.world_size,
join=True,
)

if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
print("\nBenchmark OSS")
mp.spawn(
train,
Expand All @@ -248,10 +219,30 @@ def closure():
args.batch_size,
args.data_size,
backend,
True,
True, # OSS
False, # SDP
args.check_regression,
args.reference_speed,
args.reference_memory,
args.reference_loss,
),
nprocs=args.world_size,
join=True,
)

if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
print("\nBenchmark OSS DDP")
mp.spawn(
train,
args=(
args.world_size,
args.epochs,
args.batch_size,
args.data_size,
backend,
True, # OSS
True, # SDP
False, # no regression check
),
nprocs=args.world_size,
join=True,
Expand Down

0 comments on commit 5355347

Please sign in to comment.