Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 15, 2024
1 parent 3bca491 commit dff7ba3
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 128 deletions.
2 changes: 1 addition & 1 deletion colossalai/nn/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
relative_step=True,
warmup_init=False,
):
lr=None
lr = None
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
Expand Down
119 changes: 57 additions & 62 deletions colossalai/nn/optimizer/distributed_adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import torch
import torch.distributed as dist

# from torch.optim import Optimizer
from colossalai.interface.optimizer import DistributedOptim

from colossalai.shardformer.layer._operation import _gather, _split
from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor

Expand Down Expand Up @@ -50,14 +50,13 @@ def __init__(
self.data_parallel_group = None
self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor}
self.use_zero = True
self.param_is_dtensor_dict = {} # {id(p): True/False}
self.grad_shape_dict = {} # {id(p): master param shape}
self.factored_dict = {} # {id(p): True/False}

self.param_is_dtensor_dict = {} # {id(p): True/False}
self.grad_shape_dict = {} # {id(p): master param shape}
self.factored_dict = {} # {id(p): True/False}
self.use_first_moment_dict = {} # {id(p): True/False}
self.shard_spec_dict = {} # {id(p): ShardSpec}
super().__init__(params, defaults)


def setup_distributed(
self,
Expand All @@ -84,19 +83,21 @@ def setup_distributed(
if self.data_parallel_group is not None:
self.data_parallel_size = dist.get_world_size(self.data_parallel_group)
self.use_zero = use_zero

self.shard_to_param = shard_to_param if shard_to_param is not None else {}
# grad is None, cause we dont setup now
for group in self.param_groups:
for p in group["params"]:
self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_param.get(id(p)))
self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape
self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)])
self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape
self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(
group, self.grad_shape_dict[id(p)]
)
if self.param_is_dtensor_dict[id(p)]:
self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p)))
else:
self.shard_spec_dict[id(p)] = None

@staticmethod
def _get_lr(param_group, param_state):
rel_step_sz = param_group["lr"]
Expand Down Expand Up @@ -125,7 +126,7 @@ def _get_options(param_group, param_shape):
def _rms(tensor, param_is_dtensor, tp_size, dp_size, tp_group, dp_group):
tensor_sum = tensor.pow(2).sum()
num_of_element = tensor.numel()

if param_is_dtensor:
# reduce tensor_sum from tp_group
dist.all_reduce(tensor_sum, group=tp_group)
Expand Down Expand Up @@ -157,25 +158,21 @@ def _approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam):
r_factor = (exp_avg_sq_row / sq_row_meam).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)

def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t):
if grad_shape[0] % self.data_parallel_size != 0:
# gather update[flatten] along dp group then reshape to [H, W/tp]
update = _gather(
input_=update, dim=-1, process_group=self.data_parallel_group
)
update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group)
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H, W/tp]
grad = _gather(
input_=grad, dim=-1, process_group=self.data_parallel_group
)
grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W/tp]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update_reshape.mean(dim=-2), alpha=(1.0 - beta2t))
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
update_reshape.mul_(grad_reshape)
else:
update_reshape = update.view(-1, grad_shape[1])
grad_reshape = grad.view(-1, grad_shape[1])
Expand All @@ -187,25 +184,21 @@ def _col_parallel_factor(self, update, grad, state, grad_shape, beta2t):
exp_avg_sq_row.div_(self.tensor_parallel_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)

if self.use_zero:
update = update_reshape.view(-1)
else:
update = update_reshape
return update

def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t):
if grad_shape[0] % self.data_parallel_size != 0:
# gather update[flatten] along dp group then reshape to [H/tp, W]
update = _gather(
input_=update, dim=-1, process_group=self.data_parallel_group
)
update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group)
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H/tp, W]
grad = _gather(
input_=grad, dim=-1, process_group=self.data_parallel_group
)
grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
Expand All @@ -231,9 +224,7 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t):
dist.all_reduce(exp_avg_sq_col, group=self.tensor_parallel_group)
exp_avg_sq_col.div_(self.tensor_parallel_size)
# gather row
exp_avg_sq_row_gather = _gather(
input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group
)
exp_avg_sq_row_gather = _gather(input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group)
sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True)
update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam)
update_reshape.mul_(grad_reshape)
Expand All @@ -242,24 +233,20 @@ def _row_parallel_factor(self, update, grad, state, grad_shape, beta2t):
else:
update = update_reshape
return update

def _base_factor(self, update, grad, state, grad_shape, beta2t):
if self.use_zero:
# only zero
if grad_shape[0] % self.data_parallel_size != 0:
# view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1])
# row mean no change
# view update to origin shape update.view(grad_shape[0]//self.data_parallel_size , grad_shape[1])
# row mean no change
# col mean need reduce and div
# gather update[flatten] along dp group then reshape to [H, W]
update = _gather(
input_=update, dim=-1, process_group=self.data_parallel_group
)
update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group)
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H, W]
grad = _gather(
input_=grad, dim=-1, process_group=self.data_parallel_group
)
grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/dp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
Expand All @@ -274,8 +261,8 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t):
else:
# no residual row
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W]
grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W]
update_reshape = update.view(-1, grad_shape[1]) # [H/dp, W]
grad_reshape = grad.view(-1, grad_shape[1]) # [H/dp, W]
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
exp_avg_sq_row.mul_(beta2t).add_(update_reshape.mean(dim=-1), alpha=(1.0 - beta2t))
Expand All @@ -285,7 +272,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t):
exp_avg_sq_col.div_(self.tensor_parallel_size)
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
update = update_reshape.view(-1)
update = update_reshape.view(-1)
else:
# base factor; no tp, no dp
exp_avg_sq_row = state["exp_avg_sq_row"]
Expand All @@ -298,9 +285,7 @@ def _base_factor(self, update, grad, state, grad_shape, beta2t):
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
return update




@torch.no_grad()
def step(self, closure=None):
"""
Expand Down Expand Up @@ -335,7 +320,7 @@ def step(self, closure=None):
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")

state = self.state[p]
grad_shape = self.grad_shape_dict[id(p)]
param_is_dtensor = self.param_is_dtensor_dict[id(p)]
Expand All @@ -355,11 +340,11 @@ def step(self, closure=None):
if grad_shape[0] % self.data_parallel_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H]
) # [H]
else:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp]
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp]
state["exp_avg_sq_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W/TP]
Expand All @@ -369,23 +354,27 @@ def step(self, closure=None):
if grad_shape[0] % self.data_parallel_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
) # [H/tp]
) # [H/tp]
else:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp/tp]
) # [H/dp/tp]

state["exp_avg_sq_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W]
grad_shape[1], device=p.device, dtype=p.dtype
) # [W]
else:
if self.use_zero:
if grad_shape[0] % self.data_parallel_size != 0:
# save all exp_avg_sq_row [H]
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype)
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=grad.device, dtype=p.dtype
)
else:
# exp_avg_sq_row [H // dp]
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype)
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.data_parallel_size, device=grad.device, dtype=p.dtype
)
else:
# exp_avg_sq_row [H]
state["exp_avg_sq_row"] = torch.zeros(grad_shape[0], device=grad.device, dtype=p.dtype)
Expand Down Expand Up @@ -423,16 +412,23 @@ def step(self, closure=None):
elif shard_spec.sharding_sequence[-1] == "R":
update = self._row_parallel_factor(update, grad, state, grad_shape, beta2t)
else:
update = self._base_factor(update, grad, state, grad_shape, beta2t)
update = self._base_factor(update, grad, state, grad_shape, beta2t)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)

# # (Line No.8) RMS
rms = self._rms(update, param_is_dtensor, self.tensor_parallel_size, self.data_parallel_size, self.tensor_parallel_group, self.data_parallel_group)
rms = self._rms(
update,
param_is_dtensor,
self.tensor_parallel_size,
self.data_parallel_size,
self.tensor_parallel_group,
self.data_parallel_group,
)
update.div_((rms / group["clip_threshold"]).clamp_(min=1.0))

update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
Expand All @@ -441,8 +437,7 @@ def step(self, closure=None):

if group["weight_decay"] != 0:
p.add_(p, alpha=(-group["weight_decay"] * lr))

p.add_(-update)


return loss
12 changes: 6 additions & 6 deletions docs/source/en/features/distributed_adafactor.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
# Distributed Adafactor

Author:
Author:

**Related Paper**
- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)

## Introduction

Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details.
Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details.

### API Reference

{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}

## Hands-On Practice
We now demonstrate how to start Distributed Adafactor with booster API.
We now demonstrate how to start Distributed Adafactor with booster API.
### step 1. Import libraries

```python
Expand Down Expand Up @@ -65,9 +65,9 @@ dist_optim = DistributedAdaFactor(model.parameters())
```python
plugin = LowLevelZeroPlugin()
booster = Booster(plugin=plugin)
model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader)
model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader)
```
### step 5.Train Your Model
### step 5.Train Your Model
```python
for epoch in range(max_epochs):
for input_ids, attention_mask in dataloader:
Expand Down Expand Up @@ -106,7 +106,7 @@ Model/Feature Compatibility Matrix:
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>

<tr>
<td colspan="39"></td>
</tr>
Expand Down
Loading

0 comments on commit dff7ba3

Please sign in to comment.