New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Orca: Reload DF or shard dataloader to keep consistence with pytorch dataloader #7728
Conversation
return result[0] | ||
return result | ||
|
||
features = convert_for_cols(row, feature_cols) | ||
# For pytorch we format multi-input as `f1, f2, label` instead of `[f1, f2], label` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We format multi-input as f1, f2, label
instead of [f1, f2], label
to align with PyTorch DataLoader
@@ -92,7 +96,7 @@ def __getitem__(self, i): | |||
data_loader = DataLoader(dataset, **params) | |||
return data_loader | |||
|
|||
return data_creator | |||
return reload_dataloader_creator(data_creator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we need to reload here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will disable reload_dataloader_creator in the next pr, now we just keep everything the same before we create the dataloader.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
@@ -92,7 +96,7 @@ def __getitem__(self, i): | |||
data_loader = DataLoader(dataset, **params) | |||
return data_loader | |||
|
|||
return data_creator | |||
return reload_dataloader_creator(data_creator) | |||
|
|||
|
|||
def parse_model_dir(model_dir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shoudn't we remove the reload in data_creator function if branch? (in fit)
@@ -413,7 +421,8 @@ def _dataframe_to_xshards(data, feature_cols, label_cols=None, | |||
schema, | |||
feature_cols, | |||
label_cols, | |||
accept_str_col)) | |||
accept_str_col, | |||
True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add the keyword here so that readers can easier to understand
@@ -442,7 +451,8 @@ def dataframe_to_xshards_of_feature_dict(data, feature_cols, label_cols=None, | |||
schema, | |||
feature_cols, | |||
label_cols, | |||
accept_str_col)) | |||
accept_str_col, | |||
True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will setting always be True impact TF related logic? On the other hand, are there any place this argument is False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we only set this to true under pytorch estimator fields to prevent any influence on tf
@@ -76,7 +76,11 @@ def __len__(self): | |||
return get_size(self.y) | |||
|
|||
def __getitem__(self, i): | |||
return index_data(self.x, i), index_data(self.y, i) | |||
index_data_x = index_data(self.x, i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For we can only allocate x and y as two split part here, we need to reform multi-input as [x1, x2] as the whole x
@@ -318,6 +318,7 @@ def add_row(data, results, current): | |||
feature_lists = None | |||
label_lists = None | |||
counter = 0 | |||
feature_tail = len(feature_cols) if feature_cols else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this
…dataloader (intel-analytics#7728) * Orca: reload dataloader when df or shard * reformat df list * Only in pytorch * specify by feature len
Description
Consider the case when the features from xshards is a list of tensors, our runner will unpack it by mistake:
https://github.com/intel-analytics/BigDL/blob/main/python/orca/src/bigdl/orca/learn/pytorch/torch_runner.py#L423
So this pr just simply applies modifications of #5763 on df and xshards again.
For short, Ndarray dataset of DF or shard format a batch like [x1, x2], [y1, y2] if there are multiple input or output previously, and a pytorch dataloador will format a batch like x1, x2, [y1, y2] in this case. So this pr just keeps both cases consistent: x1, x2, [y1, y2]
And also we may support a more flexiable way to split and reorganize feature col and label col based on the length of feature cols.