Skip to content
Permalink
Browse files

Let's keep socket with `bind(("", 0))` open as long as it is possible…

… to reduce the change of hijacking

JIRA: ENG-30153
Change-Id: I449ae678996aed6398e8d0c477421c6b520b7bbe
  • Loading branch information...
Akim Boyko qabot escalation
Akim Boyko authored and qabot escalation committed Apr 24, 2019
1 parent a728187 commit 322ebba16119b9ac580b129d8fd36d53e4711917
@@ -47,14 +47,15 @@ def test__prepare_container():

# fill client mock
mocked_client = mock.MagicMock(spec=skein.ApplicationClient)
host_port = ('localhost', 1234)
instances = [('worker', 10), ('chief', 1)]
mocked_client.kv.wait.return_value = json.dumps(instances).encode()
mocked_client_call.return_value = mocked_client
client, cluster_spec, cluster_tasks = _prepare_container()
(client, cluster_spec, cluster_tasks) = _prepare_container(host_port)

# checks
mocked_logs.assert_called_once()
mocked_cluster_spec.assert_called_once_with(mocked_client, cluster_tasks)
mocked_cluster_spec.assert_called_once_with(host_port, mocked_client, cluster_tasks)
assert client == mocked_client
assert cluster_tasks == list(iter_tasks(instances))

@@ -50,17 +50,13 @@ def test_start_cluster_worker(task_name, task_index):

with contextlib.ExitStack() as stack:
stack.enter_context(mock.patch.dict(os.environ))
mock_reserve = stack.enter_context(
mock.patch(f"{MODULE_TO_TEST}._internal.reserve_sock_addr"))
mock_event = stack.enter_context(mock.patch(f"{MODULE_TO_TEST}.event"))

os.environ["SKEIN_CONTAINER_ID"] = f"{task_name}_{task_index}"

mock_reserve.return_value.__enter__.return_value = CURRENT_HOST, CURRENT_PORT
mock_event.wait.side_effect = lambda client, key: CLUSTER_SPEC[key]
mock_client = mock.Mock(spec=skein.ApplicationClient)
cluster.start_cluster(mock_client, [task, "worker:0"])

cluster.start_cluster((CURRENT_HOST, CURRENT_PORT), mock_client, [task, "worker:0"])
mock_event.init_event.assert_called_once_with(mock_client, task,
f"{CURRENT_HOST}:{CURRENT_PORT}")

@@ -4,20 +4,21 @@
_prepare_container, _execute_dispatched_function,
_shutdown_container, _process_arguments, _get_experiment
)
from tf_yarn._internal import reserve_sock_addr
from . import cluster


def main() -> None:
task_type, task_id = cluster.get_task_description()
client, cluster_spec, cluster_tasks = _prepare_container()
with reserve_sock_addr() as host_port:
client, cluster_spec, cluster_tasks = _prepare_container(host_port)
# Variable TF_CONFIG must be set before instantiating
# the estimator to train in a distributed way
cluster.setup_tf_config(cluster_spec)
experiment = _get_experiment(client)
run_config = experiment.config
tf.logging.info(f"Starting server {task_type}:{task_id}")

# Variable TF_CONFIG must be set before instantiating
# the estimator to train in a distributed way
cluster.setup_tf_config(cluster_spec)
experiment = _get_experiment(client)
run_config = experiment.config

tf.logging.info(f"Starting server {task_type}:{task_id}")
cluster.start_tf_server(cluster_spec, run_config.session_config)
thread = _execute_dispatched_function(client, experiment)

@@ -69,6 +69,9 @@ def reserve_sock_addr() -> Iterator[Tuple[str, int]]:
The reservation is done by binding a TCP socket to port 0 with
``SO_REUSEPORT`` flag set (requires Linux >=3.9). The socket is
then kept open until the generator is closed.
To reduce probability of 'hijacking' port, socket should stay open
and should be closed _just before_ starting of ``tf.train.Server``
"""
so_reuseport = get_so_reuseport()
if so_reuseport is None:
@@ -3,15 +3,18 @@

from . import event
from tf_yarn._task_commons import _prepare_container, _process_arguments
from tf_yarn._internal import reserve_sock_addr
from . import cluster, KV_TF_SESSION_CONFIG


def main() -> None:
task_type, task_id = cluster.get_task_description()
client, cluster_spec, cluster_tasks = _prepare_container()
cluster.setup_tf_config(cluster_spec)
tf_session_config = dill.loads(client.kv.wait(KV_TF_SESSION_CONFIG))
tf.logging.info(f"tf_server_conf {tf_session_config}")
with reserve_sock_addr() as host_port:
client, cluster_spec, cluster_tasks = _prepare_container(host_port)
cluster.setup_tf_config(cluster_spec)
tf_session_config = dill.loads(client.kv.wait(KV_TF_SESSION_CONFIG))
tf.logging.info(f"tf_server_conf {tf_session_config}")

cluster.start_tf_server(cluster_spec, tf_session_config)
event.wait(client, "stop")

@@ -41,14 +41,16 @@ def _setup_container_logs(client):


def _prepare_container(
host_port: Tuple[str, int]
) -> Tuple[skein.ApplicationClient, Dict[str, List[str]], List[str]]:
"""Keep socket open while preparing container """
tf.logging.info("Python " + sys.version)
tf.logging.info("Skein " + skein.__version__)
tf.logging.info(f"TensorFlow {tf.GIT_VERSION} {tf.VERSION}")
client = skein.ApplicationClient.from_current()
_setup_container_logs(client)
cluster_tasks = list(iter_tasks(json.loads(client.kv.wait(KV_CLUSTER_INSTANCES).decode())))
cluster_spec = cluster.start_cluster(client, cluster_tasks)
cluster_spec = cluster.start_cluster(host_port, client, cluster_tasks)
return client, cluster_spec, cluster_tasks


@@ -33,6 +33,7 @@ def get_task_description() -> typing.Tuple[str, int]:


def start_cluster(
host_port: typing.Tuple[str, int],
client: skein.ApplicationClient,
all_tasks: typing.List[str]
) -> typing.Dict[str, typing.List[str]]:
@@ -43,10 +44,10 @@ def start_cluster(
# preempting the server.
# See https://github.com/tensorflow/tensorflow/issues/21492
cluster_spec: typing.Dict = dict()
with _internal.reserve_sock_addr() as (host, port):
event.init_event(client, get_task(), f"{socket.gethostbyname(host)}:{port}")
cluster_spec = aggregate_spec(client, all_tasks)
return cluster_spec
host, port = host_port
event.init_event(client, get_task(), f"{socket.gethostbyname(host)}:{port}")
cluster_spec = aggregate_spec(client, all_tasks)
return cluster_spec


def setup_tf_config(cluster_spec):

0 comments on commit 322ebba

Please sign in to comment.
You can’t perform that action at this time.