Skip to content

Commit

Permalink
Set GPRC_ENABLE_FORK_SUPPORT=False to avoid a deadlock. (#405)
Browse files Browse the repository at this point in the history
* Fix hang to get tasks from the master by GRPC

* Set GPRC_ENABLE_FORK_SUPPORT=False to avoid a deadlock

* Restore some codes

* Add an example to train the model with elastic batch size

* Rename an argument
  • Loading branch information
workingloong committed May 15, 2023
1 parent 9f26fa4 commit 15873b9
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[settings]
known_third_party = DeepFMAdaptor,MyEstimator,cv2,deepctr,deepctr_models,google,grpc,kubernetes,layers,numpy,psutil,pyhocon,ray,setuptools,tensorflow,tensorflow_estimator,torch,yaml
known_third_party = DeepFMAdaptor,MyEstimator,cv2,deepctr,deepctr_models,google,grpc,kubernetes,layers,numpy,psutil,pyhocon,ray,setuptools,tensorflow,tensorflow_estimator,torch,torchvision,yaml
multi_line_output=3
line_length=79
include_trailing_comma=True
6 changes: 6 additions & 0 deletions dlrover/go/operator/pkg/controllers/training/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ const (
workerRankEnvName = "WORKER_RANK"
workerNumEnvName = "WORKER_NUM"
rdzvEndpointEnvName = "RDZV_ENDPOINT"
grpcEnableFork = "GRPC_ENABLE_FORK_SUPPORT"
)

// TaskManager generates Pods for task in a distributed PS job.
Expand Down Expand Up @@ -312,6 +313,11 @@ func (m *TaskManager) newTask(
Value: fmt.Sprintf("%d", podMeta.RankIndex),
}
container.Env = append(container.Env, rankIDEnv)
grpcEnableForkEnv := corev1.EnvVar{
Name: grpcEnableFork,
Value: "False",
}
container.Env = append(container.Env, grpcEnableForkEnv)
if m.taskType == ReplicaTypeWorker {
workerNumEnv := corev1.EnvVar{
Name: workerNumEnvName,
Expand Down
2 changes: 1 addition & 1 deletion dlrover/go/operator/pkg/controllers/training/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func TestNewTaskPod(t *testing.T) {
pod.Spec.Containers[0].Resources.Requests.Cpu().AsApproximateFloat64(),
float64(1),
)
assert.Equal(t, len(pod.Spec.Containers[0].Env), 4)
assert.Equal(t, len(pod.Spec.Containers[0].Env), 5)
assert.Equal(t, pod.Spec.Containers[0].Env[0].Name, "DLROVER_MASTER_ADDR")
assert.Equal(
t,
Expand Down
1 change: 1 addition & 0 deletions dlrover/python/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class NodeEnv(object):
WORKER_NUM = "WORKER_NUM"
WORKER_RANK = "WORKER_RANK"
RDZV_ENDPOINT = "RDZV_ENDPOINT"
GRPC_ENABLE_FORK = "GRPC_ENABLE_FORK_SUPPORT"


class DatasetType(object):
Expand Down
8 changes: 6 additions & 2 deletions dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def wrapper(self, *args, **kwargs):
func.__name__,
)
execption = e
logger.error(e)
time.sleep(5)
if execption:
logger.error(execption)
raise execption

return wrapper
Expand Down Expand Up @@ -120,14 +120,17 @@ def get_task(self, dataset_name):

success = False
res = None
exception = None
for _ in range(10):
try:
res = self._stub.get_task(req)
success = True
break
except Exception as e:
logger.warning(e)
exception = e
time.sleep(15)
if not success:
logger.warning(exception)
if not res:
res = elastic_training_pb2.Task()
return success, res
Expand Down Expand Up @@ -550,6 +553,7 @@ def kv_store_get(self, key):


def build_master_client(master_addr=None):
logger.info("Build master client")
if master_addr is None:
master_addr = os.getenv(NodeEnv.DLROVER_MASTER_ADDR, "")
worker_id = int(os.getenv(NodeEnv.WORKER_ID, 0))
Expand Down
3 changes: 1 addition & 2 deletions dlrover/python/elastic_agent/sharding/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def get_task(self) -> elastic_training_pb2.Task:
if len(self._pending_tasks) == 1:
self._current_task = task
self._shard_count += 1
logger.info("shard count = %s", self._shard_count)
return task
return None

Expand Down Expand Up @@ -211,7 +210,7 @@ def restore_shard_from_checkpoint(self, shard_checkpoint):

def get_current_epoch(self):
res = self._mc.get_dataset_epoch(self._dataset_name)
return res.epoch
return res.epoch - 1

def get_total_sample_num(self):
return self._dataset_size * self._num_epochs
Expand Down
4 changes: 4 additions & 0 deletions dlrover/python/master/scaler/pod_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ def _create_pod(self, node: Node, pod_stats: Dict[str, int], ps_addrs):
env.append(V1EnvVar(name=NodeEnv.WORKER_TYPE, value=node.type))
env.append(V1EnvVar(name=NodeEnv.WORKER_ID, value=str(node.id)))

# A deadlock can happen when pthread_atfork handler is running.
# For detail https://chromium.googlesource.com/external/github.com/grpc/grpc/+/refs/tags/v1.19.0-pre1/doc/fork_support.md # noqa: E501
env.append(V1EnvVar(name=NodeEnv.GRPC_ENABLE_FORK, value="False"))

worker_num = self._config_worker_num
if pod_stats[node.type] > worker_num:
worker_num = pod_stats[node.type]
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/tests/test_sharding_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_sharding_client(self):
shard = data_shard_service.fetch_shard()
self.assertEqual(shard.start, 0)
self.assertEqual(shard.end, 32)
self.assertEqual(data_shard_service.get_current_epoch(), 1)
self.assertEqual(data_shard_service.get_current_epoch(), 0)
shard_count = 1
while True:
shard = data_shard_service.fetch_shard()
Expand Down
10 changes: 1 addition & 9 deletions dlrover/trainer/torch/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader

from dlrover.python.common.log import default_logger as logger
from dlrover.python.elastic_agent.master_client import GlobalMasterClient
from dlrover.trainer.torch.elastic_dataset import ElasticDataset

_MASTER_ADDR_KEY = "MASTER_ADDR"

Expand Down Expand Up @@ -173,8 +171,6 @@ class ElasticTrainer(object):
Args:
model (`torch.nn.Module`): PyTorch Module.
optimizer (`torch.nn.Optimizer`): PyTorch Optimizer.
dataloader: ElasticDataset in DLRover.
**Available attributes:**
- **step** -- the number of local step on the process.
Expand All @@ -188,11 +184,9 @@ class ElasticTrainer(object):
size fixed by adjusting the gradient_accumulation_steps.
"""

def __init__(self, model, dataloader: DataLoader):
def __init__(self, model):
self.model = model
self.optimizer = None
self.dataloader = dataloader
self._dataset = dataloader.dataset
self.gradient_state = GradientState()
self.gradient_accumulation_steps = 1

Expand Down Expand Up @@ -263,8 +257,6 @@ def _before_step(self, fix_total_batch_size):
)

def _after_step(self):
if isinstance(self._dataset, ElasticDataset):
self._dataset.step()
if self.gradient_state.sync_gradients:
self.gradient_state.num_steps += 1

Expand Down
10 changes: 3 additions & 7 deletions dlrover/trainer/torch/elastic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from abc import ABCMeta, abstractmethod
from typing import Dict

Expand Down Expand Up @@ -54,21 +53,18 @@ class ElasticDataset(Dataset, metaclass=ABCMeta):

def __init__(
self,
name,
dataset_size,
batch_size,
epochs,
shuffle,
name=None,
num_minibatches_per_shard=2,
):
self.dataset_size = dataset_size
if not name:
name = "dlrover-ds-" + str(time.time())
self._shard_client = IndexShardingClient(
dataset_name=name,
batch_size=batch_size,
num_epochs=epochs,
dataset_size=self.dataset_size,
dataset_size=dataset_size,
shuffle=shuffle,
storage_type="text",
num_minibatches_per_shard=num_minibatches_per_shard,
Expand All @@ -82,7 +78,7 @@ def __getitem__(self, _):
return self.read_sample(index)

def get_epoch(self):
self._shard_client.get_current_epoch()
return self._shard_client.get_current_epoch()

def step(self):
"""After updating models using the samples, the dataset need to
Expand Down
92 changes: 70 additions & 22 deletions model_zoo/pytorch/mnist_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -66,10 +67,11 @@ def __init__(self, path, batch_size, epochs, shuffle):
"""The dataset supports elastic training."""
self.data_meta = build_data_meta(path)
super(ElasticMnistDataset, self).__init__(
len(self.data_meta),
batch_size,
epochs,
shuffle,
name="mnist-train",
dataset_size=len(self.data_meta),
batch_size=batch_size,
epochs=epochs,
shuffle=shuffle,
)

def read_sample(self, index):
Expand Down Expand Up @@ -149,20 +151,15 @@ def train(args):
dataset=train_dataset, batch_size=args.batch_size, num_workers=2
)

test_dataset = ElasticMnistDataset(
path=args.validation_data,
batch_size=args.batch_size,
epochs=1,
shuffle=False,
)
test_loader = DataLoader(
dataset=test_dataset, batch_size=args.batch_size, num_workers=2
test_dataset = torchvision.datasets.ImageFolder(
args.validation_data,
transform=torchvision.transforms.ToTensor(),
)
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size)

model = Net()
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate)
step_size = int(train_dataset.dataset_size / args.batch_size)
scheduler = StepLR(optimizer, step_size=step_size, gamma=0.5)
scheduler = StepLR(optimizer, step_size=1, gamma=0.5)

if torch.cuda.is_available():
rank = int(os.environ["LOCAL_RANK"])
Expand All @@ -173,12 +170,30 @@ def train(args):
else:
model = DDP(model)

elastic_trainer = ElasticTrainer(model, train_loader)
optimizer, scheduler = elastic_trainer.prepare(optimizer, scheduler)
if checkpoint:
model.load_state_dict(checkpoint.get("model_state_dict", {}))
optimizer.load_state_dict(checkpoint.get("optimizer_state_dict", {}))

if args.fixed_batch_size:
train_with_fixed_batch_size(
model, optimizer, scheduler, train_loader, test_loader, device
)
else:
train_with_elastic_batch_size(
model, optimizer, scheduler, train_loader, test_loader, device
)


def train_with_fixed_batch_size(
model, optimizer, scheduler, train_loader, test_loader, device
):
"""
The global batch size will not change if the number of workers changes.
"""
elastic_trainer = ElasticTrainer(model)
optimizer, scheduler = elastic_trainer.prepare(optimizer, scheduler)

epoch = 0
for _, (data, target) in enumerate(train_loader):
model.train()
with elastic_trainer.step():
Expand All @@ -189,22 +204,51 @@ def train(args):
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loader.dataset.step() # Record the batch samples completed.
print(
"loss = {}, step = {}".format(loss, elastic_trainer.num_steps)
)
scheduler.step()
if (
elastic_trainer.num_steps > 0
and elastic_trainer.num_steps % 200 == 0
):
save_checkpoint(
CHEKPOINT_PATH, model, optimizer, train_dataset
CHEKPOINT_PATH, model, optimizer, train_loader.dataset
)
if (
elastic_trainer.num_steps > 0
and elastic_trainer.num_steps % 10000 == 0
):
dataset_epoch = train_loader.dataset.get_epoch()
if dataset_epoch > epoch:
epoch = dataset_epoch
scheduler.step()
test(model, device, test_loader)
test(model, device, test_loader)


def train_with_elastic_batch_size(
model, optimizer, scheduler, train_loader, test_loader, device
):
"""The global batch size will change if the number of worker changes."""
epoch = 0
for step, (data, target) in enumerate(train_loader):
model.train()
target = target.type(torch.LongTensor)
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loader.dataset.step() # Record the batch samples completed.
print("loss = {}, step = {}".format(loss, step))
if step > 0 and step % 200 == 0:
save_checkpoint(
CHEKPOINT_PATH, model, optimizer, train_loader.dataset
)
dataset_epoch = train_loader.dataset.get_epoch()
if dataset_epoch > epoch:
epoch = dataset_epoch
scheduler.step()
test(model, device, test_loader)
test(model, device, test_loader)


def save_checkpoint(path, model, optimizer, dataset: ElasticDataset):
Expand All @@ -225,6 +269,7 @@ def load_checkpoint(path):


def test(model, device, test_loader):
print("Test the model ...")
model.eval()
test_loss = 0
correct = 0
Expand Down Expand Up @@ -258,6 +303,9 @@ def arg_parser():
parser.add_argument("--batch_size", type=int, default=32, required=False)
parser.add_argument("--num_epochs", type=int, default=1, required=False)
parser.add_argument("--shuffle", type=bool, default=True, required=False)
parser.add_argument(
"--fixed_batch_size", type=bool, default=True, required=False
)
parser.add_argument(
"--learning_rate", type=float, default=0.1, required=False
)
Expand Down

0 comments on commit 15873b9

Please sign in to comment.