Skip to content

Commit

Permalink
Resolve comments from Min
Browse files Browse the repository at this point in the history
  • Loading branch information
cyugao committed Dec 5, 2022
1 parent 987f12b commit 2e9288d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 24 deletions.
14 changes: 7 additions & 7 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,16 @@ install_dep_pytorch_lts: &install_dep_pytorch_lts
# most recent stable version
install_dep_pytorch_stable: &install_dep_pytorch_stable
- run:
name: Install Dependencies with torch 1.12.0
name: Install Dependencies with torch 1.13.0
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.11 && exit 0; fi
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.13 && exit 0; fi
# start installing
pip install --progress-bar off torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install --progress-bar off torch==1.13.0 torchvision==0.14.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "12"], f"wrong torch version {torch.__version__}"'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "13"], f"wrong torch version {torch.__version__}"'
python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
Expand All @@ -118,13 +118,13 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly
name: Install Dependencies with a torch nightly preview build
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.12 && exit 0; fi
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.14 && exit 0; fi
# start installing
pip install --pre torch==1.13.0 torchvision==0.14.0 --extra-index-url https://download.pytorch.org/whl/nightly/cu113
pip install --pre torch==1.14.0.dev20221121+cu117 torchvision==0.15.0.dev20221121+cu117 --extra-index-url https://download.pytorch.org/whl/nightly/cu117
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "13"], f"wrong torch version {torch.__version__}"'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "14"], f"wrong torch version {torch.__version__}"'
python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
Expand Down
30 changes: 16 additions & 14 deletions fairscale/optim/adascale.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,26 +387,30 @@ def _update_avg(self, name: str, value: np.ndarray, factor: float) -> None:
else:
self._state[name] = factor * self._state[name] + (1.0 - factor) * value

def _to_flat_view(self, p: torch.Tensor) -> torch.Tensor:

This comment has been minimized.

Copy link
@min-xu-ai

min-xu-ai Dec 5, 2022

Contributor

this could be a local (nested) function within _gather_flat_grad? It doesn't use self anyway, right?

"""
Helper function for _gather_flat_grad.
Returns a flattened view of the input tensor.
"""
if p.grad is None:
return p.new(p.numel()).zero_() # type: ignore
elif p.grad.is_sparse: # type: ignore
return p.grad.to_dense().view(-1)
else:
return p.grad.view(-1)

def _gather_flat_grad(self) -> torch.Tensor:
"""
Helper function for gathering all gradients into a single vector.
Duplicated from torch.optim.lbfgs.
"""
views = []
for param_group in self._optimizer.param_groups:
for p in param_group["params"]:
if p.grad is None:
view = p.new(p.numel()).zero_()
elif p.grad.is_sparse:
view = p.grad.to_dense().view(-1)
else:
view = p.grad.view(-1)
views.append(view)
views = [self._to_flat_view(p) for param_group in self._optimizer.param_groups for p in param_group["params"]]
return torch.cat(views, 0)

def _compute_intra_grad_corr_mean(self) -> float:
def _compute_intra_grad_corr_mean(self) -> torch.Tensor:
"""
Helper function for computing average intra correlation among gradients on different GPUs.
This should be called under `torch.no_grad()` context.

This comment has been minimized.

Copy link
@min-xu-ai

min-xu-ai Dec 5, 2022

Contributor

why is this? elaborate? Also, how about the no_sync context?

This comment has been minimized.

Copy link
@cyugao

cyugao Dec 5, 2022

Author Contributor

Had a typo here. Fixed

"""
assert self._world_size > 1, "Only for distributed training"
flat_grad = self._gather_flat_grad()
Expand All @@ -423,7 +427,7 @@ def _compute_intra_grad_corr_mean(self) -> float:
else:
dist.gather(flat_grad, gather_list=None, dst=0)
dist.broadcast(corr_mean, src=0)
return corr_mean.item()
return corr_mean

def _backward_hook(self, pg_idx: int, grad: torch.Tensor) -> None:
# This method should be invoked once for each parameter during the
Expand Down Expand Up @@ -488,8 +492,6 @@ def _final_callback(self) -> None:

# Since self._local_grad_sqr is FP32, sum shouldn't overflow.

# TODO: Hongbo says param.grad might be FP16 should do this before converting to FP32.

# This vector has length of # of param_groups, so it is small, but we
# use async to hide the all_reduce latency, esp when # of nodes is large.
work = None
Expand Down
3 changes: 1 addition & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ pynvml == 8.0.4

# For mypy typing. It is important to have a fixed version. Otherwise, you
# may run into mypy errors out differently for different versions.
# Using 1.21.5 for now because py3.7 only has up to 1.21.5, not 1.22.x.
numpy == 1.22.0

# For layerwise gradient scaler
scikit-learn >= 0.0
scikit-learn == 1.1.3

# For weigit. These are actually user requirements, not developer requirements.
# However, due to the experimental nature of weigit, we don't expose to the
Expand Down
6 changes: 5 additions & 1 deletion tests/optim/test_ddp_adascale.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _test_corr_mean_func(rank, world_size, tempfile_name, test_case):
in_data = Tensor(in_data[rank]).cuda()
out = model(in_data)
out.sum().backward()
results.append(optim._compute_intra_grad_corr_mean())
results.append(optim._compute_intra_grad_corr_mean().item())
# sync gradients manually
for p in model.parameters():
if p.grad is not None:
Expand All @@ -191,6 +191,10 @@ def _test_corr_mean_func(rank, world_size, tempfile_name, test_case):


@skip_if_single_gpu
@pytest.mark.skipif(
torch.__version__.split("+")[0].split(".") < ["1", "10", "0"],

This comment has been minimized.

Copy link
@min-xu-ai

min-xu-ai Dec 5, 2022

Contributor

there is a function called torch_version(), which you can use and compare the results. See other tests for an example.

reason="torch.corrcoef available only for torch 1.10 or higher",
)
def test_corr_mean():
"""
Test _compute_intra_grad_corr_mean and _gather_flat_grad using ddp.no_sync()
Expand Down

0 comments on commit 2e9288d

Please sign in to comment.