Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ def _perform_sync(self) -> None:
# we averaged the local version of the tensor so need to copy it back as a DTensor
param.data.copy_(
DTensor.from_local(
avg_param, param.device_mesh, param.placements
avg_param,
param.device_mesh,
param.placements,
shape=param.shape,
stride=param.stride(),
)
)
else:
Expand Down Expand Up @@ -249,7 +253,11 @@ def _restore_parameters(self) -> None:
# we averaged the local version of the tensor so need to copy it back as a DTensor
p.data.copy_(
DTensor.from_local(
self.original_parameters[name], p.device_mesh, p.placements
self.original_parameters[name],
p.device_mesh,
p.placements,
shape=p.shape,
stride=p.stride(),
),
non_blocking=False,
)
Expand Down
34 changes: 29 additions & 5 deletions torchft/local_sgd_integ_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import logging
import os
import re
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
Expand All @@ -11,8 +12,10 @@
import torch
from parameterized import parameterized
from torch import nn, optim
from torch.distributed.tensor import DTensor, Replicate

from torchft._torchft import LighthouseServer
from torchft.device_mesh import ft_init_device_mesh
from torchft.local_sgd import DiLoCo, LocalSGD
from torchft.manager import Manager
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
Expand Down Expand Up @@ -64,6 +67,7 @@ def state_dict() -> Dict[str, Dict[str, object]]:
stack.callback(lambda: manager.shutdown(wait=False))

m: nn.Module = MyModel().to(device)

optimizer: optim.Optimizer = optim.Adam(m.parameters())
criterion = nn.CrossEntropyLoss()

Expand Down Expand Up @@ -156,6 +160,29 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
**runner.manager_args,
)
stack.callback(manager.shutdown)
# initialize default group for device mesh to work
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
init_method=f"tcp://localhost:0",
rank=rank,
world_size=runner.world_size,
)

device_type = device.type
ft_device_mesh = ft_init_device_mesh(
device_type=device_type,
mesh_shape=(runner.world_size, 1),
mesh_dim_names=("replicate", "none"),
replicate_dim=0,
manager=manager,
)
for layer in m.layers:
if isinstance(layer, nn.Linear):
for param in layer.parameters():
param = DTensor.from_local(
param,
device_mesh=ft_device_mesh,
)

criterion = nn.CrossEntropyLoss()
all_state_dicts = {}
Expand All @@ -170,13 +197,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
while True:
manager_curr_step = manager.current_step()
if manager_curr_step not in all_state_dicts:
print(
f"{manager_curr_step=} {diloco._local_step=} {runner.replica_id=} {state_dict()=}"
)
all_state_dicts[manager_curr_step] = copy.deepcopy(state_dict())
batch_size = 1
inputs = m.get_rand_inputs(batch_size).to(device)
labels = m.get_rand_labels(batch_size).to(device)
inputs = m.get_rand_inputs(batch_size, device=device)
labels = m.get_rand_labels(batch_size, device=device)

out = m(inputs)
loss = criterion(out, labels)
Expand Down
30 changes: 20 additions & 10 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,29 @@ def __init__(self, in_dim: int = 3, out_dim: int = 4) -> None:
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.model = nn.Sequential(
nn.Linear(in_dim, out_dim),
nn.Sigmoid(),
self.layers = nn.ModuleList(
[
nn.Linear(in_dim, 8),
nn.ReLU(),
nn.Linear(8, out_dim),
nn.ReLU(),
]
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)

def get_rand_inputs(self, batch_size: int) -> torch.Tensor:
return torch.rand(batch_size, self.in_dim)

def get_rand_labels(self, batch_size: int) -> torch.Tensor:
return torch.randint(3, (batch_size,))
for layer in self.layers:
x = layer(x)
return x

def get_rand_inputs(
self, batch_size: int, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
return torch.rand(batch_size, self.in_dim, device=device)

def get_rand_labels(
self, batch_size: int, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
return torch.randint(3, (batch_size,), device=device)


class InjectedFailure(Exception):
Expand Down
Loading