Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipeline]support List of Dict data #1125

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion colossalai/engine/schedule/_base_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ def _get_batch_size(self, data):
if isinstance(data, torch.Tensor):
return data.size(0)
elif isinstance(data, (list, tuple)):
if isinstance(data[0], dict):
return data[0][list(data[0].keys())[0]].size(0)
return data[0].size(0)
elif isinstance(data, dict):
return data[next(data.keys())].size(0)
return data[list(data.keys())[0]].size(0)

def load_batch(self, data_iter, to_gpu=True):
"""Loads a batch from data iterator. It returns the data and labels which are
Expand Down
10 changes: 8 additions & 2 deletions colossalai/engine/schedule/_pipeline_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ def _get_data_slice(self, data, offset):
if isinstance(data, torch.Tensor):
return data[offset:offset + self.microbatch_size]
elif isinstance(data, (list, tuple)):
data_dict = {}
for element in data:
if isinstance(element, dict):
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
if data_dict:
return data_dict
return [val[offset:offset + self.microbatch_size] for val in data]
elif isinstance(data, dict):
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
Expand Down Expand Up @@ -216,7 +222,7 @@ def _get_data_label_for_current_step(self, stage_output, micro_batch_data, crite

# get all parameter names for the forward function of the model
fwd_sig = self._get_actual_forward_func(model)
fwd_sig_param_name = [p.name for p in fwd_sig.values()]
fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()]

# build the kwargs for the forward function
for idx, param_name in enumerate(fwd_sig_param_name):
Expand All @@ -228,7 +234,7 @@ def _get_data_label_for_current_step(self, stage_output, micro_batch_data, crite

# get the tensors for loss
loss_sig = inspect.signature(criterion)
loss_sig_param_name = [p.name for p in loss_sig.values()]
loss_sig_param_name = [p.name for p in loss_sig.parameters.values()]

for param_name in loss_sig_param_name:
if param_name in micro_batch_data:
Expand Down
8 changes: 3 additions & 5 deletions colossalai/pipeline/pipelinable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
A context manager to split the model into pipeline stages.
"""

def __init__(self, policy: str="balanced"):
def __init__(self, policy: str = "balanced"):
super().__init__()
self._layer_spec_dict = {}
self._root_children = None
Expand Down Expand Up @@ -61,11 +61,12 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
# iterate over the positional arguments
# to check if an argument is a torch Module
# if found any torch Module, replace it with its layer spec
# if found any torch Module, replace it with its layer spec
# for storage purpose
modified_args = []
for arg in args:
if isinstance(arg, torch.nn.Module):
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself.
arg = self._layer_spec_dict[id(arg)]
modified_args.append(arg)

Expand Down Expand Up @@ -255,6 +256,3 @@ def forward(self, input_tensor, **kwargs):
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)

return input_tensor