Skip to content

Commit

Permalink
Resubmit: [Gradient Compression] Implement the original layerwise Pow…
Browse files Browse the repository at this point in the history
…erSGD (pytorch#49639)

Summary:
Pull Request resolved: pytorch#49639

Resubmit pytorch#49417 with a fix for distributed_test.

The previous submission broke a multi-gpu test that runs on 4 GPUs. Since this test only runs on master, couldn't detect it before the submission.

The real diff is:
pytorch@4ca1014

This time I have verified that the previous failed test `pytorch_linux_xenial_cuda10_2_cudnn7_py3_multigpu_test` could pass after creating a PR (pytorch#49651) from a separate branch:
https://app.circleci.com/pipelines/github/pytorch/pytorch/253644/workflows/c1c02b70-0877-40e6-8b4c-61f60f6b70ed/jobs/9768079

ghstack-source-id: 118969912

Test Plan: buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_DistributedDataParallel_powerSGD_ddp_comm_hook、

Reviewed By: mrshenli

Differential Revision: D25654961

fbshipit-source-id: 2a45c8ceb9bdb54ff7309a8b66ec87e913e0150e
  • Loading branch information
Yi Wang authored and hwangdeyu committed Dec 23, 2020
1 parent 0eecd3d commit 5ea5c01
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 2 deletions.
11 changes: 11 additions & 0 deletions torch/distributed/algorithms/ddp_comm_hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ class DDPCommHookType(Enum):
comm_hook=powerSGD.powerSGD_hook,
matrix_approximation_rank=2,
)
# Batching can lead to a faster training at the cost of accuracy.
BATCHED_POWER_SGD = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.batched_powerSGD_hook,
matrix_approximation_rank=1,
)
BATCHED_POWER_SGD_RANK2 = partial(
_powerSGD_comm_hook_wrapper,
comm_hook=powerSGD.batched_powerSGD_hook,
matrix_approximation_rank=2,
)


def register_ddp_comm_hook(
Expand Down
187 changes: 185 additions & 2 deletions torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,193 @@ def powerSGD_hook(
bucket,
) -> torch.futures.Future:
"""
This DDP communication hook implements a simplified PowerSGD gradient compression
This DDP communication hook implements the original PowerSGD gradient compression
algorithm described in https://arxiv.org/abs/1905.13727.
Once gradient tensors are aggregated across all workers, this hook applies
compression as follows:
1) Views the input flattened 1D gradient tensor as two groups of per-parameter tensors:
high-rank tensors and vector-like rank-1 tensors (for biases).
2) Handles rank-1 tensors by allreducing them without compression:
2.1) Allocate contiguous memory for those rank-1 tensors,
and allreduces all the rank-1 tensors as a batch, without compression;
2.2) Copies the indvidual rank-1 tensors from the contiguous memory back to the input tensor.
3) Handles high-rank tensors by PowerSGD compression:
3.1) For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M,
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
3.2) Computes each P in Ps, which is equal to MQ;
3.3) Allreduces Ps as a batch;
3.4) Orthogonizes each P in Ps;
3.5) Computes each Q in Qs, which is approximately equal to M^TP;
3.6) Allreduces Qs as a batch;
3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
one left multiplication and one right multiplication.
For warm start, can take one such step at a time, and alternate between them.
Arguments:
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
Note that since DDP comm hook only supports single process single device mode at this time,
only exactly one tensor is stored in this bucket.
Returns:
Future handler of the communication, which updates the gradients in place.
Example::
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
"""
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)

# The input tensor is a flattened 1D tensor.
input_tensor = bucket.get_tensors()[0]
device = input_tensor.device
dtype = input_tensor.dtype
# Unflatten the input tensor into per-parameter tensors, for layer-wise compression.
tensors = [
input_tensor[offset : offset + length].view(sizes)
for offset, length, sizes in zip(
bucket.get_offsets(), bucket.get_lengths(), bucket.get_sizes_list()
)
]

# Step I: Handle rank-1 tensors.
# Allocate contiguous memory for rank-1 tensors to allreduce them without compression efficiently.
rank1_tensors = [tensor for tensor in tensors if tensor.ndimension() <= 1]
rank1_tensors_memory = (
torch.cat([tensor.view(-1) for tensor in rank1_tensors])
if rank1_tensors
else torch.tensor([], device=device)
)

# Step II: Handle high-rank tensors.
# Allocate contiguous memory for Ps and Qs to allreduce compressed high-rank tensors efficiently.
high_rank_tensors = [
tensor.view(tensor.shape[0], -1)
for tensor in tensors
if tensor.ndimension() > 1
]
total_Ps_size = 0
ps_memory = None # TODO(wayi): Store it in a dict of PowerState for warm-up.
total_Qs_size = 0
qs_memory = None # TODO(wayi): Store it in a dict of PowerState for warm-up.
for tensor in high_rank_tensors:
n, m = tensor.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
total_Ps_size += n * matrix_approximation_rank
total_Qs_size += m * matrix_approximation_rank
ps_memory = torch.empty(total_Ps_size, device=device, dtype=dtype)
qs_memory = torch.empty(total_Qs_size, device=device, dtype=dtype)

# Create Ps and Qs that point to the allocated memory.
ps = []
qs = []
p_idx = 0
q_idx = 0
for tensor in high_rank_tensors:
n, m = tensor.shape
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
ps.append(
ps_memory[p_idx : p_idx + n * matrix_approximation_rank].view(
n, matrix_approximation_rank
)
)
qs.append(
qs_memory[q_idx : q_idx + m * matrix_approximation_rank].view(
m, matrix_approximation_rank
)
)
p_idx += n * matrix_approximation_rank
q_idx += m * matrix_approximation_rank

# Initialize and then orthogonalize Qs.
with torch.random.fork_rng(devices=[]):
# Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
# The seed makes sure that the initial random values are the same across all the DDP replicas.
# Such seed should differ at every step.
# Since it is very slow to fork RNG state across all the CUDA devices,
# only fork on CPU and then move the generated tensor to the CUDA device.
torch.manual_seed(state.rng.randint(1_000_000_000))
for q in qs:
q.data = torch.randn(
*q.shape,
device="cpu",
dtype=dtype,
).to(device)
_orthogonalize(q)

# Compute Ps.
for tensor, q, p in zip(high_rank_tensors, qs, ps):
torch.matmul(tensor, q, out=p)

# This allreduce is only applied to rank-1 tensors,
# so it should have been kicked off before the above computation on the high-rank tensors to hide more communication costs.
# However, this somehow requires a separate future chain at this time.
allreduce_contiguous_rank1_tensors_fut = dist.all_reduce(
rank1_tensors_memory, group=group_to_use, async_op=True
).get_future()

def unpack_rank1_tensors_and_allreduce_ps(fut):
rank1_tensors_memory = fut.value()[0].div_(world_size)
idx = 0
for tensor in rank1_tensors:
tensor.copy_(rank1_tensors_memory[idx : idx + tensor.shape[0]])
idx += tensor.shape[0]

# Since these Ps will be orthogonized later, no need to divide them by world size.
return [
dist.all_reduce(ps_memory, group=group_to_use, async_op=True)
.get_future()
.wait()[0]
]

def compute_qs(fut):
ps_memory = fut.wait()[0]
for p in ps:
_orthogonalize(p)

# Compute Qs.
for tensor, p, q in zip(high_rank_tensors, ps, qs):
torch.matmul(tensor.t(), p, out=q)

# Allreduce Qs.
return [
dist.all_reduce(qs_memory, group=group_to_use, async_op=True)
.get_future()
.wait()[0]
]

def decompress(fut):
qs_memory = fut.wait()[0].div_(world_size)

for p, q, tensor in zip(ps, qs, high_rank_tensors):
torch.matmul(p, q.t(), out=tensor)
assert not torch.any(torch.isnan(tensor))
return [input_tensor]

return (
allreduce_contiguous_rank1_tensors_fut.then(
unpack_rank1_tensors_and_allreduce_ps
)
.then(compute_qs)
.then(decompress)
)


def batched_powerSGD_hook(
state: PowerSGDState,
bucket,
) -> torch.futures.Future:
"""
This DDP communication hook implements a simplified PowerSGD gradient compression
algorithm described in https://arxiv.org/abs/1905.13727.
Once gradient tensors are aggregated across all workers, this hook applies
compression to the flattened input tensor that batches per-parameter tensors as follows:
1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
2) Creates two low-rank tensors P and Q for decomposing M,
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
Expand All @@ -105,7 +288,7 @@ def powerSGD_hook(
Example::
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
"""
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
Expand Down
49 changes: 49 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
import torch.cuda
import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars
import torch.nn as nn
Expand Down Expand Up @@ -2819,6 +2820,54 @@ def test_DistributedDataParallel_non_default_stream(self):
msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
)

@unittest.skipIf(
BACKEND != "nccl",
"Only NCCL backend supports DDP communication hook",
)
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
@skip_if_rocm
def test_DistributedDataParallel_powerSGD_ddp_comm_hook(self):
stream = torch.cuda.Stream(self.rank)
rank = self.rank
rank_to_GPU = self._init_multigpu_helper()
gpus = list(rank_to_GPU[rank])
with torch.cuda.stream(stream):
net = torch.nn.parallel.DistributedDataParallel(
torch.nn.Linear(1, 5).to(rank), device_ids=[rank]
)
process_group = torch.distributed.new_group(gpus)
state = powerSGD.PowerSGDState(
process_group=process_group, matrix_approximation_rank=1
)
net.register_comm_hook(state=state, hook=powerSGD.powerSGD_hook)
# NOTE: batched_powerSGD_hook cannot pass the following test, because it has a lower accuracy.
for i in range(1000):
# Clear gradients manually.
grad = net.module.weight.grad
if grad is not None:
grad.requires_grad_(False)
grad.zero_()
# Forward + BW
batch = torch.tensor([rank]).float().cuda(rank)
loss = net(batch).sum()
loss.backward()
# For each worker, the gradient on the weight should be worker_rank.
grad = net.module.weight.grad
avg = grad.clone()
# All-reducing the gradient averages should give us the gradient
# average. If not, then one of the workers has not correctly
# written back the averaged gradient before this all-reduce call.
dist.all_reduce(avg)
world_size = int(os.environ["WORLD_SIZE"])
avg.div_(world_size)
expected_grad = sum(i for i in range(world_size)) / world_size
self.assertEqual(
avg[0, 0],
expected_grad,
msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
)


@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
"Only Nccl & Gloo backend support DistributedDataParallel")
@skip_if_no_gpu
Expand Down

0 comments on commit 5ea5c01

Please sign in to comment.