Skip to content

Commit

Permalink
[pipeline]support List of Dict data (#1125)
Browse files Browse the repository at this point in the history
* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e650.

* [pipeline]support List of Dict data

* polish
  • Loading branch information
YuliangLiu0306 authored Jun 16, 2022
1 parent 91a5999 commit 3175bcb
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
4 changes: 3 additions & 1 deletion colossalai/engine/schedule/_base_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ def _get_batch_size(self, data):
if isinstance(data, torch.Tensor):
return data.size(0)
elif isinstance(data, (list, tuple)):
if isinstance(data[0], dict):
return data[0][list(data[0].keys())[0]].size(0)
return data[0].size(0)
elif isinstance(data, dict):
return data[next(data.keys())].size(0)
return data[list(data.keys())[0]].size(0)

def load_batch(self, data_iter, to_gpu=True):
"""Loads a batch from data iterator. It returns the data and labels which are
Expand Down
10 changes: 8 additions & 2 deletions colossalai/engine/schedule/_pipeline_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ def _get_data_slice(self, data, offset):
if isinstance(data, torch.Tensor):
return data[offset:offset + self.microbatch_size]
elif isinstance(data, (list, tuple)):
data_dict = {}
for element in data:
if isinstance(element, dict):
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
if data_dict:
return data_dict
return [val[offset:offset + self.microbatch_size] for val in data]
elif isinstance(data, dict):
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
Expand Down Expand Up @@ -216,7 +222,7 @@ def _get_data_label_for_current_step(self, stage_output, micro_batch_data, crite

# get all parameter names for the forward function of the model
fwd_sig = self._get_actual_forward_func(model)
fwd_sig_param_name = [p.name for p in fwd_sig.values()]
fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()]

# build the kwargs for the forward function
for idx, param_name in enumerate(fwd_sig_param_name):
Expand All @@ -228,7 +234,7 @@ def _get_data_label_for_current_step(self, stage_output, micro_batch_data, crite

# get the tensors for loss
loss_sig = inspect.signature(criterion)
loss_sig_param_name = [p.name for p in loss_sig.values()]
loss_sig_param_name = [p.name for p in loss_sig.parameters.values()]

for param_name in loss_sig_param_name:
if param_name in micro_batch_data:
Expand Down
8 changes: 3 additions & 5 deletions colossalai/pipeline/pipelinable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
A context manager to split the model into pipeline stages.
"""

def __init__(self, policy: str="balanced"):
def __init__(self, policy: str = "balanced"):
super().__init__()
self._layer_spec_dict = {}
self._root_children = None
Expand Down Expand Up @@ -61,11 +61,12 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
# iterate over the positional arguments
# to check if an argument is a torch Module
# if found any torch Module, replace it with its layer spec
# if found any torch Module, replace it with its layer spec
# for storage purpose
modified_args = []
for arg in args:
if isinstance(arg, torch.nn.Module):
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself.
arg = self._layer_spec_dict[id(arg)]
modified_args.append(arg)

Expand Down Expand Up @@ -255,6 +256,3 @@ def forward(self, input_tensor, **kwargs):
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)

return input_tensor



0 comments on commit 3175bcb

Please sign in to comment.