Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelizes URL reads for images using Ray/Multithreading #2048

Merged
merged 77 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
57964de
wip
geoffreyangus May 16, 2022
764582d
debugging nans
geoffreyangus May 17, 2022
add6f60
Merge branch 'master' into speedup-url-load
geoffreyangus May 17, 2022
34ea789
failing parity test
geoffreyangus May 17, 2022
f81969a
not passing auc parity test... w logs
geoffreyangus May 17, 2022
1755fda
audio feature works
geoffreyangus May 17, 2022
dea61b3
cleanup and revert image changes to prepare for image work
geoffreyangus May 17, 2022
eaf8dcc
further cleanup
geoffreyangus May 17, 2022
1e44899
added batch size
geoffreyangus May 18, 2022
8e78736
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 18, 2022
95aecac
remove batch size
geoffreyangus May 18, 2022
15ff78b
Merge branch 'speedup-url-load' of https://github.com/ludwig-ai/ludwi…
geoffreyangus May 18, 2022
6ade301
Merge branch 'master' into speedup-url-load
geoffreyangus May 19, 2022
2b06eb0
address nit
geoffreyangus May 19, 2022
32febdb
cleanup
geoffreyangus May 19, 2022
9b66ea5
adds support for nans and unit test
geoffreyangus May 20, 2022
a150c91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2022
7522ca6
adds type hint and fixes abstract class definition
geoffreyangus May 20, 2022
933bd38
merge
geoffreyangus May 20, 2022
b1a904f
fix docstring
geoffreyangus May 20, 2022
4ec7108
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2022
a2c40aa
add pandas to test
geoffreyangus May 20, 2022
8133338
Merge branch 'speedup-url-load' of https://github.com/ludwig-ai/ludwi…
geoffreyangus May 20, 2022
90e703d
wip
geoffreyangus May 20, 2022
3884240
wip
geoffreyangus May 20, 2022
c701b66
removed legacy audio fns and renamed bytes fns
geoffreyangus May 20, 2022
c0f10c8
refactor + read_binary_files
geoffreyangus May 20, 2022
6aa79e5
removed pd indexing warning and precommit errors, added nans to ray a…
geoffreyangus May 20, 2022
054f54d
Merge branch 'speedup-url-load' into image-url-load
geoffreyangus May 20, 2022
0f60863
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2022
59d3c18
remove np.ndarray type
geoffreyangus May 20, 2022
1342faf
added nans test for csv dataset type
geoffreyangus May 20, 2022
0bc6377
Merge branch 'speedup-url-load' of https://github.com/ludwig-ai/ludwi…
geoffreyangus May 20, 2022
f24cf16
remove prints
geoffreyangus May 20, 2022
806f312
Merge branch 'speedup-url-load' into image-url-load
geoffreyangus May 20, 2022
2fff70d
added test for csv data type
geoffreyangus May 20, 2022
d5b0e48
use backend.df_engine for map
geoffreyangus May 20, 2022
8d4370b
add back in src_path functionality
geoffreyangus May 20, 2022
38efacf
Merge branch 'speedup-url-load' into image-url-load
geoffreyangus May 20, 2022
ac6c521
cleanup
geoffreyangus May 20, 2022
ed700ca
standardize nan percent
geoffreyangus May 20, 2022
0a9f25a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2022
bff817f
typo
geoffreyangus May 21, 2022
a306754
typo
geoffreyangus May 21, 2022
93930e3
Merge branch 'speedup-url-load' into image-url-load
geoffreyangus May 21, 2022
62d01a0
Merge branch 'image-url-load' of https://github.com/ludwig-ai/ludwig …
geoffreyangus May 21, 2022
808ab48
merge
geoffreyangus Jun 3, 2022
ffab9f0
remove nan testing for now (until #2058 merged)
geoffreyangus Jun 3, 2022
1973b42
revert last change (on wrong branch)
geoffreyangus Jun 3, 2022
6c18ba2
github ui weirdness... revert the revert
geoffreyangus Jun 3, 2022
25d32f3
unify map_abs_path_to_entries b/t audio and image
geoffreyangus Jun 3, 2022
5339d0e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2022
62a1399
refactored to include read_image_from_path function
geoffreyangus Jun 3, 2022
5691267
Merge branch 'image-url-load' of https://github.com/ludwig-ai/ludwig …
geoffreyangus Jun 3, 2022
7a34002
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2022
8bac2ad
cleaned up naming and comments
geoffreyangus Jun 3, 2022
dd0e3a7
Merge branch 'image-url-load' of https://github.com/ludwig-ai/ludwig …
geoffreyangus Jun 3, 2022
65ca19b
fix docstring
geoffreyangus Jun 3, 2022
2843e7b
merge
geoffreyangus Jun 6, 2022
2fb77f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2022
8770f4f
implement PR revisions
geoffreyangus Jun 6, 2022
cfff33b
add del buffer_view back
geoffreyangus Jun 7, 2022
9b53a08
addressed PR comments and deleted extraneous functions
geoffreyangus Jun 7, 2022
d8f3371
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2022
602914a
cleanup
geoffreyangus Jun 7, 2022
5b2e857
Merge branch 'image-url-load' of https://github.com/ludwig-ai/ludwig …
geoffreyangus Jun 7, 2022
a356894
Merge branch 'master' into image-url-load
geoffreyangus Jun 7, 2022
b92effe
removed read_*_if_* functions
geoffreyangus Jun 7, 2022
428f19d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2022
55a90cc
remove map_abs_path_if_entries
geoffreyangus Jun 7, 2022
5e3bb99
Merge branch 'image-url-load' of https://github.com/ludwig-ai/ludwig …
geoffreyangus Jun 7, 2022
b30c053
simplified getting abs path
geoffreyangus Jun 7, 2022
bb6fdb8
merge
geoffreyangus Jun 10, 2022
d4c2b09
Merge branch 'master' into image-url-load
geoffreyangus Jun 13, 2022
f20454d
add check for remote protocol
geoffreyangus Jun 13, 2022
fc9c9a1
Merge branch 'master' into image-url-load
geoffreyangus Jun 14, 2022
42f301e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions ludwig/features/audio_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,16 +372,6 @@ def _get_2D_feature(
else:
raise ValueError(f'feature_type "{feature_type}" is not recognized.')

@staticmethod
def map_abs_path_to_entries(column, src_path, backend):
def get_abs_path_if_entry_is_str(entry):
if not isinstance(entry, str) or has_remote_protocol(entry):
return entry
else:
return get_abs_path(src_path, entry)

return backend.df_engine.map_objects(column, get_abs_path_if_entry_is_str)

@staticmethod
def add_feature_data(
feature_config, input_df, proc_df, metadata, preprocessing_parameters, backend, skip_save_processed_input
Expand Down Expand Up @@ -413,7 +403,9 @@ def add_feature_data(
if SRC in metadata:
if isinstance(first_audio_entry, str) and not has_remote_protocol(first_audio_entry):
src_path = os.path.dirname(os.path.abspath(metadata.get(SRC)))
abs_path_column = AudioFeatureMixin.map_abs_path_to_entries(column, src_path, backend)
abs_path_column = backend.df_engine.map_objects(
column, lambda row: get_abs_path(src_path, row) if isinstance(row, str) else row
)

num_audio_utterances = len(input_df[feature_config[COLUMN]])
padding_value = preprocessing_parameters["padding_value"]
Expand Down
188 changes: 61 additions & 127 deletions ludwig/features/image_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
import os
from collections import Counter
from functools import partial
from multiprocessing import Pool
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import requests
import torch
import torchvision

Expand Down Expand Up @@ -49,20 +47,17 @@
from ludwig.data.cache.types import wrap
from ludwig.features.base_feature import BaseFeatureMixin, InputFeature
from ludwig.utils.data_utils import get_abs_path
from ludwig.utils.fs_utils import has_remote_protocol, upload_h5
from ludwig.utils.fs_utils import upload_h5
from ludwig.utils.image_utils import (
get_gray_default_image,
get_image_from_path,
grayscale,
num_channels_in_image,
read_image,
read_image_from_bytes_obj,
read_image_from_path,
resize_image,
)
from ludwig.utils.misc_utils import set_default_value
from ludwig.utils.types import TorchscriptPreprocessingInput

logger = logging.getLogger(__name__)

from ludwig.utils.types import Series, TorchscriptPreprocessingInput

# TODO(shreya): Confirm if it's ok to do per channel normalization
# TODO(shreya): Also confirm if this is being used anywhere
Expand Down Expand Up @@ -170,8 +165,8 @@ def get_feature_meta(column, preprocessing_parameters, backend):
return {PREPROCESSING: preprocessing_parameters}

@staticmethod
def _read_image_and_resize(
img_entry: Union[str, torch.Tensor],
def _read_image_if_bytes_obj_and_resize(
img_entry: Union[bytes, torch.Tensor],
img_width: int,
img_height: int,
should_resize: bool,
Expand All @@ -180,8 +175,8 @@ def _read_image_and_resize(
user_specified_num_channels: bool,
) -> Optional[np.ndarray]:
"""
:param img_entry Union[str, 'numpy.array']: if str file path to the
image else numpy.array of the image itself
:param img_entry Union[bytes, torch.Tensor]: if str file path to the
image else torch.Tensor of the image itself
:param img_width: expected width of the image
:param img_height: expected height of the image
:param should_resize: Should the image be resized?
Expand All @@ -199,11 +194,15 @@ def _read_image_and_resize(
If the user specifies a number of channels, we try to convert all the
images to the specifications by dropping channels/padding 0 channels
"""
if isinstance(img_entry, bytes):
img = read_image_from_bytes_obj(img_entry, num_channels)
else:
img = img_entry

img = read_image(img_entry, num_channels)
if img is None:
logger.info(f"{img_entry} cannot be read")
if not isinstance(img, torch.Tensor):
logging.info(f"Image with value {img} cannot be read")
return None

img_num_channels = num_channels_in_image(img)
# Convert to grayscale if needed.
if num_channels == 1 and img_num_channels != 1:
Expand All @@ -225,7 +224,7 @@ def _read_image_and_resize(
img = torch.nn.functional.pad(img, [0, 0, 0, 0, 0, extra_channels])

if img_num_channels != num_channels:
logger.warning(
logging.warning(
"Image has {} channels, where as {} "
"channels are expected. Dropping/adding channels "
"with 0s as appropriate".format(img_num_channels, num_channels)
Expand Down Expand Up @@ -273,7 +272,7 @@ def _infer_image_size(image_sample: List[torch.Tensor], max_height: int, max_wid
height = min(int(round(height_avg)), max_height)
width = min(int(round(width_avg)), max_width)

logger.debug(f"Inferring height: {height} and width: {width}")
logging.debug(f"Inferring height: {height} and width: {width}")
return height, width

@staticmethod
Expand Down Expand Up @@ -316,9 +315,7 @@ def _infer_number_of_channels(image_sample: List[torch.Tensor]):
@staticmethod
def _finalize_preprocessing_parameters(
preprocessing_parameters: dict,
first_img_entry: Optional[Union[str, torch.Tensor]],
src_path: str,
input_feature_col: np.array,
column: Series,
) -> Tuple:
"""Helper method to determine the height, width and number of channels for preprocessing the image data.

Expand All @@ -327,30 +324,25 @@ def _finalize_preprocessing_parameters(
expected be of the same size with the same number of channels
"""

explicit_height_width = (
HEIGHT in preprocessing_parameters or WIDTH in preprocessing_parameters
) and first_img_entry is not None
explicit_num_channels = (
NUM_CHANNELS in preprocessing_parameters and preprocessing_parameters[NUM_CHANNELS]
) and first_img_entry is not None
explicit_height_width = HEIGHT in preprocessing_parameters or WIDTH in preprocessing_parameters
explicit_num_channels = NUM_CHANNELS in preprocessing_parameters and preprocessing_parameters[NUM_CHANNELS]

if explicit_num_channels:
first_image = read_image(first_img_entry, preprocessing_parameters[NUM_CHANNELS])
sample = []
if preprocessing_parameters[INFER_IMAGE_DIMENSIONS] and not (explicit_height_width and explicit_num_channels):
sample_size = min(len(column), preprocessing_parameters[INFER_IMAGE_SAMPLE_SIZE])
else:
first_image = read_image(first_img_entry)
sample_size = 1 # Take first image

inferred_sample = None
if preprocessing_parameters[INFER_IMAGE_DIMENSIONS] and not (explicit_height_width and explicit_num_channels):
sample_size = min(len(input_feature_col), preprocessing_parameters[INFER_IMAGE_SAMPLE_SIZE])
sample = []
for img in input_feature_col.head(sample_size):
try:
sample.append(read_image(get_image_from_path(src_path, img, ret_bytes=True)))
except requests.exceptions.HTTPError:
pass
inferred_sample = [img for img in sample if img is not None]
if not inferred_sample:
raise ValueError("No readable images in sample, image dimensions cannot be inferred")
for image_entry in column.head(sample_size):
if isinstance(image_entry, str):
image = read_image_from_path(image_entry)
else:
image = image_entry

if isinstance(image, torch.Tensor):
sample.append(image)
if len(sample) == 0:
raise ValueError("No readable images in sample, image dimensions cannot be inferred")

should_resize = False
if explicit_height_width:
Expand All @@ -368,12 +360,10 @@ def _finalize_preprocessing_parameters(
if preprocessing_parameters[INFER_IMAGE_DIMENSIONS]:
should_resize = True
height, width = ImageFeatureMixin._infer_image_size(
inferred_sample,
sample,
max_height=preprocessing_parameters[INFER_IMAGE_MAX_HEIGHT],
max_width=preprocessing_parameters[INFER_IMAGE_MAX_WIDTH],
)
elif first_image is not None:
height, width = first_image.shape[0], first_image.shape[1]
else:
raise ValueError(
"Explicit image width/height are not set, infer_image_dimensions is false, "
Expand All @@ -388,73 +378,48 @@ def _finalize_preprocessing_parameters(
user_specified_num_channels = False
if preprocessing_parameters[INFER_IMAGE_DIMENSIONS]:
user_specified_num_channels = True
num_channels = ImageFeatureMixin._infer_number_of_channels(inferred_sample)
elif first_image is not None:
num_channels = num_channels_in_image(first_image)
num_channels = ImageFeatureMixin._infer_number_of_channels(sample)
elif len(sample) > 0:
num_channels = num_channels_in_image(sample[0])
else:
raise ValueError(
"Explicit image num channels is not set, infer_image_dimensions is false, "
"and first image cannot be read, so image num channels is unknown"
)

assert isinstance(num_channels, int), ValueError("Number of image channels needs to be an integer")

return (should_resize, width, height, num_channels, user_specified_num_channels, first_image)
return (should_resize, width, height, num_channels, user_specified_num_channels)

@staticmethod
def add_feature_data(
feature_config, input_df, proc_df, metadata, preprocessing_parameters, backend, skip_save_processed_input
):
set_default_value(feature_config["preprocessing"], "in_memory", preprocessing_parameters["in_memory"])

in_memory = preprocessing_parameters["in_memory"]
if PREPROCESSING in feature_config and "in_memory" in feature_config[PREPROCESSING]:
in_memory = feature_config[PREPROCESSING]["in_memory"]

num_processes = preprocessing_parameters["num_processes"]
if PREPROCESSING in feature_config and "num_processes" in feature_config[PREPROCESSING]:
num_processes = feature_config[PREPROCESSING]["num_processes"]
set_default_value(feature_config[PREPROCESSING], "in_memory", preprocessing_parameters["in_memory"])

num_images = len(input_df[feature_config[COLUMN]])
if num_images == 0:
raise ValueError("There are no images in the dataset provided.")

first_img_entry = next(iter(input_df[feature_config[COLUMN]]))
logger.debug(f"Detected image feature type is {type(first_img_entry)}")

if not isinstance(first_img_entry, str) and not isinstance(first_img_entry, torch.Tensor):
raise ValueError(
"Invalid image feature data type. Detected type is {}, "
"expect either string for file path or numpy array.".format(type(first_img_entry))
)
name = feature_config[NAME]
column = input_df[feature_config[COLUMN]]

src_path = None
if SRC in metadata:
if isinstance(first_img_entry, str) and not has_remote_protocol(first_img_entry):
src_path = os.path.dirname(os.path.abspath(metadata.get(SRC)))

try:
first_img_entry = get_image_from_path(src_path, first_img_entry, ret_bytes=True)
except requests.exceptions.HTTPError:
first_img_entry = None
src_path = os.path.dirname(os.path.abspath(metadata.get(SRC)))
abs_path_column = backend.df_engine.map_objects(
column, lambda row: get_abs_path(src_path, row) if isinstance(row, str) else row
)

(
should_resize,
width,
height,
num_channels,
user_specified_num_channels,
first_image,
) = ImageFeatureMixin._finalize_preprocessing_parameters(
preprocessing_parameters, first_img_entry, src_path, input_df[feature_config[COLUMN]]
)
) = ImageFeatureMixin._finalize_preprocessing_parameters(preprocessing_parameters, abs_path_column)

metadata[feature_config[NAME]][PREPROCESSING]["height"] = height
metadata[feature_config[NAME]][PREPROCESSING]["width"] = width
metadata[feature_config[NAME]][PREPROCESSING]["num_channels"] = num_channels
metadata[name][PREPROCESSING]["height"] = height
metadata[name][PREPROCESSING]["width"] = width
metadata[name][PREPROCESSING]["num_channels"] = num_channels

read_image_and_resize = partial(
ImageFeatureMixin._read_image_and_resize,
read_image_if_bytes_obj_and_resize = partial(
ImageFeatureMixin._read_image_if_bytes_obj_and_resize,
img_width=width,
img_height=height,
should_resize=should_resize,
Expand All @@ -470,55 +435,24 @@ def add_feature_data(
# image features from the hdf5 cache.
backend.check_lazy_load_supported(feature_config)

in_memory = feature_config[PREPROCESSING]["in_memory"]
if in_memory or skip_save_processed_input:
# Number of processes to run in parallel for preprocessing
metadata[feature_config[NAME]][PREPROCESSING]["num_processes"] = num_processes
metadata[feature_config[NAME]]["reshape"] = (num_channels, height, width)

# Split the dataset into pools only if we have an explicit request to use
# multiple processes. In case we have multiple input images use the
# standard code anyway.
if backend.supports_multiprocessing and (num_processes > 1 or num_images > 1):
all_img_entries = [
get_abs_path(src_path, img_entry) if isinstance(img_entry, str) else img_entry
for img_entry in input_df[feature_config[COLUMN]]
]

with Pool(num_processes) as pool:
logger.debug(f"Using {num_processes} processes for preprocessing images")
res = pool.map(read_image_and_resize, all_img_entries)
proc_df[feature_config[PROC_COLUMN]] = [x if x is not None else default_image for x in res]
else:
# If we're not running multiple processes and we are only processing one
# image just use this faster shortcut, bypassing multiprocessing.Pool.map
logger.debug("No process pool initialized. Using internal process for preprocessing images")

# helper function for handling single image
def _get_processed_image(img_store):
if isinstance(img_store, str):
res_single = read_image_and_resize(get_abs_path(src_path, img_store))
else:
res_single = read_image_and_resize(img_store)
return res_single if res_single is not None else default_image

proc_df[feature_config[PROC_COLUMN]] = backend.df_engine.map_objects(
input_df[feature_config[COLUMN]], _get_processed_image
)
else:
metadata[name]["reshape"] = (num_channels, height, width)

all_img_entries = [
get_abs_path(src_path, img_entry) if isinstance(img_entry, str) else img_entry
for img_entry in input_df[feature_config[COLUMN]]
]
proc_col = backend.read_binary_files(abs_path_column, map_fn=read_image_if_bytes_obj_and_resize)
dantreiman marked this conversation as resolved.
Show resolved Hide resolved
proc_col = backend.df_engine.map_objects(proc_col, lambda row: row if row is not None else default_image)
proc_df[feature_config[PROC_COLUMN]] = proc_col
else:
num_images = len(abs_path_column)

data_fp = backend.cache.get_cache_path(wrap(metadata.get(SRC)), metadata.get(CHECKSUM), TRAINING)
with upload_h5(data_fp) as h5_file:
# todo future add multiprocessing/multithreading
image_dataset = h5_file.create_dataset(
feature_config[PROC_COLUMN] + "_data", (num_images, num_channels, height, width), dtype=np.uint8
)
for i, img_entry in enumerate(all_img_entries):
res = read_image_and_resize(img_entry)
for i, img_entry in enumerate(abs_path_column):
res = read_image_if_bytes_obj_and_resize(img_entry)
image_dataset[i, :height, :width, :] = res if res is not None else default_image
h5_file.flush()

Expand Down
8 changes: 4 additions & 4 deletions ludwig/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from ludwig.features.feature_registries import input_type_registry, output_type_registry
from ludwig.features.feature_utils import get_module_dict_key_from_name, get_name_from_module_dict_key
from ludwig.globals import INFERENCE_MODULE_FILE_NAME, MODEL_HYPERPARAMETERS_FILE_NAME, TRAIN_SET_METADATA_FILE_NAME
from ludwig.utils import image_utils
from ludwig.utils.audio_utils import read_audio_if_path
from ludwig.utils.audio_utils import read_audio_from_path
from ludwig.utils.image_utils import read_image_from_path
from ludwig.utils.types import TorchscriptPreprocessingInput

# Prevents circular import errors from typing.
Expand Down Expand Up @@ -117,10 +117,10 @@ def to_inference_module_input(s: pd.Series, feature_type: str, load_paths=False)
"""Converts a pandas Series to be compatible with a torchscripted InferenceModule forward pass."""
if feature_type == "image":
if load_paths:
return [image_utils.read_image(v) for v in s]
return [read_image_from_path(v) if isinstance(v, str) else v for v in s]
elif feature_type == "audio":
if load_paths:
return [read_audio_if_path(v) for v in s]
return [read_audio_from_path(v) if isinstance(v, str) else v for v in s]
if feature_type in {"binary", "category", "bag", "set", "text", "sequence", "timeseries"}:
return s.astype(str).to_list()
return torch.from_numpy(s.to_numpy())
10 changes: 6 additions & 4 deletions ludwig/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from sklearn.model_selection import KFold

from ludwig.data.cache.types import CacheableDataset
from ludwig.utils.fs_utils import download_h5, open_file, upload_h5
from ludwig.utils.fs_utils import download_h5, has_remote_protocol, open_file, upload_h5
from ludwig.utils.misc_utils import get_from_registry

try:
Expand Down Expand Up @@ -111,9 +111,11 @@ def get_split_path(dataset_fp):
return os.path.splitext(dataset_fp)[0] + ".split.csv"


def get_abs_path(data_csv_path, file_path):
if data_csv_path is not None:
return os.path.join(data_csv_path, file_path)
def get_abs_path(src_path, file_path):
if has_remote_protocol(file_path):
return file_path
elif src_path is not None:
return os.path.join(src_path, file_path)
else:
return file_path

Expand Down
Loading