Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for dynamic batch padding #2352

Merged
merged 20 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/accelerate/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from .utils import (
calculate_maximum_sizes,
convert_bytes,
ignorant_find_batch_size,
infer_auto_device_map,
send_to_device,
slice_and_concatenate,
)


Expand Down Expand Up @@ -59,9 +61,33 @@ def build_pipeline(model, split_points, args, kwargs) -> PipelineStage:
def pippy_forward(forward, *args, **kwargs):
state = PartialState()
output = None

if state.num_processes == 1:
output = forward(*args, **kwargs)
elif state.is_local_main_process:
found_batch_size = None
for arg in args:
found_batch_size = ignorant_find_batch_size(arg)
if found_batch_size is not None:
break
for kwarg in kwargs.values():
found_batch_size = ignorant_find_batch_size(kwarg)
if found_batch_size is not None:
break
if found_batch_size is None:
raise ValueError("Could not find batch size from args or kwargs")
else:
if (found_batch_size % state.num_processes) != 0:
# First special case: bs == 1, we just duplicate
if found_batch_size == 1:
slice_to_cut = slice(0, found_batch_size % state.num_processes)
else:
# Second special case: bs < num_processes, we add a buffer to the batch size to bring it to num_processes
if state.num_processes > found_batch_size:
found_batch_size += (state.num_processes - found_batch_size) + 1
slice_to_cut = slice((found_batch_size % state.num_processes) + 1, found_batch_size)
args = slice_and_concatenate(args, slice_to_cut)
kwargs = slice_and_concatenate(kwargs, slice_to_cut)
forward(*args, **kwargs)
elif state.is_last_process:
output = forward()
Expand Down
48 changes: 40 additions & 8 deletions src/accelerate/test_utils/scripts/external_deps/test_pippy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torchvision.models import resnet34
from transformers import (
BertConfig,
BertForMaskedLM,
Expand All @@ -34,24 +35,28 @@
}


def get_model_and_data(model_name, device, num_processes: int = 2):
def get_model_and_data_for_text(model_name, device, num_processes: int = 2):
initializer, config, seq_len = model_to_config[model_name]
config = config()
model = initializer(config)
config_args = {}
# Eventually needed for batch inference tests on gpt-2 when bs != 1
# if model_name == "gpt2":
# config_args["pad_token_id"] = 0
model_config = config(**config_args)
model = initializer(model_config)
return model, torch.randint(
low=0,
high=config.vocab_size,
high=model_config.vocab_size,
size=(num_processes, seq_len),
device=device,
dtype=torch.int64,
requires_grad=False,
)


def test_gpt2():
def test_gpt2(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, inputs = get_model_and_data("gpt2", "cpu", state.num_processes)
model, inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
model = prepare_pippy(model, example_args=(inputs,), no_split_module_classes=model._no_split_modules)
# For inference args need to be a tuple
inputs = inputs.to("cuda")
Expand All @@ -64,10 +69,10 @@ def test_gpt2():
assert output is not None, "Output was not generated in the last process!"


def test_t5():
def test_t5(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, inputs = get_model_and_data("t5", "cpu", state.num_processes)
model, inputs = get_model_and_data_for_text("t5", "cpu", batch_size)
example_inputs = {"input_ids": inputs, "decoder_input_ids": inputs}
model = prepare_pippy(
model,
Expand All @@ -85,13 +90,40 @@ def test_t5():
assert output is not None, "Output was not generated in the last process!"


def test_resnet(batch_size: int = 2):
set_seed(42)
state = PartialState()
model = resnet34()
input_tensor = torch.rand(batch_size, 3, 224, 224)
model = prepare_pippy(
model,
example_args=(input_tensor,),
)
inputs = send_to_device(input_tensor, "cuda:0")
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"


if __name__ == "__main__":
state = PartialState()
state.print("Testing pippy integration...")
if state.distributed_type == DistributedType.MULTI_GPU:
state.print("Testing GPT2...")
test_gpt2()
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
# due to references
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
# test_gpt2(3)
state.print("Testing T5...")
test_t5()
test_t5(3)
state.print("Testing CV model...")
test_resnet()
test_resnet(3)
else:
print("Less than two GPUs found, not running tests!")
2 changes: 2 additions & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
gather_object,
get_data_structure,
honor_type,
ignorant_find_batch_size,
initialize_tensors,
is_namedtuple,
is_tensor_information,
Expand All @@ -134,6 +135,7 @@
recursively_apply,
reduce,
send_to_device,
slice_and_concatenate,
slice_tensors,
)
from .versions import compare_versions, is_torch_version
Expand Down
25 changes: 25 additions & 0 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,23 @@ def find_batch_size(data):
return data.shape[0]


def ignorant_find_batch_size(data):
"""
Same as [`utils.operations.find_batch_size`] except will ignore if `ValueError` and `TypeErrors` are raised

Args:
data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.

Returns:
`int`: The batch size.
"""
try:
return find_batch_size(data)
except (ValueError, TypeError):
pass
return None


def listify(data):
"""
Recursively finds tensors in a nested list/tuple/dictionary and converts them to a list of numbers.
Expand Down Expand Up @@ -591,6 +608,14 @@ def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
)


def slice_and_concatenate(tensor, tensor_slice, dim=0):
"""
Slices `tensor` based on `tensor_slice` and then returns a concatenated version of `tensor` and `tensor_slice`
"""
chunk = slice_tensors(tensor, tensor_slice)
return concatenate([tensor, chunk], dim=dim)


@verify_operation
def reduce(tensor, reduction="mean", scale=1.0):
"""
Expand Down
86 changes: 86 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
recursively_apply,
save,
send_to_device,
slice_and_concatenate,
)


Expand Down Expand Up @@ -237,3 +238,88 @@ def test_pad_across_processes(self):
with self.assertWarns(CannotPadNestedTensorWarning):
nt2 = pad_across_processes(nt)
self.assertIs(nt, nt2)

def test_slice_and_concatenate(self):
# Should be equivalent to the slice func used in `pippy_forward`
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
def get_slice(batch_size, num_processes):
# First special case: bs == 1, we just duplicate
if batch_size == 1:
slice_to_cut = slice(0, batch_size % num_processes)
else:
# Second special case: bs < num_processes, we add a buffer to the batch size to bring it to num_processes
if num_processes > batch_size:
batch_size += (num_processes - batch_size) + 1
slice_to_cut = slice((batch_size % num_processes) + 1, batch_size)
return slice_to_cut

# First base case: 2 processes, batch size of 1
num_processes = 2
batch_size = 1
batch = torch.rand(batch_size, 4)
slice_to_cut = get_slice(batch_size, num_processes)
result = slice_and_concatenate(batch, slice_to_cut)
# We should expect there to be 2 items now
assert result.shape == torch.Size([2, 4])

# Second base case: 2 processes, batch size of 3
num_processes = 2
batch_size = 3
batch = torch.rand(batch_size, 4)
slice_to_cut = get_slice(batch_size, num_processes)
result = slice_and_concatenate(batch, slice_to_cut)
# We should expect there to be 4 items now
assert result.shape == torch.Size([4, 4])

# Third base case: 3 processes, batch size of 4
num_processes = 3
batch_size = 4
batch = torch.rand(batch_size, 4, 4)
slice_to_cut = get_slice(batch_size, num_processes)
result = slice_and_concatenate(batch, slice_to_cut)
# We should expect there to be 6 items now
assert result.shape == torch.Size([6, 4, 4])

# Fourth base case: 4 processes, batch size of 3
num_processes = 4
batch_size = 3
batch = torch.rand(batch_size, 4, 4)
slice_to_cut = get_slice(batch_size, num_processes)
result = slice_and_concatenate(batch, slice_to_cut)
# We should expect there to be 4 items now
assert result.shape == torch.Size([4, 4, 4])

# Fifth base case: 6 processes, batch size of 4
num_processes = 6
batch_size = 4
batch = torch.rand(batch_size, 4, 4)
slice_to_cut = get_slice(batch_size, num_processes)
result = slice_and_concatenate(batch, slice_to_cut)
# We should expect there to be 6 items now
assert result.shape == torch.Size([6, 4, 4])

# Sixth base case: 6 processes, batch size of 1
num_processes = 6
batch_size = 1
batch = torch.rand(batch_size, 4, 4)
slice_to_cut = get_slice(batch_size, num_processes)
result = slice_and_concatenate(batch, slice_to_cut)
# We should expect there to be 6 items now
assert result.shape == torch.Size([6, 4, 4])

# Seventh base case: 6 processes, batch size of 2
num_processes = 6
batch_size = 2
batch = torch.rand(batch_size, 4, 4)
slice_to_cut = get_slice(batch_size, num_processes)
result = slice_and_concatenate(batch, slice_to_cut)
# We should expect there to be 6 items now
assert result.shape == torch.Size([6, 4, 4])

# Eighth base case: 6 processes, batch size of 61
num_processes = 6
batch_size = 61
batch = torch.rand(batch_size, 4, 4)
slice_to_cut = get_slice(batch_size, num_processes)
result = slice_and_concatenate(batch, slice_to_cut)
# We should expect there to be 6 items now
assert result.shape == torch.Size([66, 4, 4])
Loading