Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions onnxruntime/python/tools/transformers/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,12 @@ def run_onnxruntime(
"datetime": str(datetime.now()),
}

logger.info(f"Run onnxruntime on {model_name} with input shape {[batch_size, sequence_length]}")
if config.model_type in ["vit", "swin"]:
logger.info(
f"Run onnxruntime on {model_name} with input shape {[batch_size, 3, config.image_size, config.image_size]}"
)
else:
logger.info(f"Run onnxruntime on {model_name} with input shape {[batch_size, sequence_length]}")

if disable_ort_io_binding:
result = inference_ort(
Expand Down Expand Up @@ -336,11 +341,16 @@ def run_pytorch(
cache_dir=cache_dir,
custom_model_class=model_class,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)

max_input_size = (
tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024
)
if config.model_type in ["vit", "swin"]:
# These models don't use sequence lengths, so just pick the first sequence length so that the summary still works
sequence_lengths = [sequence_lengths[0]]
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)

max_input_size = (
tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024
)

logger.debug(f"Model {model}")
logger.debug(f"Number of parameters {model.num_parameters()}")
Expand All @@ -359,17 +369,27 @@ def run_pytorch(
continue

for sequence_length in sequence_lengths:
if max_input_size is not None and sequence_length > max_input_size:
continue
if config.model_type in ["vit", "swin"]:
logger.info(
f"Run PyTorch on {model_name} with input shape {[batch_size, 3, config.image_size, config.image_size]}"
)
input_ids = torch.randn(
size=(batch_size, 3, config.image_size, config.image_size),
dtype=torch.float16 if precision == Precision.FLOAT16 else torch.float32,
device=device,
)
else:
if max_input_size is not None and sequence_length > max_input_size:
continue

logger.info(f"Run PyTorch on {model_name} with input shape {[batch_size, sequence_length]}")
input_ids = torch.randint(
low=0,
high=config.vocab_size - 1,
size=(batch_size, sequence_length),
dtype=torch.long,
device=device,
)
logger.info(f"Run PyTorch on {model_name} with input shape {[batch_size, sequence_length]}")
input_ids = torch.randint(
low=0,
high=config.vocab_size - 1,
size=(batch_size, sequence_length),
dtype=torch.long,
device=device,
)
try:
inference = (
torch.jit.trace(model, input_ids) if torchscript else torch.compile(model) if torch2 else model
Expand Down Expand Up @@ -767,6 +787,9 @@ def main():
logger.error("int8 is for CPU only")
return

if len(args.models) == 1 and MODELS[args.models[0]][3] in ["vit", "swim"]:
args.sequence_lengths = [""]

args.num_threads = sorted({cpu_count if x <= 0 else x for x in args.num_threads})

logger.info(f"Arguments: {args}")
Expand Down
12 changes: 9 additions & 3 deletions onnxruntime/python/tools/transformers/benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,11 @@ def output_summary(results, csv_filename, args):
]
data_names = []
for batch_size in args.batch_sizes:
for sequence_length in args.sequence_lengths:
data_names.append(f"b{batch_size}_s{sequence_length}")
if args.sequence_lengths == [""]:
data_names.append(f"b{batch_size}")
else:
for sequence_length in args.sequence_lengths:
data_names.append(f"b{batch_size}_s{sequence_length}")

csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
csv_writer.writeheader()
Expand All @@ -273,7 +276,10 @@ def output_summary(results, csv_filename, args):
assert row[k] == headers[k]
b = result["batch_size"]
s = result["sequence_length"]
row[f"b{b}_s{s}"] = result["average_latency_ms"]
if s != "":
row[f"b{b}_s{s}"] = result["average_latency_ms"]
else:
row[f"b{b}"] = result["average_latency_ms"]
if row:
csv_writer.writerow(row)

Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/python/tools/transformers/huggingface_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,12 @@
False,
"bert",
),
"google/vit-base-patch16-224": (["pixel_values"], 12, False, "vit"),
# "google/pegasus-xsum": (["input_ids"], 11, False, "bert"),
# "google/pegasus-large": (["input_ids"], 11, False, "bert"),
# ViT
"google/vit-base-patch16-224": (["pixel_values"], 12, False, "vit"),
# Swin
"microsoft/swin-base-patch4-window7-224": (["pixel_values"], 12, False, "swin"),
"microsoft/swin-small-patch4-window7-224": (["pixel_values"], 12, False, "swin"),
"microsoft/swin-tiny-patch4-window7-224": (["pixel_values"], 12, False, "swin"),
}
25 changes: 20 additions & 5 deletions onnxruntime/python/tools/transformers/onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ def restore_torch_functions():


def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, config, data_type=numpy.int64):
input_ids = numpy.random.randint(low=0, high=vocab_size - 1, size=(batch_size, sequence_length), dtype=data_type)
if config.model_type in ["vit", "swin"]:
input_ids = numpy.random.rand(batch_size, 3, config.image_size, config.image_size).astype(numpy.float32)
inputs = {"pixel_values": input_ids}
return inputs

input_ids = numpy.random.randint(low=0, high=vocab_size - 1, size=(batch_size, sequence_length), dtype=data_type)
inputs = {"input_ids": input_ids}

if "attention_mask" in input_names:
Expand Down Expand Up @@ -241,6 +245,11 @@ def optimize_onnx_model(
if precision == Precision.INT8:
optimization_options.enable_embed_layer_norm = False

# For swin models, the num_attention_heads is a list, which isn't supported yet, so set to 0 for now
if model_type == "swin":
num_attention_heads = 0
hidden_size = 0

# Use script to optimize model.
# Use opt_level <= 1 for models to be converted to fp16, because some fused op (like FusedGemm) has only fp32 and no fp16.
# It is better to be conservative so we use opt_level=0 here, in case MemcpyFromHost is added to the graph by OnnxRuntime.
Expand Down Expand Up @@ -439,7 +448,11 @@ def validate_and_optimize_onnx(
model_fusion_statistics,
)

return onnx_model_path, is_valid_onnx_model, None if model_type == "vit" else config.vocab_size
return (
onnx_model_path,
is_valid_onnx_model,
config.num_labels if model_type in ["vit", "swin"] else config.vocab_size,
)


def export_onnx_model_from_pt(
Expand Down Expand Up @@ -468,9 +481,11 @@ def export_onnx_model_from_pt(
example_inputs = None
max_input_size = None

if model_type == "vit":
if model_type in ["vit", "swin"]:
image_processor = AutoFeatureExtractor.from_pretrained(model_name, cache_dir=cache_dir)
data = numpy.random.randint(low=0, high=256, size=224 * 224 * 3, dtype=numpy.uint8).reshape(224, 224, 3)
data = numpy.random.randint(
low=0, high=256, size=config.image_size * config.image_size * 3, dtype=numpy.uint8
).reshape(config.image_size, config.image_size, 3)

example_inputs = image_processor(data, return_tensors="pt")
else:
Expand Down Expand Up @@ -509,7 +524,7 @@ def export_onnx_model_from_pt(
dynamic_axes = None
output_names = None

if model_type == "vit":
if model_type in ["vit", "swin"]:
dynamic_axes, output_names = {key: {0: "pixel_values"} for key in example_inputs}, ["logits"]
else:
dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"vae": (VaeOnnxModel, "pytorch", 1),
"clip": (ClipOnnxModel, "pytorch", 1),
"vit": (BertOnnxModel, "pytorch", 1),
"swin": (BertOnnxModel, "pytorch", 1),
}


Expand Down Expand Up @@ -160,7 +161,7 @@ def optimize_by_fusion(
Returns:
object of an optimizer class.
"""
if model_type not in ["bert", "unet", "vae", "clip"] and (num_heads == 0 or hidden_size == 0):
if model_type not in ["bert", "swin", "unet", "vae", "clip"] and (num_heads == 0 or hidden_size == 0):
logger.warning(f"Please specify parameters of num_heads and hidden_size for model_type {model_type}")

(optimizer_class, producer, _) = MODEL_TYPES[model_type]
Expand Down