-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Attach AmoebaNet-D performance benchmark as example
- Loading branch information
Showing
9 changed files
with
327 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# AmoebaNet-D Performance Benchmark | ||
|
||
This example reproduces performance benchmark on AmoebaNet-D, as reported in | ||
Figure 3(a) of the paper. But there is some difference between torchgpipe and | ||
GPipe. We believe that this difference is not caused by the difference of | ||
torchgpipe and GPipe, rather by reimplementing the AmoebaNet-D model in | ||
TensorFlow for PyTorch. Results will be updated whenever a stable and | ||
reproducible AmoebaNet-D in PyTorch is available. | ||
|
||
The benchmark cares of only training performance rather than the model's | ||
accuracy. The batch size is adjusted to achieve higher throughput without any | ||
large batch training tricks. This example also doesn't feed actual dataset like | ||
ImageNet or CIFAR-100. Instead, a fake dataset with 50k 3×224×224 tensors is | ||
used to eliminate data loading overhead. | ||
|
||
Every experiment setting is optimized for Tesla P40 GPUs. | ||
|
||
## Result | ||
|
||
Experiment | Throughput | Speed up | ||
---------- | ----------------: | -------: | ||
naive-2 | 14.18 samples/sec | 1.000x | ||
pipeline-2 | 20.34 samples/sec | 1.434x | ||
pipeline-4 | 29.07 samples/sec | 2.049x | ||
pipeline-8 | 34.39 samples/sec | 2.424x | ||
|
||
## Optimized Environment | ||
|
||
- Python 3.6.7 | ||
- PyTorch 1.1.0 | ||
- CUDA 9.0.176 | ||
- 8 Tesla P40 GPUs | ||
- 8+ Intel E5-2650 v4 CPUs | ||
|
||
## To Reproduce | ||
|
||
First, resolve the dependencies. We highly recommend to use a separate virtual | ||
environment only for this benchmark: | ||
|
||
```sh | ||
$ pip install -r requirements.txt | ||
``` | ||
|
||
Then, run each benchmark: | ||
|
||
```sh | ||
$ python main.py naive-2 | ||
$ python main.py pipeline-2 | ||
$ python main.py pipeline-4 | ||
$ python main.py pipeline-8 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../amoebanet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
"""AmoebaNet-D Performance Benchmark""" | ||
import platform | ||
import random | ||
import time | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, cast | ||
|
||
import click | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn import functional as F | ||
from torch.optim import SGD | ||
from torch.utils.data import DataLoader | ||
|
||
from amoebanet import amoebanetd | ||
from torchgpipe import GPipe | ||
|
||
Stuffs = Tuple[nn.Module, int, List[torch.device]] # (model, batch_size, devices) | ||
Experiment = Callable[[nn.Module, List[int]], Stuffs] | ||
|
||
|
||
class Experiments: | ||
|
||
@staticmethod | ||
def naive2(model: nn.Module, devices: List[int]) -> Stuffs: | ||
batch_size = 47 | ||
balance = [84, 241] | ||
|
||
model = cast(nn.Sequential, model) | ||
# GPipe with chunks=1, checkpoint='never' is equivalent to a typical model parallel. | ||
model = GPipe(model, balance, devices=devices, chunks=1, checkpoint='never') | ||
return model, batch_size, list(model.devices) | ||
|
||
@staticmethod | ||
def pipeline2(model: nn.Module, devices: List[int]) -> Stuffs: | ||
batch_size = 803 | ||
chunks = 48 | ||
balance = [144, 181] | ||
|
||
model = cast(nn.Sequential, model) | ||
model = GPipe(model, balance, devices=devices, chunks=chunks) | ||
return model, batch_size, list(model.devices) | ||
|
||
@staticmethod | ||
def pipeline4(model: nn.Module, devices: List[int]) -> Stuffs: | ||
batch_size = 378 | ||
chunks = 32 | ||
balance = [78, 77, 92, 78] | ||
|
||
model = cast(nn.Sequential, model) | ||
model = GPipe(model, balance, devices=devices, chunks=chunks) | ||
return model, batch_size, list(model.devices) | ||
|
||
@staticmethod | ||
def pipeline8(model: nn.Module, devices: List[int]) -> Stuffs: | ||
batch_size = 216 | ||
chunks = 32 | ||
balance = [43, 35, 36, 38, 43, 45, 46, 39] | ||
|
||
model = cast(nn.Sequential, model) | ||
model = GPipe(model, balance, devices=devices, chunks=chunks) | ||
return model, batch_size, list(model.devices) | ||
|
||
|
||
EXPERIMENTS: Dict[str, Experiment] = { | ||
'naive-2': Experiments.naive2, | ||
'pipeline-2': Experiments.pipeline2, | ||
'pipeline-4': Experiments.pipeline4, | ||
'pipeline-8': Experiments.pipeline8, | ||
} | ||
|
||
|
||
class RandomDataset(torch.utils.data.Dataset): | ||
def __len__(self) -> int: | ||
return 50000 | ||
|
||
def __getitem__(self, i: int) -> Tuple[torch.Tensor, int]: | ||
return torch.rand(3, 224, 224), random.randrange(10) | ||
|
||
|
||
BASE_TIME: float = 0 | ||
|
||
|
||
def hr() -> None: | ||
"""Prints a horizontal line.""" | ||
width, _ = click.get_terminal_size() | ||
click.echo('-' * width) | ||
|
||
|
||
def log(msg: str, clear: bool = False, nl: bool = True) -> None: | ||
"""Prints a message with elapsed time.""" | ||
if clear: | ||
# Clear the output line to overwrite. | ||
width, _ = click.get_terminal_size() | ||
click.echo('\b\r', nl=False) | ||
click.echo(' ' * width, nl=False) | ||
click.echo('\b\r', nl=False) | ||
|
||
t = time.time() - BASE_TIME | ||
h = t // 3600 | ||
t %= 3600 | ||
m = t // 60 | ||
t %= 60 | ||
s = t | ||
|
||
click.echo('%02d:%02d:%02d | ' % (h, m, s), nl=False) | ||
click.echo(msg, nl=nl) | ||
|
||
|
||
def parse_devices(ctx: Any, param: Any, value: Optional[str]) -> List[int]: | ||
if value is None: | ||
return list(range(torch.cuda.device_count())) | ||
return [int(x) for x in value.split(',')] | ||
|
||
|
||
@click.command() | ||
@click.pass_context | ||
@click.argument( | ||
'experiment', | ||
type=click.Choice(sorted(EXPERIMENTS.keys())), | ||
) | ||
@click.option( | ||
'--epochs', '-e', | ||
type=int, | ||
default=10, | ||
help='Number of epochs (default: 10)', | ||
) | ||
@click.option( | ||
'--skip-epochs', '-k', | ||
type=int, | ||
default=1, | ||
help='Number of epochs to skip in result (default: 1)', | ||
) | ||
@click.option( | ||
'--devices', '-d', | ||
metavar='0,1,2,3', | ||
callback=parse_devices, | ||
help='Device IDs to use (default: all CUDA devices)', | ||
) | ||
def cli(ctx: click.Context, | ||
experiment: str, | ||
epochs: int, | ||
skip_epochs: int, | ||
devices: List[int], | ||
) -> None: | ||
"""AmoebaNet-D Performance Benchmark""" | ||
if skip_epochs >= epochs: | ||
ctx.fail('--skip-epochs=%d must be less than --epochs=%d' % (skip_epochs, epochs)) | ||
|
||
model: nn.Module = amoebanetd(num_classes=10) | ||
|
||
f = EXPERIMENTS[experiment] | ||
try: | ||
model, batch_size, _devices = f(model, devices) | ||
except ValueError as exc: | ||
# Examples: | ||
# ValueError: too few devices to hold given partitions (devices: 1, paritions: 2) | ||
ctx.fail(str(exc)) | ||
|
||
optimizer = SGD(model.parameters(), lr=0.1) | ||
|
||
in_device = _devices[0] | ||
out_device = _devices[-1] | ||
|
||
# This experiment cares about only training performance, rather than | ||
# accuracy. To eliminate any overhead due to data loading, we use a fake | ||
# dataset with random 224x224 images over 10 labels. | ||
dataset = RandomDataset() | ||
loader = DataLoader( | ||
dataset, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=1, | ||
pin_memory=True, | ||
drop_last=False, | ||
) | ||
|
||
# HEADER ====================================================================================== | ||
|
||
title = '%s, %d-%d epochs' % (experiment, skip_epochs+1, epochs) | ||
click.echo(title) | ||
click.echo('python: %s, torch: %s, cudnn: %s, cuda: %s, gpu: %s' % ( | ||
platform.python_version(), | ||
torch.__version__, | ||
torch.backends.cudnn.version(), | ||
torch.version.cuda, | ||
torch.cuda.get_device_name(in_device))) | ||
|
||
# TRAIN ======================================================================================= | ||
|
||
global BASE_TIME | ||
BASE_TIME = time.time() | ||
|
||
def run_epoch(epoch: int) -> Tuple[float, float]: | ||
torch.cuda.synchronize(in_device) | ||
tick = time.time() | ||
|
||
data_trained = 0 | ||
for i, (input, target) in enumerate(loader): | ||
data_trained += len(input) | ||
|
||
input = input.to(in_device, non_blocking=True) | ||
target = target.to(out_device, non_blocking=True) | ||
|
||
output = model(input) | ||
loss = F.cross_entropy(output, target) | ||
loss.backward() | ||
|
||
optimizer.step() | ||
optimizer.zero_grad() | ||
|
||
# 00:01:02 | 1/20 epoch (42%) | 200.000 samples/sec (estimated) | ||
percent = i / len(loader) * 100 | ||
throughput = data_trained / (time.time()-tick) | ||
log('%d/%d epoch (%d%%) | %.3f samples/sec (estimated)' | ||
'' % (epoch+1, epochs, percent, throughput), clear=True, nl=False) | ||
|
||
torch.cuda.synchronize(in_device) | ||
tock = time.time() | ||
|
||
# 00:02:03 | 1/20 epoch | 200.000 samples/sec, 123.456 sec/epoch | ||
elapsed_time = tock - tick | ||
throughput = len(dataset) / elapsed_time | ||
log('%d/%d epoch | %.3f samples/sec, %.3f sec/epoch' | ||
'' % (epoch+1, epochs, throughput, elapsed_time), clear=True) | ||
|
||
return throughput, elapsed_time | ||
|
||
throughputs = [] | ||
elapsed_times = [] | ||
|
||
hr() | ||
for epoch in range(epochs): | ||
throughput, elapsed_time = run_epoch(epoch) | ||
|
||
if epoch < skip_epochs: | ||
continue | ||
|
||
throughputs.append(throughput) | ||
elapsed_times.append(elapsed_time) | ||
hr() | ||
|
||
# RESULT ====================================================================================== | ||
|
||
# pipeline-4, 2-10 epochs | 200.000 samples/sec, 123.456 sec/epoch (average) | ||
n = len(throughputs) | ||
throughput = sum(throughputs) / n | ||
elapsed_time = sum(elapsed_times) / n | ||
click.echo('%s | %.3f samples/sec, %.3f sec/epoch (average)' | ||
'' % (title, throughput, elapsed_time)) | ||
|
||
|
||
if __name__ == '__main__': | ||
cli() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
click==7.0 | ||
torch==1.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../torchgpipe |