Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyzoo/test/zoo/orca/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#
import os
from zoo import ZooContext
from zoo.orca import OrcaContext

import pytest

Expand All @@ -26,7 +26,7 @@
def orca_data_fixture():
from zoo import init_spark_on_local
from zoo.ray import RayContext
ZooContext._orca_eager_mode = True
OrcaContext._eager_mode = True
sc = init_spark_on_local(cores=4, spark_log_level="INFO")
access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
Expand Down
6 changes: 5 additions & 1 deletion pyzoo/test/zoo/orca/data/test_pandas_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@

import zoo.orca.data
import zoo.orca.data.pandas
from zoo.orca import OrcaContext
from zoo.orca.data import SharedValue
from zoo.common.nncontext import *


class TestSparkXShards(TestCase):
def setup_method(self, method):
self.resource_path = os.path.join(os.path.split(__file__)[0], "../../resources")
OrcaContext.pandas_read_backend = "pandas"

def tearDown(self):
OrcaContext.pandas_read_backend = "spark"

def test_read_local_csv(self):
file_path = os.path.join(self.resource_path, "orca/data/csv")
Expand All @@ -48,7 +53,6 @@ def test_read_local_csv(self):
self.assertTrue('Error tokenizing data' in str(context.exception))

def test_read_local_json(self):
ZooContext.orca_pandas_read_backend = "pandas"
file_path = os.path.join(self.resource_path, "orca/data/json")
data_shard = zoo.orca.data.pandas.read_json(file_path, orient='columns', lines=True)
data = data_shard.collect()
Expand Down
4 changes: 0 additions & 4 deletions pyzoo/test/zoo/orca/data/test_spark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
class TestSparkBackend(TestCase):
def setup_method(self, method):
self.resource_path = os.path.join(os.path.split(__file__)[0], "../../resources")
ZooContext.orca_pandas_read_backend = "spark"

def tearDown(self):
ZooContext.orca_pandas_read_backend = "pandas"

def test_header_and_names(self):
file_path = os.path.join(self.resource_path, "orca/data/csv")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from zoo.orca import OrcaContext
import zoo.orca.data.pandas
from zoo.orca.learn.mxnet import Estimator, create_config

Expand Down Expand Up @@ -81,6 +82,13 @@ def forward(self, x):


class TestMXNetSparkXShards(TestCase):
def setup_method(self, method):
self.resource_path = os.path.join(os.path.split(__file__)[0], "../../resources")
OrcaContext.pandas_read_backend = "pandas"

def tearDown(self):
OrcaContext.pandas_read_backend = "spark"

def test_xshards_symbol_with_val(self):
resource_path = os.path.join(os.path.split(__file__)[0], "../../../../resources")
train_file_path = os.path.join(resource_path, "orca/learn/single_input_json/train")
Expand Down
3 changes: 3 additions & 0 deletions pyzoo/test/zoo/zouwu/test_model_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,10 @@ def test_forecast_tcmf_without_id(self):

def test_forecast_tcmf_xshards(self):
from zoo.zouwu.model.forecast import TCMFForecaster
from zoo.orca import OrcaContext
import zoo.orca.data.pandas
import tempfile
OrcaContext.pandas_read_backend = "pandas"

def preprocessing(df, id_name, y_name):
id = df.index
Expand Down Expand Up @@ -259,6 +261,7 @@ def get_pred(d):
final_df = pd.concat(final_df_list)
final_df.sort_values("datetime", inplace=True)
assert final_df.shape == (300 * horizon, 3)
OrcaContext.pandas_read_backend = "spark"


if __name__ == "__main__":
Expand Down
35 changes: 2 additions & 33 deletions pyzoo/zoo/common/nncontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,13 @@ def init_spark_on_yarn(hadoop_conf,
class ZooContextMeta(type):

_log_output = False
__orca_eager_mode = True
_orca_pandas_read_backend = "pandas"

@property
def log_output(cls):
"""
Whether to redirect Spark driver JVM's stdout and stderr to the current
python process. This is useful when running Analytics Zoo in jupyter notebook.
Default to False. Needs to be set before initializing SparkContext.
Default to be False. Needs to be set before initializing SparkContext.
"""
return cls._log_output

Expand All @@ -138,38 +136,9 @@ def log_output(cls, value):
raise AttributeError("log_output cannot be set after SparkContext is created."
" Please set it before init_nncontext, init_spark_on_local"
"or init_spark_on_yarn")
assert isinstance(value, bool), "log_output should either be True or False"
cls._log_output = value

@property
def _orca_eager_mode(cls):
"""
Default to True. Needs to be set before initializing SparkContext.
"""
return cls.__orca_eager_mode

@_orca_eager_mode.setter
def _orca_eager_mode(cls, value):
if SparkContext._active_spark_context is not None:
raise AttributeError("orca_eager_mode cannot be set after SparkContext is created."
" Please set it before init_nncontext, init_spark_on_local"
"or init_spark_on_yarn")
cls.__orca_eager_mode = value

@property
def orca_pandas_read_backend(cls):
"""
The backend for reading csv/json files. Either "spark" or "pandas".
spark backend would call spark.read and pandas backend would call pandas.read.
"""
return cls._orca_pandas_read_backend

@orca_pandas_read_backend.setter
def orca_pandas_read_backend(cls, value):
value = value.lower()
assert value == "spark" or value == "pandas", \
"orca_pandas_read_backend must be either spark or pandas"
cls._orca_pandas_read_backend = value


class ZooContext(metaclass=ZooContextMeta):
pass
Expand Down
2 changes: 2 additions & 0 deletions pyzoo/zoo/orca/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .common import OrcaContext
69 changes: 69 additions & 0 deletions pyzoo/zoo/orca/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from zoo import ZooContext


class OrcaContextMeta(type):

_pandas_read_backend = "spark"
__eager_mode = True

@property
def log_output(cls):
"""
Whether to redirect Spark driver JVM's stdout and stderr to the current
python process. This is useful when running Analytics Zoo in jupyter notebook.
Default to be False. Needs to be set before initializing SparkContext.
"""
return ZooContext.log_output

@log_output.setter
def log_output(cls, value):
ZooContext.log_output = value

@property
def pandas_read_backend(cls):
"""
The backend for reading csv/json files. Either "spark" or "pandas".
spark backend would call spark.read and pandas backend would call pandas.read.
Default to be "spark".
"""
return cls._pandas_read_backend

@pandas_read_backend.setter
def pandas_read_backend(cls, value):
value = value.lower()
assert value == "spark" or value == "pandas", \
"pandas_read_backend must be either spark or pandas"
cls._pandas_read_backend = value

@property
def _eager_mode(cls):
"""
Whether to compute eagerly for SparkXShards.
Default to be True.
"""
return cls.__eager_mode

@_eager_mode.setter
def _eager_mode(cls, value):
assert isinstance(value, bool), "_eager_mode should either be True or False"
cls.__eager_mode = value


class OrcaContext(metaclass=OrcaContextMeta):
pass
15 changes: 12 additions & 3 deletions pyzoo/zoo/orca/data/pandas/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
#

from bigdl.util.common import get_node_and_core_number
from zoo import init_nncontext, ZooContext
from zoo import init_nncontext
from zoo.orca import OrcaContext
from zoo.orca.data import SparkXShards
from zoo.orca.data.utils import *

Expand Down Expand Up @@ -47,8 +48,9 @@ def read_json(file_path, **kwargs):
def read_file_spark(file_path, file_type, **kwargs):
sc = init_nncontext()
node_num, core_num = get_node_and_core_number()
backend = OrcaContext.pandas_read_backend

if ZooContext.orca_pandas_read_backend == "pandas":
if backend == "pandas":
file_url_splits = file_path.split("://")
prefix = file_url_splits[0]

Expand Down Expand Up @@ -252,5 +254,12 @@ def f(iter):

pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns, squeeze, index_col))

data_shards = SparkXShards(pd_rdd)
try:
data_shards = SparkXShards(pd_rdd)
except Exception as e:
alternative_backend = "pandas" if backend == "spark" else "spark"
print("An error occurred when reading files with '%s' backend, you may switch to '%s' "
"backend for another try. You can set the backend using "
"OrcaContext.pandas_read_backend" % (backend, alternative_backend))
raise e
return data_shards
4 changes: 2 additions & 2 deletions pyzoo/zoo/orca/data/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from py4j.protocol import Py4JError

from zoo.orca.data.utils import *
from zoo.orca import OrcaContext
from zoo.common.nncontext import init_nncontext
from zoo import ZooContext


class XShards(object):
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(self, rdd, transient=False):
if transient:
self.eager = False
else:
self.eager = ZooContext._orca_eager_mode
self.eager = OrcaContext._eager_mode
self.rdd.cache()
if self.eager:
self.compute()
Expand Down