Skip to content

Commit

Permalink
[infer] Fix tp inference engine (#4564)
Browse files Browse the repository at this point in the history
* fix engine prepare data

* add engine test

* use bloom for testing

* revise on test

* revise on test
  • Loading branch information
yuanheng-zhao committed Sep 7, 2023
1 parent 483b937 commit 66454d9
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 21 deletions.
26 changes: 17 additions & 9 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,29 +163,37 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
if not isinstance(inputs, (BatchEncoding, dict, list, torch.Tensor)):
raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state")

input_ids_list = None
attention_mask = None

if isinstance(inputs, (BatchEncoding, dict)):
attn_masks = inputs['attention_mask']
batch_size = attn_masks.shape[0]
max_len_in_batch = attn_masks.shape[1]
elif isinstance(inputs, list):
batch_size = len(inputs)
input_ids_list = inputs['input_ids']
attention_mask = inputs['attention_mask']
else:
batch_size = inputs.shape[0]
input_ids_list = inputs
if isinstance(input_ids_list[0], int): # for a single input
input_ids_list = [input_ids_list]
attention_mask = [attention_mask] if attention_mask is not None else attention_mask

batch_size = len(input_ids_list)

seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
start_index = 0

max_len_in_batch = -1
if isinstance(inputs, (BatchEncoding, dict)):
for i, attn_mask in enumerate(attn_masks):
curr_seq_len = int(torch.sum(attn_mask))
for i, attn_mask in enumerate(attention_mask):
if isinstance(attn_mask, torch.Tensor):
curr_seq_len = int(torch.sum(attn_mask))
else:
curr_seq_len = int(sum(attn_mask))
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
else:
for i, input_ids in enumerate(inputs):
for i, input_ids in enumerate(input_ids_list):
curr_seq_len = len(input_ids)
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
Expand Down
83 changes: 71 additions & 12 deletions tests/test_infer/test_infer_engine.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,112 @@
from itertools import accumulate

import pytest
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM
from transformers.tokenization_utils_base import BatchEncoding

import colossalai
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn

TP_SIZE = 2
BATCH_SIZE = 4
MAX_BATCH_SIZE = 4
MAX_INPUT_LEN = 16
MAX_OUTPUT_LEN = 8


def test_prepare_data():
# dummy module used for testing
class DummyModule(nn.Module):

def __init__(self, config):
super(DummyModule, self).__init__()
self.config = config

def forward(self, x):
return x

# dummy config used for testing
class DummyModelConfig:

def __init__(self):
self.hidden_size = 4096
self.num_attention_heads = 32
self.num_hidden_layers = 8

dummy_config = DummyModelConfig()
model = DummyModule(dummy_config)
infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)

input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970],
[80540, 15473, 3331, 11970], [80540, 15473]]
batch_size = len(input_ids_list)
max_seq_len = max(len(li) for li in input_ids_list)
attention_mask = [[0] * max_seq_len for _ in range(batch_size)]
for i, li in enumerate(input_ids_list):
attention_mask[i][max_seq_len - len(li):] = [1 for _ in range(len(li))]
data = dict(input_ids=input_ids_list, attention_mask=attention_mask)
inputs_batch_encoding = BatchEncoding(data=data)

seq_lengths = [len(li) for li in input_ids_list]
start_loc = list(accumulate([0] + seq_lengths[:-1]))
seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32)
start_loc = torch.tensor(start_loc, dtype=torch.int32)

# input token id list as inputs
batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding)
# BatchEncoding as inputs
batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list)

assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size
assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths)
assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths)
assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)


def test_orig_generate():
input_ids = torch.randint(low=10, high=1000, size=(BATCH_SIZE, MAX_INPUT_LEN))
input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))

model_config = LlamaConfig()
model = LlamaForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())

shard_config = ShardConfig(enable_tensor_parallelism=False)

# init TPInferEngine and
infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.prepare_with_shard_config(shard_config)

# original model generate
generate_kwargs = dict(do_sample=False)
infer_engine.generate(input_ids, generate_kwargs)

torch.cuda.empty_cache()


def run():
model_config = LlamaConfig()
model = LlamaForCausalLM(model_config)
model_config = BloomConfig()
model = BloomForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())

shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)

infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.prepare_with_shard_config(shard_config=shard_config)
infer_engine.shard_model_by(shardformer)

assert infer_engine.cache_manager is not None
assert infer_engine.tp_size == TP_SIZE
assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE

# TODO After adding forward replacement for CausalLM,
# uncomment these lines to test sharded model generate
# generate_kwargs = dict(do_sample=False)
# infer_engine.generate(input_ids, generate_kwargs)

torch.cuda.empty_cache()


Expand All @@ -66,5 +124,6 @@ def test_engine_infer():


if __name__ == '__main__':
test_prepare_data()
test_orig_generate()
test_engine_infer()

0 comments on commit 66454d9

Please sign in to comment.