New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove mandatory index
key from output of metric_function
in DataAnalysis
map operation
#5112
remove mandatory index
key from output of metric_function
in DataAnalysis
map operation
#5112
Conversation
DataAnalysis
index
key from output of metric_function in DataAnalysis
map operation
index
key from output of metric_function in DataAnalysis
map operationindex
key from output of metric_function
in DataAnalysis
map operation
index
key from output of metric_function
in DataAnalysis
map operationindex
key from output of metric_function
in DataAnalysis
map operation [ONGOING TESTING]
@microsoft-github-policy-service agree |
index
key from output of metric_function
in DataAnalysis
map operation [ONGOING TESTING]index
key from output of metric_function
in DataAnalysis
map operation
@bm-synth Could you resolve the conflicts? Thanks. |
@conglongli done |
@bm-synth Thanks you. On the other hand, after reading this PR's details, I'm concerning that your PR might not be able to replace the index key. The index key is because that user's dataset may have shuffling feature, so we have to ask user to always provide an index to indicate "inside the data, what is the exact index of this sample". Otherwise we could make a wrong connection between the data sample and the curriculum difficulty value. This PR basically assumes that the data is always in-order, which might not be always the case. You can refer to how I do the data analysis at here https://github.com/microsoft/Megatron-DeepSpeed/blob/6d4c535eeae782daa22583fd8abac7cec3bb60f2/examples_deepspeed/data_efficiency/gpt/ds_analyze_gpt_data_map.sh#L66 where I have to add a "--return-data-index" flag to return the actual index. |
@bm-synth To further clarify: in Megatron-DeepSpeed and Megatron-LM, the dataset is shuffled even before reaching sampler https://github.com/microsoft/Megatron-DeepSpeed/blob/6d4c535eeae782daa22583fd8abac7cec3bb60f2/megatron/data/gpt_dataset.py#L597. This is why even if we use a SequentialSampler for data analysis, the data could still be shuffled. Thus an index key provided by user is needed. |
@bm-synth After some more thinking, I think there is still values in your approach and it should work in many cases. So my proposal is that: we still keep the index key and use it when user provides it. Otherwise we use your approach. |
@conglongli ok i saw it in the megatron source code:
where
so This is a quick code I just wrote to test my theory (try with
with
and with
and this shows that the Also thinking about it, your
I looked at that particular code and I believe that it changes the problem, this is a "non-standard" shuffling procedure done outside So i added a new commit where i support:
|
Added missing `ininstance` check in [#5112. --------- Co-authored-by: Conglong Li <conglong.li@gmail.com>
…aAnalysis` map operation (microsoft#5112) When performing the map operation required for the curriculum learning, the output of `metric_function` requires an `index` field: ``` def update_metric_results(self, data, metric_types, metric_dtypes, metric_functions, metric_results): for m_idx in range(len(metric_types)): [...] if metric_type == 'single_value_per_sample': for row in range(metric_values.size()[0]): metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1)) metric_result["metric_to_sample_dict"][metric_values[row].item()].append( data['index'][row][0].item()). ##<------- data['index']?? ``` There is no mention to this in the documentation, where it specifies that the output of `metric_function` should be a dict/DataFrame (?) with an `index` key/column. To makes things worse, on top of that, there is no way for an user to be able to specify a proper `index` value for each sample, because the distribution of samples across workers/threads is not know, as it's done inside `DataAnalysis`: ``` def run_map_helper(self, thread_id): start_idx, end_idx = self.thread_splits[thread_id][0], \ self.thread_splits[thread_id][1] logger.info(f"worker {self.worker_id} thread {thread_id}: start working " \ f"on data subset {start_idx} to {end_idx}") thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx))) sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False) ``` Since by design you picked a `SequentialSampler`, then you know beforehand the global id of each each sample of each batch of each thread of each worker by looking at ``` self.worker_splits, self.thread_splits = split_dataset(self.dataset, self.num_workers, self.worker_id, self.num_threads) start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1] ``` and you can populate that index value correctly, instead of asking the user to provide it. This PR removes the need for `'index'` key in `data` and uses instead the batch, thread, and worker ids to compute the global index of each sample.
Added missing `ininstance` check in [microsoft#5112. --------- Co-authored-by: Conglong Li <conglong.li@gmail.com>
…aAnalysis` map operation (microsoft#5112) When performing the map operation required for the curriculum learning, the output of `metric_function` requires an `index` field: ``` def update_metric_results(self, data, metric_types, metric_dtypes, metric_functions, metric_results): for m_idx in range(len(metric_types)): [...] if metric_type == 'single_value_per_sample': for row in range(metric_values.size()[0]): metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1)) metric_result["metric_to_sample_dict"][metric_values[row].item()].append( data['index'][row][0].item()). ##<------- data['index']?? ``` There is no mention to this in the documentation, where it specifies that the output of `metric_function` should be a dict/DataFrame (?) with an `index` key/column. To makes things worse, on top of that, there is no way for an user to be able to specify a proper `index` value for each sample, because the distribution of samples across workers/threads is not know, as it's done inside `DataAnalysis`: ``` def run_map_helper(self, thread_id): start_idx, end_idx = self.thread_splits[thread_id][0], \ self.thread_splits[thread_id][1] logger.info(f"worker {self.worker_id} thread {thread_id}: start working " \ f"on data subset {start_idx} to {end_idx}") thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx))) sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False) ``` Since by design you picked a `SequentialSampler`, then you know beforehand the global id of each each sample of each batch of each thread of each worker by looking at ``` self.worker_splits, self.thread_splits = split_dataset(self.dataset, self.num_workers, self.worker_id, self.num_threads) start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1] ``` and you can populate that index value correctly, instead of asking the user to provide it. This PR removes the need for `'index'` key in `data` and uses instead the batch, thread, and worker ids to compute the global index of each sample.
Added missing `ininstance` check in [microsoft#5112. --------- Co-authored-by: Conglong Li <conglong.li@gmail.com>
When performing the map operation required for the curriculum learning, the output of
metric_function
requires anindex
field:There is no mention to this in the documentation, where it specifies that the output of
metric_function
should be a dict/DataFrame (?) with anindex
key/column. To makes things worse, on top of that, there is no way for an user to be able to specify a properindex
value for each sample, because the distribution of samples across workers/threads is not know, as it's done insideDataAnalysis
:Since by design you picked a
SequentialSampler
, then you know beforehand the global id of each each sample of each batch of each thread of each worker by looking atand you can populate that index value correctly, instead of asking the user to provide it.
This PR removes the need for
'index'
key indata
and uses instead the batch, thread, and worker ids to compute the global index of each sample.