Skip to content
Permalink
Browse files

Make Spark's Hadoop token file available to Python method (#1532)

Python code run via horovod.spark.run() needs the
Hadoop token file to access secure HDFS. This is
made available to the Spark executor via environment
variable HADOOP_TOKEN_FILE_LOCATION. SparkTaskService
needs to pass this to orted, which runs the Python code.

Signed-off-by: Enrico Minack <github@enrico.minack.dev>
  • Loading branch information
EnricoMi authored and tgaddair committed Dec 2, 2019
1 parent c8c53a9 commit e3d63deefab42e32a0cd89edfe73e6769837ccaa
Showing with 66 additions and 2 deletions.
  1. +11 −1 horovod/run/common/service/task_service.py
  2. +4 −1 horovod/spark/task/task_service.py
  3. +22 −0 test/common.py
  4. +29 −0 test/test_spark.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================

import os
import threading
import time

@@ -52,10 +53,11 @@ def __init__(self, result):


class BasicTaskService(network.BasicService):
def __init__(self, name, key, nic):
def __init__(self, name, key, nic, service_env_keys):
super(BasicTaskService, self).__init__(name, key, nic)
self._initial_registration_complete = False
self._wait_cond = threading.Condition()
self._service_env_keys = service_env_keys
self._command_thread = None
self._fn_result = None

@@ -64,6 +66,14 @@ def _handle(self, req, client_address):
self._wait_cond.acquire()
try:
if self._command_thread is None:
# we inject all these environment variables
# to make them available to the executed command
# NOTE: this will overwrite environment variables that exist in req.env
for key in self._service_env_keys:
value = os.environ.get(key)
if value is not None:
req.env[key] = value

# We only permit executing exactly one command, so this is idempotent.
self._command_thread = threading.Thread(
target=safe_shell_exec.execute,
@@ -18,9 +18,12 @@

class SparkTaskService(task_service.BasicTaskService):
NAME_FORMAT = 'task service #%d'
SERVICE_ENV_KEYS = ['HADOOP_TOKEN_FILE_LOCATION']

def __init__(self, index, key, nic):
super(SparkTaskService, self).__init__(SparkTaskService.NAME_FORMAT % index, key, nic)
super(SparkTaskService, self).__init__(SparkTaskService.NAME_FORMAT % index,
key, nic,
SparkTaskService.SERVICE_ENV_KEYS)


class SparkTaskClient(task_service.BasicTaskClient):
@@ -18,7 +18,10 @@
from __future__ import division
from __future__ import print_function

import contextlib
import os
import shutil
import tempfile


def mpi_env_rank_and_size():
@@ -54,3 +57,22 @@ def mpi_env_rank_and_size():

# Default to rank zero and size one if there are no environment variables
return 0, 1


@contextlib.contextmanager
def tempdir():
dirpath = tempfile.mkdtemp()
try:
yield dirpath
finally:
shutil.rmtree(dirpath)


@contextlib.contextmanager
def temppath():
path = tempfile.mktemp()
try:
yield path
finally:
if os.path.exists(path):
shutil.rmtree(path)
@@ -27,12 +27,16 @@
import unittest
import warnings

from horovod.run.common.util import secret
from horovod.run.mpi_run import _get_mpi_implementation_flags
import horovod.spark
from horovod.spark.task.task_service import SparkTaskService, SparkTaskClient
import horovod.torch as hvd

from mock import MagicMock

from common import tempdir


@contextlib.contextmanager
def spark(app, cores=2, *args):
@@ -212,3 +216,28 @@ def fn():
self.assertTrue(len(actual_secret) > 0)
self.assertEqual(actual_stdout, stdout)
self.assertEqual(actual_stderr, stderr)

def test_spark_task_service_env(self):
key = secret.make_secret_key()
service_env = dict([(key, '{} value'.format(key))
for key in SparkTaskService.SERVICE_ENV_KEYS])
service_env.update({"other": "value"})
with os_environ(service_env):
service = SparkTaskService(1, key, None)
client = SparkTaskClient(1, service.addresses(), key, 3)

with tempdir() as d:
file = '{}/env'.format(d)
command = "env | grep -v '^PWD='> {}".format(file)
command_env = {"test": "value"}

try:
client.run_command(command, command_env)
client.wait_for_command_termination()
finally:
service.shutdown()

with open(file) as f:
env = sorted([line.strip() for line in f.readlines()])
expected = ['HADOOP_TOKEN_FILE_LOCATION=HADOOP_TOKEN_FILE_LOCATION value', 'test=value']
self.assertEqual(env, expected)

0 comments on commit e3d63de

Please sign in to comment.
You can’t perform that action at this time.