Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] test tracer on diffuser modules. #1750

Merged
merged 3 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
diffusers
pytest
torchvision
transformers
Expand Down
12 changes: 6 additions & 6 deletions tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import transformers
import torch
import pytest
import torch
import transformers
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
from utils import trace_model_and_compare_output

BATCH_SIZE = 2
SEQ_LENGHT = 16
SEQ_LENGTH = 16
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved


def test_single_sentence_albert():
Expand All @@ -23,9 +23,9 @@ def test_single_sentence_albert():
intermediate_size=256)

def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args

Expand Down
12 changes: 6 additions & 6 deletions tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import transformers
import torch
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output

BATCH_SIZE = 2
SEQ_LENGHT = 16
SEQ_LENGTH = 16


def test_single_sentence_bert():
Expand All @@ -20,9 +20,9 @@ def test_single_sentence_bert():
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)

def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args

Expand Down
116 changes: 116 additions & 0 deletions tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import diffusers
import pytest
import torch
import transformers
from torch.fx import GraphModule
from utils import trace_model_and_compare_output

from colossalai.fx import ColoTracer

BATCH_SIZE = 2
SEQ_LENGTH = 5
HEIGHT = 224
WIDTH = 224
Comment on lines +12 to +13
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately, input for CLIPVisionModel should be (2, 3, 224, 224)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the input has to be 224 that large

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.

IN_CHANNELS = 3
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8)
TIME_STEP = 2


def test_vae():
MODEL_LIST = [
diffusers.AutoencoderKL,
diffusers.VQModel,
]

for model_cls in MODEL_LIST:
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)

tracer = ColoTracer()
graph = tracer.trace(root=model)

gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()

model.eval()
gm.eval()

with torch.no_grad():
fx_out = gm(sample)
non_fx_out = model(sample)
assert torch.allclose(
fx_out['sample'],
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'


def test_clip():
MODEL_LIST = [
transformers.CLIPModel,
transformers.CLIPTextModel,
transformers.CLIPVisionModel,
]

CONFIG_LIST = [
transformers.CLIPConfig,
transformers.CLIPTextConfig,
transformers.CLIPVisionConfig,
]

def data_gen():
if isinstance(model, transformers.CLIPModel):
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
kwargs = dict(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values)
elif isinstance(model, transformers.CLIPTextModel):
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
elif isinstance(model, transformers.CLIPVisionModel):
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
kwargs = dict(pixel_values=pixel_values)
return kwargs

for model_cls, config in zip(MODEL_LIST, CONFIG_LIST):
model = model_cls(config=config())
trace_model_and_compare_output(model, data_gen)


@pytest.mark.skip(reason='cannot pass the test yet')
def test_unet():
MODEL_LIST = [
diffusers.UNet2DModel,
diffusers.UNet2DConditionModel,
]

for model_cls in MODEL_LIST:
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)

tracer = ColoTracer()
graph = tracer.trace(root=model)

gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()

model.eval()
gm.eval()

with torch.no_grad():
fx_out = gm(sample, TIME_STEP)
non_fx_out = model(sample, TIME_STEP)
assert torch.allclose(
fx_out['sample'],
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'


if __name__ == "__main__":
test_vae()
test_clip()

# skip because of failure
# test_unet()
12 changes: 6 additions & 6 deletions tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import transformers
import torch
import pytest
import torch
import transformers
from utils import trace_model_and_compare_output

BATCH_SIZE = 1
SEQ_LENGHT = 16
SEQ_LENGTH = 16


def test_gpt():
Expand All @@ -19,9 +19,9 @@ def test_gpt():
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)

def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return kwargs

Expand Down
8 changes: 4 additions & 4 deletions tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
import transformers
import torch
import transformers
from utils import trace_model_and_compare_output

BATCH_SIZE = 1
SEQ_LENGHT = 16
SEQ_LENGTH = 16


def test_opt():
Expand All @@ -16,8 +16,8 @@ def test_opt():
config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)

def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs

Expand Down
10 changes: 5 additions & 5 deletions tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
import transformers
import torch
import transformers
from utils import trace_model_and_compare_output

BATCH_SIZE = 1
SEQ_LENGHT = 16
SEQ_LENGTH = 16


def test_t5():
Expand All @@ -17,13 +17,13 @@ def test_t5():
config = transformers.T5Config(d_model=128, num_layers=2)

def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
return kwargs

def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids)
return kwargs

Expand Down
5 changes: 3 additions & 2 deletions tests/test_fx/test_tracer/test_hf_model/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from numpy import isin
import torch
from colossalai.fx import ColoTracer
from numpy import isin
from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten

from colossalai.fx import ColoTracer


def trace_model_and_compare_output(model, data_gen):
# must turn on eval mode to ensure the output is consistent
Expand Down