Skip to content

Commit

Permalink
[minor] OSS: bring DDP in the benchmark (#130)
Browse files Browse the repository at this point in the history
More realistic benchmarks, comparing apples to apples. DDP/OSS+DDP/OSS+SDP
  • Loading branch information
blefaudeux committed Oct 9, 2020
1 parent 81ac5b2 commit bfd88ca
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ run_oss_benchmark: &run_oss_benchmark
- run:
name: Run OSS Benchmark
command: |
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 21.2 --reference_memory 4220 --reference_loss 0.63
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 13.7 --reference_memory 4390 --reference_loss 0.595
run_oss_gloo: &run_oss_gloo
- run:
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.models import resnet101
Expand Down Expand Up @@ -40,7 +41,9 @@ def collate(inputs: List[Any]):
}

dataloader = DataLoader(
dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
dataset=FakeData(transform=ToTensor(), size=data_size, random_offset=rank),
batch_size=batch_size,
collate_fn=collate,
)
loss_fn = nn.CrossEntropyLoss()
return model, dataloader, loss_fn
Expand Down Expand Up @@ -85,12 +88,13 @@ def train(
optimizer=OPTIM,
optimizer_params={"lr": 1e-4, "momentum": 0.9},
world_size=world_size,
broadcast_buffers=False,
broadcast_buffers=True,
)
ddp.train()
optimizer = ddp.optimizer
model = ddp
else:
model = DDP(model, device_ids=[rank], find_unused_parameters=True) # type: ignore
optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
if use_oss
Expand Down Expand Up @@ -216,7 +220,7 @@ class OptimType(str, Enum):
)

if args.optim_type == OptimType.oss or args.optim_type == OptimType.everyone:
print("\nBenchmark OSS")
print("\nBenchmark OSS with DDP")
mp.spawn(
train,
args=(
Expand All @@ -237,7 +241,7 @@ class OptimType(str, Enum):
)

if args.optim_type == OptimType.oss_sdp or args.optim_type == OptimType.everyone:
print("\nBenchmark OSS DDP")
print("\nBenchmark OSS with SDP")
mp.spawn(
train,
args=(
Expand All @@ -248,7 +252,7 @@ class OptimType(str, Enum):
backend,
True, # OSS
True, # SDP
args.check_regression,
False, # FIXME: @lefaudeux - SDP should give the same results
-1, # Not checking SDP for speed regression for now, still slower than OSS
args.reference_memory,
args.reference_loss,
Expand Down

0 comments on commit bfd88ca

Please sign in to comment.