Skip to content

Commit

Permalink
[feat] Sharded DDP - small refactor and new features (#97)
Browse files Browse the repository at this point in the history
- rename oss_ddp to ShardedDataParallel
- some refactoring
- ShardedDataParallel owns the sharded optimizer, exposed if need be
- some small perf bumps
  • Loading branch information
blefaudeux committed Sep 17, 2020
1 parent 2ddce57 commit 49a198c
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 123 deletions.
8 changes: 8 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ run_oss_benchmark: &run_oss_benchmark
command: |
python benchmarks/oss.py
run_oss_ddp_benchmark: &run_oss_ddp_benchmark
- run:
name: Run OSS DDP Benchmark
command: |
python benchmarks/oss.py --oss_ddp
# -------------------------------------------------------------------------------------
# Jobs to run
# -------------------------------------------------------------------------------------
Expand Down Expand Up @@ -252,6 +258,8 @@ jobs:

- <<: *run_oss_benchmark

- <<: *run_oss_ddp_benchmark



workflows:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
- id: end-of-file-fixer

- repo: https://github.com/ambv/black
rev: stable
rev: 19.10b0
hooks:
- id: black
language_version: python3.6
Expand Down
174 changes: 130 additions & 44 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchvision.models import resnet101
from torchvision.transforms import ToTensor

from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim.oss import OSS

BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
Expand All @@ -28,21 +29,7 @@ def dist_init(rank, world_size):
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)


def train(
rank: int,
world_size: int,
num_epochs: int = 10,
batch_size: int = 32,
data_size: int = 200,
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
):

# DDP
dist_init(rank, world_size)

def get_problem(rank, data_size, batch_size):
# Standard RN101
model = resnet101(pretrained=False, progress=True).to(rank)

Expand All @@ -57,14 +44,101 @@ def collate(inputs: List[Any]):
dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
)
loss_fn = nn.CrossEntropyLoss()
return model, dataloader, loss_fn


def train_oss_ddp(
rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200,
):

# DDP
dist_init(rank, world_size)

# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

ddp = ShardedDataParallel(
module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size
)
optimizer = ddp.optimizer

# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank)

# Dummy training loop
torch.cuda.synchronize(rank)
training_start = time.monotonic()
model.train()

measurements = []

for epoch in range(num_epochs):
epoch_start = time.monotonic()

for batch in dataloader:

def closure():
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss /= world_size
loss.backward()
if dist.get_rank() == 0:
print(f"Loss: {loss.item()}")

ddp.reduce() # Send the gradients to the appropriate shards
return loss

optimizer.step(closure)

epoch_end = time.monotonic()

measurements.append(data_size / (epoch_end - epoch_start))
if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")

torch.cuda.synchronize(rank)
training_stop = time.monotonic()
img_per_sec = data_size / (training_stop - training_start) * num_epochs
max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")


def train(
rank: int,
world_size: int,
num_epochs: int = 10,
batch_size: int = 32,
data_size: int = 200,
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
):
# DDP
dist_init(rank, world_size)

# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

# Shard the optimizer
optimizer: torch.optim.Optimizer = OSS(
params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9
) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
optimizer: torch.optim.Optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
if use_oss
else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
)

# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank)

# Dummy training loop
torch.cuda.synchronize(rank)
Expand Down Expand Up @@ -95,9 +169,9 @@ def closure():
# Check the checkpointing in the case of the OSS optimizer
# Memory usage could spill over from there
optimizer = cast(OSS, optimizer)
# optimizer.consolidate_state_dict()
optimizer.consolidate_state_dict()
if dist.get_rank() == 0:
# _ = optimizer.state_dict()
_ = optimizer.state_dict()
print("... State dict collected")

measurements.append(data_size / (epoch_end - epoch_start))
Expand Down Expand Up @@ -137,30 +211,42 @@ def closure():
parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float)

# beta - test oss_ddp
parser.add_argument("--oss_ddp", action="store_true", default=False)

args = parser.parse_args()
print(f"Benchmark arguments: {args}")

print("\nBenchmark vanilla optimizer")
mp.spawn(
train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
nprocs=args.world_size,
join=True,
)

print("\nBenchmark OSS")
mp.spawn(
train,
args=(
args.world_size,
args.epochs,
args.batch_size,
args.data_size,
True,
args.check_regression,
args.reference_speed,
args.reference_memory,
),
nprocs=args.world_size,
join=True,
)
if args.oss_ddp:
print("\nBenchmark OSS DDP")
mp.spawn(
train_oss_ddp,
args=(args.world_size, args.epochs, args.batch_size, args.data_size),
nprocs=args.world_size,
join=True,
)
else:
print("\nBenchmark vanilla optimizer")
mp.spawn(
train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
nprocs=args.world_size,
join=True,
)

print("\nBenchmark OSS")
mp.spawn(
train,
args=(
args.world_size,
args.epochs,
args.batch_size,
args.data_size,
True,
args.check_regression,
args.reference_speed,
args.reference_memory,
),
nprocs=args.world_size,
join=True,
)
2 changes: 1 addition & 1 deletion fairscale/nn/data_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from .oss_ddp import OssDdp
from .sharded_ddp import ShardedDataParallel

0 comments on commit 49a198c

Please sign in to comment.