-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Orca: Reload DF or shard dataloader to keep consistence with pytorch dataloader #7728
Changes from 6 commits
837549c
ceca45e
f6148cb
6bfe520
b99718e
d39614e
ae5cbe8
88d2b58
07e146b
59d6bc7
bf24256
159804c
7972298
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,7 +76,11 @@ def __len__(self): | |
return get_size(self.y) | ||
|
||
def __getitem__(self, i): | ||
return index_data(self.x, i), index_data(self.y, i) | ||
index_data_x = index_data(self.x, i) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For we can only allocate x and y as two split part here, we need to reform multi-input as [x1, x2] as the whole x |
||
if isinstance(index_data_x, (list, tuple)): | ||
return (*index_data_x, index_data(self.y, i)) | ||
else: | ||
return (index_data_x, index_data(self.y, i)) | ||
|
||
params = {"batch_size": batch_size, "shuffle": True} | ||
for arg in ["shuffle", "sampler", "batch_sampler", "num_workers", "collate_fn", | ||
|
@@ -92,7 +96,7 @@ def __getitem__(self, i): | |
data_loader = DataLoader(dataset, **params) | ||
return data_loader | ||
|
||
return data_creator | ||
return reload_dataloader_creator(data_creator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we need to reload here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will disable reload_dataloader_creator in the next pr, now we just keep everything the same before we create the dataloader. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
|
||
|
||
def parse_model_dir(model_dir): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shoudn't we remove the reload in data_creator function if branch? (in fit) |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -302,7 +302,7 @@ def init_result_lists(first_row, cols): | |
return [[] for r in cols] | ||
|
||
def add_row(data, results, current): | ||
if not isinstance(data, list) and not isinstance(data, dict): | ||
if not isinstance(data, (list, tuple, dict)): | ||
arrays = [data] | ||
else: | ||
arrays = data | ||
|
@@ -318,15 +318,16 @@ def add_row(data, results, current): | |
feature_lists = None | ||
label_lists = None | ||
counter = 0 | ||
feature_tail = len(feature_cols) if feature_cols else None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this |
||
|
||
for row in iter: | ||
if feature_lists is None: | ||
feature_lists = init_result_lists(row[0], feature_cols) | ||
add_row(row[0], feature_lists, counter) | ||
feature_lists = init_result_lists(row[:feature_tail], feature_cols) | ||
add_row(row[:feature_tail], feature_lists, counter) | ||
if label_cols is not None: | ||
if label_lists is None: | ||
label_lists = init_result_lists(row[1], label_cols) | ||
add_row(row[1], label_lists, counter) | ||
label_lists = init_result_lists(get_label_row(row, feature_tail), label_cols) | ||
add_row(get_label_row(row, feature_tail), label_lists, counter) | ||
counter += 1 | ||
|
||
if shard_size and counter % shard_size == 0: | ||
|
@@ -357,6 +358,13 @@ def add_row(data, results, current): | |
arrays2pandas = partial(arrays2others, generate_func=_generate_output_pandas_df) | ||
|
||
|
||
def get_label_row(row, anchor): | ||
if anchor == len(row)-1: # In case label is the last one | ||
return row[-1] | ||
else: | ||
return row[anchor:] | ||
|
||
|
||
def transform_to_shard_dict(data, feature_cols, label_cols=None): | ||
def single_col_to_numpy(col_series, dtype): | ||
if dtype == np.ndarray: | ||
|
@@ -413,7 +421,8 @@ def _dataframe_to_xshards(data, feature_cols, label_cols=None, | |
schema, | ||
feature_cols, | ||
label_cols, | ||
accept_str_col)) | ||
accept_str_col, | ||
unpack_list=True)) | ||
shard_rdd = numpy_rdd.mapPartitions(lambda x: arrays2dict(x, | ||
feature_cols, | ||
label_cols, | ||
|
@@ -442,7 +451,8 @@ def dataframe_to_xshards_of_feature_dict(data, feature_cols, label_cols=None, | |
schema, | ||
feature_cols, | ||
label_cols, | ||
accept_str_col)) | ||
accept_str_col, | ||
unpack_list=True)) | ||
shard_rdd = numpy_rdd.mapPartitions(lambda x: arrays2feature_dict(x, | ||
feature_cols, | ||
label_cols, | ||
|
@@ -469,7 +479,8 @@ def dataframe_to_xshards_of_pandas_df(data, feature_cols, label_cols=None, accep | |
schema, | ||
feature_cols, | ||
label_cols, | ||
accept_str_col)) | ||
accept_str_col, | ||
unpack_list=True)) | ||
shard_rdd = numpy_rdd.mapPartitions(lambda x: arrays2pandas(x, | ||
feature_cols, | ||
label_cols, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We format multi-input as
f1, f2, label
instead of[f1, f2], label
to align with PyTorch DataLoader