Skip to content

Commit

Permalink
refactor _split_model_inputs function for sequential beam search.
Browse files Browse the repository at this point in the history
  • Loading branch information
Saibo-creator committed Jan 16, 2024
1 parent 6054c33 commit d78d038
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 47 deletions.
4 changes: 2 additions & 2 deletions src/transformers/generation/configuration_utils.py
Expand Up @@ -200,8 +200,8 @@ class GenerationConfig(PushToHubMixin):
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
low_memory (`bool`, *optional*):
Switch to sequential topk for contrastive search to reduce peak memory. Used with beam search
and contrastive search.
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
Used with beam search and contrastive search.
> Parameters that define the output variables of `generate`
Expand Down
101 changes: 56 additions & 45 deletions src/transformers/generation/utils.py
Expand Up @@ -4916,76 +4916,87 @@ def _ranking_fast(
return selected_idx


def _split(data, full_batch_size: int, split_size: int = None):
"""
Takes care of three cases:
1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim
2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and
return a list of tuples
3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and
return a list of tuples of tuples
(see documentation of ModelOutput)
"""
if data is None:
return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple):
return [
tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data)
for i in range(0, full_batch_size, split_size)
]

else:
return [
tuple(sub_tensor[i : i + split_size] for sub_tensor in data)
for i in range(0, full_batch_size, split_size)
]
else:
raise ValueError(f"Unexpected attribute type: {type(data)}")


def _split_model_inputs(
model_output: Union[ModelOutput, Dict], split_size: int, full_batch_size: int
model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int
) -> List[Union[ModelOutput, Dict]]:
"""
Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split
size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from
previous forward pass.
"""
# Infer the class from the object
model_output_cls = type(model_output)
model_output_cls = type(model_input)
if (full_batch_size % split_size) != 0:
raise ValueError("`full_batch_size` must be divisible by `split_size`")

if split_size > full_batch_size:
raise ValueError("`split_size` must be smaller or equal to `full_batch_size`")

# Helper function to split tensors or tuples of tensors
def _split(data):
"""
Takes care of three cases:
1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim
2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and
return a list of tuples
3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and
return a list of tuples of tuples
(see documentation of ModelOutput)
"""
if data is None:
return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple):
return [
tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data)
for i in range(0, full_batch_size, split_size)
]

else:
return [
tuple(sub_tensor[i : i + split_size] for sub_tensor in data)
for i in range(0, full_batch_size, split_size)
]
elif isinstance(data, ModelOutput):
return _split_model_inputs(model_output["encoder_outputs"], split_size, full_batch_size)
else:
raise ValueError(f"Unexpected attribute type: {type(data)}")

# Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them
keys = (
model_output.__dataclass_fields__.keys()
if hasattr(model_output, "__dataclass_fields__")
else model_output.keys()
model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys()
)
# we only keep keys that are in the model_output
keys = [k for k in keys if k in model_output]
# here we can have three types of values: tensors, tuples of tensors and booleans
# We only keep keys that are in the model_input
keys = [k for k in keys if k in model_input]
# Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a
# ModelOutput object.
# bool should not be split but replicated for each split
bool_keys = [k for k in keys if isinstance(model_output[k], bool)]
non_bool_keys = [k for k in keys if not isinstance(model_output[k], bool)]
# import pdb; pdb.set_trace()
bool_keys = [k for k in keys if isinstance(model_input[k], bool)]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"]

# we split the tensors and tuples of tensors
data_split_list = [
{k: _split(model_output[k])[i] for k in non_bool_keys} for i in range(full_batch_size // split_size)
{k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys}
for i in range(full_batch_size // split_size)
]
# bool values are the same and replicated for each split
bool_data = {k: model_output[k] for k in bool_keys}
bool_data = {k: model_input[k] for k in bool_keys}
# encoder_outputs is a ModelOutput object and should be split by its own
if "encoder_outputs" in model_input:
encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size)
data_split_list = [
{**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
]

# Convert each dictionary in the list to an object of the inferred class
return [model_output_cls(**data_split, **bool_data) for data_split in data_split_list]
split_model_inputs: List[Union[ModelOutput, Dict]] = [
model_output_cls(**data_split, **bool_data) for data_split in data_split_list
]

return split_model_inputs


def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
Expand Down

0 comments on commit d78d038

Please sign in to comment.