Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] rewrite tests for opt/bloom/llama/vit/chatglm #4395

Merged
merged 6 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]

tp_size = dist.get_world_size(process_group)
if out_features < tp_size:
return module

if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!")

linear_1d = Linear1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
Expand Down
497 changes: 495 additions & 2 deletions colossalai/shardformer/modeling/opt.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ class PolicyLocation:
PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"),
"transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration":
PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"),

# ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel":
PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration":
PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"),
}


Expand Down
618 changes: 5 additions & 613 deletions colossalai/shardformer/policies/opt.py

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions tests/kit/model_zoo/transformers/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def data_gen_for_question_answering():
# inputs = tokenizer(question, text, return_tensors="pt")

input_ids = torch.tensor(
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]],
dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
start_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([10], dtype=torch.int64)
Expand All @@ -73,12 +74,13 @@ def data_gen_for_question_answering():
loss_fn_for_classification = lambda x: x.loss
loss_fn_for_question_answering = lambda x: x.loss

config = transformers.BloomConfig(n_layer=1,
config = transformers.BloomConfig(n_layer=2,
n_head=4,
vocab_size=250880,
hidden_dropout=0,
attention_dropout=0,
hidden_size=64)
hidden_size=64,
pad_token_id=50256)

# register the following models
model_zoo.register(name='transformers_bloom',
Expand Down
19 changes: 14 additions & 5 deletions tests/kit/model_zoo/transformers/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,24 @@ def data_gen():
return dict(input_ids=input_ids, attention_mask=attention_mask)


def data_gen_for_conditional_generation():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
return data


# define output transform function
output_transform_fn = lambda x: x

# define loss function
loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.sum()
loss_fn = lambda x: x.logits.sum()
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
torch.ones_like(x.last_hidden_state))
loss_fn = lambda x: x.loss

config = ChatGLMConfig(num_layers=1,
config = ChatGLMConfig(num_layers=2,
padded_vocab_size=65024,
hidden_size=64,
num_attention_heads=8,
Expand All @@ -33,7 +43,6 @@ def data_gen():
use_cache=True,
torch_dtype=torch.float32)


model_zoo.register(name='transformers_chatglm',
model_fn=lambda: ChatGLMModel(config, empty_init=False),
data_gen_fn=data_gen,
Expand All @@ -43,7 +52,7 @@ def data_gen():

model_zoo.register(name="transformers_chatglm_for_conditional_generation",
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
6 changes: 1 addition & 5 deletions tests/kit/model_zoo/transformers/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
# Register single-sentence VIT
# ===============================

config = transformers.ViTConfig(
num_hidden_layers=4,
# hidden_size=128,
# intermediate_size=256,
num_attention_heads=4)
config = transformers.ViTConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)


# define data gen function
Expand Down
37 changes: 17 additions & 20 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,27 +104,22 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
if 'use_lazy_init' in test_config:
use_lazy_init = test_config.pop('use_lazy_init')

if use_lazy_init:
ctx = LazyInitContext()
else:
ctx = nullcontext()

plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)

ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
org_model = model_fn().cuda()
sharded_model = copy.deepcopy(org_model)

org_model = model_fn()
if use_lazy_init:
org_model = ctx.materialize(org_model)
ctx.materialize(org_model)

org_model = org_model.cuda()
sharded_model = copy.deepcopy(org_model)
ver217 marked this conversation as resolved.
Show resolved Hide resolved
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn

sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)

sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster


Expand All @@ -142,11 +137,12 @@ def _criterion(outputs, inputs):
data = data_gen_fn()
sharded_model.train()
if booster.plugin.stage_manager is not None:
data = {
k: v.to('cuda').repeat(*([4] + [1] *
(v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)

data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
sharded_model,
Expand Down Expand Up @@ -176,15 +172,16 @@ def check_output_hidden_state(org_output: Tensor,
sharded_output: Tensor,
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
rtol: float = 1e-3):
rtol: float = 1e-3,
dim: int = 0):

org_hidden_state = org_output.last_hidden_state

if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state

if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim)

assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
Expand Down
118 changes: 81 additions & 37 deletions tests/test_shardformer/test_model/test_shard_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,57 +3,101 @@

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)


def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):

org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)

def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)

# do backward
org_loss.backward()
shard_loss.backward()
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group

assert torch.allclose(org_loss, shard_loss,
atol=1e-6), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():

if org_model.__class__.__name__ == 'BloomModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)

check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)

# unwrap model
if org_model.__class__.__name__ == 'BloomModel':
bloom = org_model
sharded_bloom = sharded_model
sharded_bloom = sharded_model.unwrap()
else:
bloom = org_model.transformer
sharded_bloom = sharded_model.transformer
sharded_bloom = sharded_model.unwrap().transformer

# check grad
col_layer_for_check = ['h[0].self_attention.query_key_value']
row_layer_for_check = ['h[0].self_attention.dense']
check_grad(bloom, sharded_bloom, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(bloom, sharded_bloom, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)


@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
use_lazy_init):
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
col_layer_for_check = ['h[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage():
check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=1, verbose=False)

# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)

torch.cuda.empty_cache()


@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': True,
'use_lazy_init': True
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
def run_bloom_test(test_config):

# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}

# TODO: add test_config for flash attention & jit operator after supporting

sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

clear_layout_converter()
torch.cuda.empty_cache()


Expand All @@ -67,7 +111,7 @@ def check_bloom(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom():
spawn(check_bloom, 2)
spawn(check_bloom, 4)


if __name__ == "__main__":
Expand Down
Loading
Loading