Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import logging
import math
import os
from contextlib import nullcontext
from types import TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type
Expand All @@ -25,6 +26,8 @@

logger: logging.Logger = logging.getLogger(__name__)

USE_BUCKETIZATION_ENV: str = "TORCHFT_USE_BUCKETIZATION"


def extract_local_tensor(t: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -171,7 +174,7 @@ def _average(self) -> list[torch.Tensor]:


class _StreamingDiLoCoFragment:
bucket_cap_mb: int = 32 * 1024 * 1024
bucket_cap_mb: int = 1 * 1024 * 1024 * 1024
use_bucketization: bool = False

def __init__(
Expand Down Expand Up @@ -220,7 +223,11 @@ def __init__(
if bucket_cap_mb is not None:
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)

self.use_bucketization = use_bucketization
if os.getenv(USE_BUCKETIZATION_ENV, "False") == "True":
self.use_bucketization = True
else:
self.use_bucketization = use_bucketization

self.should_quantize = should_quantize

self._grads: Dict[str, torch.Tensor] = {}
Expand Down Expand Up @@ -535,14 +542,9 @@ def _bucketize_and_allreduce(
def callback(
fut: torch.futures.Future[list[torch.Tensor]],
) -> list[torch.Tensor]:
with torch.cuda.stream(self._stream) if self._stream else nullcontext():
nonlocal bucket_tensors, flat_buffer
# Setup stream dependency
fut.wait()
for t, pack_offset, numel in bucket_tensors:
t.copy_(
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
)
nonlocal bucket_tensors, flat_buffer
for t, pack_offset, numel in bucket_tensors:
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))

return []

Expand Down
4 changes: 3 additions & 1 deletion torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,9 @@ def allreduce(
torch.accelerator.current_stream(),
)
else:
work = self._pg.allreduce([tensor], reduce_op)
opts = AllreduceOptions()
opts.reduceOp = reduce_op
work = self._pg.allreduce([tensor], opts)

# schedule grad normalization as a continuation
# on the Future
Expand Down
1 change: 1 addition & 0 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
# pyre-fixme[16]: no attribute ProcessGroupNCCL
opts = BaseProcessGroupNCCL.Options()
opts.config.blocking = False
opts.global_ranks_in_group = list(range(world_size))

pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
Expand Down
Loading