diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 5748def1..824b8f56 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -132,7 +132,11 @@ def _perform_sync(self) -> None: # we averaged the local version of the tensor so need to copy it back as a DTensor param.data.copy_( DTensor.from_local( - avg_param, param.device_mesh, param.placements + avg_param, + param.device_mesh, + param.placements, + shape=param.shape, + stride=param.stride(), ) ) else: @@ -249,7 +253,11 @@ def _restore_parameters(self) -> None: # we averaged the local version of the tensor so need to copy it back as a DTensor p.data.copy_( DTensor.from_local( - self.original_parameters[name], p.device_mesh, p.placements + self.original_parameters[name], + p.device_mesh, + p.placements, + shape=p.shape, + stride=p.stride(), ), non_blocking=False, ) diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index 456b1782..88356775 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -1,5 +1,6 @@ import copy import logging +import os import re import traceback from concurrent.futures import ThreadPoolExecutor, as_completed @@ -11,8 +12,10 @@ import torch from parameterized import parameterized from torch import nn, optim +from torch.distributed.tensor import DTensor, Replicate from torchft._torchft import LighthouseServer +from torchft.device_mesh import ft_init_device_mesh from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager from torchft.manager_integ_test import FailureInjector, MyModel, Runner @@ -64,6 +67,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: stack.callback(lambda: manager.shutdown(wait=False)) m: nn.Module = MyModel().to(device) + optimizer: optim.Optimizer = optim.Adam(m.parameters()) criterion = nn.CrossEntropyLoss() @@ -156,6 +160,29 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] **runner.manager_args, ) stack.callback(manager.shutdown) + # initialize default group for device mesh to work + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + init_method=f"tcp://localhost:0", + rank=rank, + world_size=runner.world_size, + ) + + device_type = device.type + ft_device_mesh = ft_init_device_mesh( + device_type=device_type, + mesh_shape=(runner.world_size, 1), + mesh_dim_names=("replicate", "none"), + replicate_dim=0, + manager=manager, + ) + for layer in m.layers: + if isinstance(layer, nn.Linear): + for param in layer.parameters(): + param = DTensor.from_local( + param, + device_mesh=ft_device_mesh, + ) criterion = nn.CrossEntropyLoss() all_state_dicts = {} @@ -170,13 +197,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] while True: manager_curr_step = manager.current_step() if manager_curr_step not in all_state_dicts: - print( - f"{manager_curr_step=} {diloco._local_step=} {runner.replica_id=} {state_dict()=}" - ) all_state_dicts[manager_curr_step] = copy.deepcopy(state_dict()) batch_size = 1 - inputs = m.get_rand_inputs(batch_size).to(device) - labels = m.get_rand_labels(batch_size).to(device) + inputs = m.get_rand_inputs(batch_size, device=device) + labels = m.get_rand_labels(batch_size, device=device) out = m(inputs) loss = criterion(out, labels) diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index e7622be0..8ae3103e 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -33,19 +33,29 @@ def __init__(self, in_dim: int = 3, out_dim: int = 4) -> None: super().__init__() self.in_dim = in_dim self.out_dim = out_dim - self.model = nn.Sequential( - nn.Linear(in_dim, out_dim), - nn.Sigmoid(), + self.layers = nn.ModuleList( + [ + nn.Linear(in_dim, 8), + nn.ReLU(), + nn.Linear(8, out_dim), + nn.ReLU(), + ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.model(x) - - def get_rand_inputs(self, batch_size: int) -> torch.Tensor: - return torch.rand(batch_size, self.in_dim) - - def get_rand_labels(self, batch_size: int) -> torch.Tensor: - return torch.randint(3, (batch_size,)) + for layer in self.layers: + x = layer(x) + return x + + def get_rand_inputs( + self, batch_size: int, device: torch.device = torch.device("cpu") + ) -> torch.Tensor: + return torch.rand(batch_size, self.in_dim, device=device) + + def get_rand_labels( + self, batch_size: int, device: torch.device = torch.device("cpu") + ) -> torch.Tensor: + return torch.randint(3, (batch_size,), device=device) class InjectedFailure(Exception):