Skip to content

Commit

Permalink
Attach AmoebaNet-D performance benchmark as example
Browse files Browse the repository at this point in the history
  • Loading branch information
zep-hyr authored and sublee committed Aug 2, 2019
1 parent 14bbc9b commit 2535a5a
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 13 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
- lint examples/resnet
- lint examples/resnet101_performance_benchmark
- lint examples/amoebanet
- lint examples/amoebanetd_performance_benchmark

# Test with various Python and PyTorch versions.
install:
Expand Down
6 changes: 3 additions & 3 deletions README.ko.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ Parallelism이 적용되지 않고 Checkpointing 오버헤드만 있어서 naive
실험 | torchgpipe | GPipe (논문)
---------- | -----: | -----:
naive-2 | 1x | 1x
pipeline-2 | 1.442x | 1.156x
pipeline-4 | 2.094x | 2.483x
pipeline-8 | 2.463x | 3.442x
pipeline-2 | 1.434x | 1.156x
pipeline-4 | 2.049x | 2.483x
pipeline-8 | 2.424x | 3.442x

GPipe 논문의 그림3 (a)에 보고된 AmoebaNet-D 학습 속도 벤치마크 비교에선
torchgpipe와 GPipe간 다소 차이가 있습니다. 이는 TensorFlow로 구현된
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ The reproducible code can be found in
Experiment | torchgpipe | GPipe (original)
---------- | -----: | -----:
naive-2 | 1x | 1x
pipeline-2 | 1.442x | 1.156x
pipeline-4 | 2.094x | 2.483x
pipeline-8 | 2.463x | 3.442x
pipeline-2 | 1.434x | 1.156x
pipeline-4 | 2.049x | 2.483x
pipeline-8 | 2.424x | 3.442x

The table shows the reproduced performance benchmark on AmoebaNet-D, as
reported in Figure 3(a) of the paper. But there is some difference between
Expand Down
19 changes: 12 additions & 7 deletions docs/benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ pipeline-4 230.216 samples/sec 2.291x
pipeline-8 312.945 samples/sec 3.114x
========== =================== =======

The code which is reproducible on Tesla P40 GPUs, and the experiment details
The code is reproducible on Tesla P40 GPUs, and the experiment details
can be found in `examples/resnet101_performance_benchmark`_.

.. _examples/resnet101_performance_benchmark:
https://github.com/kakaobrain/torchgpipe/
tree/master/examples/resnet101_performance_benchmark
https://github.com/kakaobrain/torchgpipe/tree/master/examples/resnet101_performance_benchmark

AmoebaNet-D
~~~~~~~~~~~
Expand All @@ -33,12 +32,18 @@ AmoebaNet-D Performance Benchmark
========== =================== =======
Experiment Throughput Speedup
========== =================== =======
naive-2 ___.___ samples/sec 1.000x
pipeline-2 ___.___ samples/sec 1.442x
pipeline-4 ___.___ samples/sec 2.094x
pipeline-8 ___.___ samples/sec 2.463x
naive-2 14.188 samples/sec 1.000x
pipeline-2 20.346 samples/sec 1.434x
pipeline-4 29.074 samples/sec 2.049x
pipeline-8 34.392 samples/sec 2.424x
========== =================== =======

The code is reproducible on Tesla P40 GPUs, and the experiment details
can be found in `examples/amoebanetd_performance_benchmark`_.

.. _examples/amoebanetd_performance_benchmark:
https://github.com/kakaobrain/torchgpipe/tree/master/examples/amoebanetd_performance_benchmark

AmoebaNet-D Memory Benchmark
----------------------------

Expand Down
51 changes: 51 additions & 0 deletions examples/amoebanetd_performance_benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# AmoebaNet-D Performance Benchmark

This example reproduces performance benchmark on AmoebaNet-D, as reported in
Figure 3(a) of the paper. But there is some difference between torchgpipe and
GPipe. We believe that this difference is not caused by the difference of
torchgpipe and GPipe, rather by reimplementing the AmoebaNet-D model in
TensorFlow for PyTorch. Results will be updated whenever a stable and
reproducible AmoebaNet-D in PyTorch is available.

The benchmark cares of only training performance rather than the model's
accuracy. The batch size is adjusted to achieve higher throughput without any
large batch training tricks. This example also doesn't feed actual dataset like
ImageNet or CIFAR-100. Instead, a fake dataset with 50k 3×224×224 tensors is
used to eliminate data loading overhead.

Every experiment setting is optimized for Tesla P40 GPUs.

## Result

Experiment | Throughput | Speed up
---------- | ----------------: | -------:
naive-2 | 14.18 samples/sec | 1.000x
pipeline-2 | 20.34 samples/sec | 1.434x
pipeline-4 | 29.07 samples/sec | 2.049x
pipeline-8 | 34.39 samples/sec | 2.424x

## Optimized Environment

- Python 3.6.7
- PyTorch 1.1.0
- CUDA 9.0.176
- 8 Tesla P40 GPUs
- 8+ Intel E5-2650 v4 CPUs

## To Reproduce

First, resolve the dependencies. We highly recommend to use a separate virtual
environment only for this benchmark:

```sh
$ pip install -r requirements.txt
```

Then, run each benchmark:

```sh
$ python main.py naive-2
$ python main.py pipeline-2
$ python main.py pipeline-4
$ python main.py pipeline-8
```
1 change: 1 addition & 0 deletions examples/amoebanetd_performance_benchmark/amoebanet
253 changes: 253 additions & 0 deletions examples/amoebanetd_performance_benchmark/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
"""AmoebaNet-D Performance Benchmark"""
import platform
import random
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, cast

import click
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import SGD
from torch.utils.data import DataLoader

from amoebanet import amoebanetd
from torchgpipe import GPipe

Stuffs = Tuple[nn.Module, int, List[torch.device]] # (model, batch_size, devices)
Experiment = Callable[[nn.Module, List[int]], Stuffs]


class Experiments:

@staticmethod
def naive2(model: nn.Module, devices: List[int]) -> Stuffs:
batch_size = 47
balance = [84, 241]

model = cast(nn.Sequential, model)
# GPipe with chunks=1, checkpoint='never' is equivalent to a typical model parallel.
model = GPipe(model, balance, devices=devices, chunks=1, checkpoint='never')
return model, batch_size, list(model.devices)

@staticmethod
def pipeline2(model: nn.Module, devices: List[int]) -> Stuffs:
batch_size = 803
chunks = 48
balance = [144, 181]

model = cast(nn.Sequential, model)
model = GPipe(model, balance, devices=devices, chunks=chunks)
return model, batch_size, list(model.devices)

@staticmethod
def pipeline4(model: nn.Module, devices: List[int]) -> Stuffs:
batch_size = 378
chunks = 32
balance = [78, 77, 92, 78]

model = cast(nn.Sequential, model)
model = GPipe(model, balance, devices=devices, chunks=chunks)
return model, batch_size, list(model.devices)

@staticmethod
def pipeline8(model: nn.Module, devices: List[int]) -> Stuffs:
batch_size = 216
chunks = 32
balance = [43, 35, 36, 38, 43, 45, 46, 39]

model = cast(nn.Sequential, model)
model = GPipe(model, balance, devices=devices, chunks=chunks)
return model, batch_size, list(model.devices)


EXPERIMENTS: Dict[str, Experiment] = {
'naive-2': Experiments.naive2,
'pipeline-2': Experiments.pipeline2,
'pipeline-4': Experiments.pipeline4,
'pipeline-8': Experiments.pipeline8,
}


class RandomDataset(torch.utils.data.Dataset):
def __len__(self) -> int:
return 50000

def __getitem__(self, i: int) -> Tuple[torch.Tensor, int]:
return torch.rand(3, 224, 224), random.randrange(10)


BASE_TIME: float = 0


def hr() -> None:
"""Prints a horizontal line."""
width, _ = click.get_terminal_size()
click.echo('-' * width)


def log(msg: str, clear: bool = False, nl: bool = True) -> None:
"""Prints a message with elapsed time."""
if clear:
# Clear the output line to overwrite.
width, _ = click.get_terminal_size()
click.echo('\b\r', nl=False)
click.echo(' ' * width, nl=False)
click.echo('\b\r', nl=False)

t = time.time() - BASE_TIME
h = t // 3600
t %= 3600
m = t // 60
t %= 60
s = t

click.echo('%02d:%02d:%02d | ' % (h, m, s), nl=False)
click.echo(msg, nl=nl)


def parse_devices(ctx: Any, param: Any, value: Optional[str]) -> List[int]:
if value is None:
return list(range(torch.cuda.device_count()))
return [int(x) for x in value.split(',')]


@click.command()
@click.pass_context
@click.argument(
'experiment',
type=click.Choice(sorted(EXPERIMENTS.keys())),
)
@click.option(
'--epochs', '-e',
type=int,
default=10,
help='Number of epochs (default: 10)',
)
@click.option(
'--skip-epochs', '-k',
type=int,
default=1,
help='Number of epochs to skip in result (default: 1)',
)
@click.option(
'--devices', '-d',
metavar='0,1,2,3',
callback=parse_devices,
help='Device IDs to use (default: all CUDA devices)',
)
def cli(ctx: click.Context,
experiment: str,
epochs: int,
skip_epochs: int,
devices: List[int],
) -> None:
"""AmoebaNet-D Performance Benchmark"""
if skip_epochs >= epochs:
ctx.fail('--skip-epochs=%d must be less than --epochs=%d' % (skip_epochs, epochs))

model: nn.Module = amoebanetd(num_classes=10)

f = EXPERIMENTS[experiment]
try:
model, batch_size, _devices = f(model, devices)
except ValueError as exc:
# Examples:
# ValueError: too few devices to hold given partitions (devices: 1, paritions: 2)
ctx.fail(str(exc))

optimizer = SGD(model.parameters(), lr=0.1)

in_device = _devices[0]
out_device = _devices[-1]

# This experiment cares about only training performance, rather than
# accuracy. To eliminate any overhead due to data loading, we use a fake
# dataset with random 224x224 images over 10 labels.
dataset = RandomDataset()
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=1,
pin_memory=True,
drop_last=False,
)

# HEADER ======================================================================================

title = '%s, %d-%d epochs' % (experiment, skip_epochs+1, epochs)
click.echo(title)
click.echo('python: %s, torch: %s, cudnn: %s, cuda: %s, gpu: %s' % (
platform.python_version(),
torch.__version__,
torch.backends.cudnn.version(),
torch.version.cuda,
torch.cuda.get_device_name(in_device)))

# TRAIN =======================================================================================

global BASE_TIME
BASE_TIME = time.time()

def run_epoch(epoch: int) -> Tuple[float, float]:
torch.cuda.synchronize(in_device)
tick = time.time()

data_trained = 0
for i, (input, target) in enumerate(loader):
data_trained += len(input)

input = input.to(in_device, non_blocking=True)
target = target.to(out_device, non_blocking=True)

output = model(input)
loss = F.cross_entropy(output, target)
loss.backward()

optimizer.step()
optimizer.zero_grad()

# 00:01:02 | 1/20 epoch (42%) | 200.000 samples/sec (estimated)
percent = i / len(loader) * 100
throughput = data_trained / (time.time()-tick)
log('%d/%d epoch (%d%%) | %.3f samples/sec (estimated)'
'' % (epoch+1, epochs, percent, throughput), clear=True, nl=False)

torch.cuda.synchronize(in_device)
tock = time.time()

# 00:02:03 | 1/20 epoch | 200.000 samples/sec, 123.456 sec/epoch
elapsed_time = tock - tick
throughput = len(dataset) / elapsed_time
log('%d/%d epoch | %.3f samples/sec, %.3f sec/epoch'
'' % (epoch+1, epochs, throughput, elapsed_time), clear=True)

return throughput, elapsed_time

throughputs = []
elapsed_times = []

hr()
for epoch in range(epochs):
throughput, elapsed_time = run_epoch(epoch)

if epoch < skip_epochs:
continue

throughputs.append(throughput)
elapsed_times.append(elapsed_time)
hr()

# RESULT ======================================================================================

# pipeline-4, 2-10 epochs | 200.000 samples/sec, 123.456 sec/epoch (average)
n = len(throughputs)
throughput = sum(throughputs) / n
elapsed_time = sum(elapsed_times) / n
click.echo('%s | %.3f samples/sec, %.3f sec/epoch (average)'
'' % (title, throughput, elapsed_time))


if __name__ == '__main__':
cli()
2 changes: 2 additions & 0 deletions examples/amoebanetd_performance_benchmark/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
click==7.0
torch==1.1.0
1 change: 1 addition & 0 deletions examples/amoebanetd_performance_benchmark/torchgpipe

0 comments on commit 2535a5a

Please sign in to comment.