Skip to content
Permalink
Browse files

Remove _check_classpath

- libhdfs picks up classpath from LIBHDFS_CLASSPATH
which has been initialized with hadoop classpath --glob

Change-Id: I011231791616499419f8403282aaf5d58dc76638
  • Loading branch information...
fhoering committed Jan 8, 2019
1 parent 8f16ea1 commit 9024b3a94c2a8ab21fae27e851d4b8cf9ff96d5c
Showing with 3 additions and 57 deletions.
  1. +2 −22 tests/test__task_commons.py
  2. +0 −20 tf_yarn/_internal.py
  3. +1 −15 tf_yarn/_task_commons.py
@@ -11,9 +11,9 @@
import tensorflow as tf

from tf_yarn.__init__ import Experiment
from tf_yarn._internal import iter_tasks, expand_wildcards_in_classpath, MonitoredThread
from tf_yarn._internal import iter_tasks, MonitoredThread
from tf_yarn._task_commons import matches_device_filters, \
_prepare_container, _check_classpath, _get_experiment, _execute_dispatched_function, \
_prepare_container, _get_experiment, _execute_dispatched_function, \
wait_for_connected_tasks, _shutdown_container


@@ -37,30 +37,11 @@ def test_does_not_match_device_filters(task, device_filters):
assert not matches_device_filters(task, device_filters)


def test__check_classpath():
with contextlib.ExitStack() as stack:
stack.enter_context(patch.dict(os.environ))
mocked_logging = stack.enter_context(patch(f'{MODULE_TO_TEST}.tf.logging.warn'))
# Only throw warning if classpath is not set
_check_classpath()
mocked_logging.assert_called_once()

# expand classpath when provided
mocked_expansion = stack.enter_context(
patch(f'{MODULE_TO_TEST}.expand_wildcards_in_classpath'))
expanded_path = 'expanded_path'
mocked_expansion.return_value = expanded_path
os.environ["CLASSPATH"] = "mock_path"
_check_classpath()
assert os.environ["CLASSPATH"] == expanded_path


def test__prepare_container():
with contextlib.ExitStack() as stack:
# mock modules
mocked_client_call = stack.enter_context(
patch(f"{MODULE_TO_TEST}.skein.ApplicationClient.from_current"))
mocked_check = stack.enter_context(patch(f'{MODULE_TO_TEST}._check_classpath'))
mocked_logs = stack.enter_context(patch(f'{MODULE_TO_TEST}._setup_container_logs'))
mocked_cluster_spec = stack.enter_context(patch(f'{MODULE_TO_TEST}.cluster.start_cluster'))

@@ -72,7 +53,6 @@ def test__prepare_container():
client, cluster_spec, cluster_tasks = _prepare_container()

# checks
mocked_check.assert_called_once()
mocked_logs.assert_called_once()
mocked_cluster_spec.assert_called_once_with(mocked_client, cluster_tasks)
assert client == mocked_client
@@ -129,26 +129,6 @@ def zip_path(path: str, tempdir: str):
return zip_path


def expand_wildcards_in_classpath(classpath: str) -> str:
"""Expand wildcard entries in the $CLASSPATH.
JNI-invoked JVM does not support wildcards in the classpath. This
function replaces all classpath entries of the form ``foo/bar/*``
with the JARs in the ``foo/bar`` directory.
See "Common Problems" section in the libhdfs docs
https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-hdfs/LibHdfs.html
"""
def maybe_glob(path):
if path.endswith("*"):
yield from glob.iglob(path + ".jar")
else:
yield path

return ":".join(entry for path in classpath.split(":")
for entry in maybe_glob(path))


class StaticDefaultDict(dict):
"""A ``dict`` with a static default value.
@@ -26,7 +26,7 @@

from tf_yarn import event, cluster, Experiment
from tf_yarn.__init__ import KV_CLUSTER_INSTANCES, KV_EXPERIMENT_FN
from tf_yarn._internal import expand_wildcards_in_classpath, MonitoredThread, iter_tasks
from tf_yarn._internal import MonitoredThread, iter_tasks
from tf_yarn.tensorboard import start_tf_board, get_termination_timeout


@@ -42,19 +42,6 @@ def _process_arguments() -> None:
tf.logging.info(f"using log file {log_conf_file}")


def _check_classpath():
try:
classpath = os.environ["CLASSPATH"]
except KeyError:
tf.logging.warn(
"$CLASSPATH is not defined. HDFS access will surely fail.")
else:
tf.logging.info("Attempting to expand wildcards in $CLASSPATH")
os.environ["CLASSPATH"] = expand_wildcards_in_classpath(classpath)
del classpath
tf.logging.info(os.environ["CLASSPATH"])


def _setup_container_logs(client):
task = cluster.get_task()
event.broadcast_container_start_time(client, task)
@@ -71,7 +58,6 @@ def _prepare_container(
tf.logging.info("Python " + sys.version)
tf.logging.info("Skein " + skein.__version__)
tf.logging.info(f"TensorFlow {tf.GIT_VERSION} {tf.VERSION}")
_check_classpath()
client = skein.ApplicationClient.from_current()
_setup_container_logs(client)
cluster_tasks = list(iter_tasks(json.loads(client.kv.wait(KV_CLUSTER_INSTANCES).decode())))

0 comments on commit 9024b3a

Please sign in to comment.
You can’t perform that action at this time.