Skip to content

Commit

Permalink
Handles import with exception (#539)
Browse files Browse the repository at this point in the history
* handles import with exception

* Remove debris from previous change

* reformat with black
  • Loading branch information
wangyinz committed May 18, 2024
1 parent de648fa commit 91addbb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
17 changes: 3 additions & 14 deletions python/mspasspy/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,12 @@
import fcntl

try:
import dask.bag as daskbag
import dask.dataframe as daskdf

_mspasspy_has_dask = True
except ImportError:
_mspasspy_has_dask = False

try:
import dask.dataframe as daskdf
except ImportError:
_mspasspy_has_dask = False

try:
import pyspark

_mspasspy_has_pyspark = True
except ImportError:
_mspasspy_has_pyspark = False

import gridfs
import pymongo
Expand Down Expand Up @@ -5864,7 +5853,7 @@ def save_dataframe(
"""
dbcol = self[collection]

if parallel:
if parallel and _mspasspy_has_dask:
df = daskdf.from_pandas(df, chunksize=1, sort=False)

if not one_to_one:
Expand All @@ -5880,7 +5869,7 @@ def save_dataframe(
df[key] = df[key].mask(df[key] == val, None)

"""
if parallel:
if parallel and _mspasspy_has_dask:
df = daskdf.from_pandas(df, chunksize=1, sort=False)
df = df.apply(lambda x: x.dropna(), axis=1).compute()
else:
Expand Down
38 changes: 32 additions & 6 deletions python/mspasspy/io/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,24 @@
ErrorLogger,
)

import dask
import pyspark
try:
import dask

_mspasspy_has_dask = True
except ImportError:
_mspasspy_has_dask = False
try:
import pyspark

_mspasspy_has_pyspark = True
except ImportError:
_mspasspy_has_pyspark = False

if not _mspasspy_has_dask and not _mspasspy_has_pyspark:
message = "{} requires either dask or pyspark module. Please install dask or pyspark".format(
__name__
)
raise ModuleNotFoundError(message)


def read_ensemble_parallel(
Expand Down Expand Up @@ -460,6 +476,11 @@ def read_distributed_data(
message += "Unsupported value for scheduler={}\n".format(scheduler)
message += "Must be either 'dask' or 'spark'"
raise ValueError(message)
if scheduler == "spark" and not _mspasspy_has_pyspark:
print(
"WARNING(read_distributed_data): pyspark not found, will use dask instead. The scheduler argument is ignored."
)
scheduler = "dask"

if isinstance(data, list):
ensemble_mode = True
Expand Down Expand Up @@ -491,9 +512,10 @@ def read_distributed_data(
ensemble_mode = False
dataframe_input = False
db = data
elif isinstance(
data,
(pd.DataFrame, pyspark.sql.dataframe.DataFrame, dask.dataframe.core.DataFrame),
elif (
isinstance(data, pd.DataFrame)
or (_mspasspy_has_dask and isinstance(data, dask.dataframe.core.DataFrame))
or (_mspasspy_has_pyspark and isinstance(data, pyspark.sql.dataframe.DataFrame))
):
ensemble_mode = False
dataframe_input = True
Expand Down Expand Up @@ -539,7 +561,6 @@ def read_distributed_data(
)
raise TypeError(message)
container_partitions = container_to_merge.getNumPartitions()

else:
if not isinstance(container_to_merge, dask.bag.core.Bag):
message += (
Expand Down Expand Up @@ -1512,6 +1533,11 @@ class method `save_data`. See the docstring for details but the
)
message += "Must be either dask or spark"
raise ValueError(message)
if scheduler == "spark" and not _mspasspy_has_pyspark:
print(
"WARNING(write_distributed_data): pyspark not found, will use dask instead. The scheduler argument is ignored."
)
scheduler = "dask"
else:
scheduler = "dask"
# This use of the collection name to establish the schema is
Expand Down

0 comments on commit 91addbb

Please sign in to comment.