Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Regression] Local performance regression #321

Merged
merged 15 commits into from
Jul 20, 2023
2 changes: 1 addition & 1 deletion python/hidet/graph/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def latency(
if dummy_inputs is None:
dummy_inputs = self.dummy_inputs()
for _ in range(warmup):
self.forward(*dummy_inputs)
self.forward(dummy_inputs)
results = []
for _ in range(repeat):
hidet.cuda.synchronize()
Expand Down
Empty file added scripts/regression/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions scripts/regression/email_sender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import smtplib
from email.mime.text import MIMEText
import getpass


class EmailSender:
def __init__(self) -> None:
self.recipients = input('Enter comma separated recipient email addresses:\n')
self.recipients = self.recipients.replace(' ', '').split(',')
self.sender = input('Enter your Gmail address:\n')
self.password = getpass.getpass(prompt='Enter your 16-digit Google app password: ')

def send_email(self, body):
subject = 'Hidet Performance Regression'
msg = MIMEText(body)
msg['Subject'] = subject
msg['From'] = self.sender
msg['To'] = ', '.join(self.recipients)
with smtplib.SMTP_SSL('smtp.gmail.com', 465) as smtp_server:
smtp_server.login(self.sender, self.password)
smtp_server.sendmail(self.sender, self.recipients, msg.as_string())
print("Results sent to", msg['To'])

128 changes: 128 additions & 0 deletions scripts/regression/model_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os
import json
import numpy as np
import torch
import torchvision
import hidet
import argparse
from result_entry import ResultEntry, ResultGroup, load_regression_data
from transformers import AutoTokenizer, AutoModelForMaskedLM, logging
from torch import _dynamo
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.set_verbosity_error()

device_name = str(hidet.cuda.properties().name, 'UTF-8')

def bench_torch_model(model, torch_inputs, bench_iters=100, warmup_iters=10):
for _ in range(warmup_iters):
torch_out = model(*torch_inputs)
torch.cuda.empty_cache()

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for _ in range(bench_iters):
torch_out = model(*torch_inputs)
end.record()
end.synchronize()
torch.cuda.empty_cache()

latency = start.elapsed_time(end) / bench_iters
return latency

def bench_hf_transformers(model_name, seqlen, dtype):
use_fp16 = dtype == 'float16'
hidet.torch.dynamo_config.search_space(2)
hidet.torch.dynamo_config.use_fp16(use_fp16)
hidet.torch.dynamo_config.use_fp16_reduction(use_fp16)
hidet.torch.dynamo_config.use_attention(True)
hidet.torch.dynamo_config.use_tensor_core(True)
hidet.torch.dynamo_config.use_cuda_graph(True)
dtype = getattr(torch, dtype)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name,
max_position_embeddings=8192, ignore_mismatched_sizes=True)
model = model.eval().to(dtype).cuda()
inputs = tokenizer("Dummy sentence", padding='max_length', max_length=seqlen,
return_tensors='pt')
inputs = {'input_ids': inputs['input_ids']}
torch_inputs = tuple(i.clone().cuda() for i in inputs.values())
with torch.no_grad(), torch.autocast("cuda"):
model = torch.compile(model, backend='hidet')
latency = bench_torch_model(model, torch_inputs)
del model
return latency

def bench_torchvision(model_cls, shape, dtype):
use_fp16 = dtype == 'float16'
hidet.torch.dynamo_config.search_space(2)
hidet.torch.dynamo_config.use_fp16(use_fp16)
hidet.torch.dynamo_config.use_fp16_reduction(use_fp16)
hidet.torch.dynamo_config.use_attention(True)
hidet.torch.dynamo_config.use_tensor_core(True)
hidet.torch.dynamo_config.use_cuda_graph(True)
dtype = getattr(torch, dtype)
model = model_cls(weights=None)
model = model.eval().to(dtype).cuda()
torch_inputs = [torch.randn(shape, device='cuda', dtype=dtype)]
with torch.no_grad(), torch.autocast("cuda"):
model = torch.compile(model, backend='hidet')
latency = bench_torch_model(model, torch_inputs)
del model
return latency

def bert_regression():
regression_data = load_regression_data()
result_group = ResultGroup(name='bert-base Regression', device_name=device_name)
bert_data = regression_data[device_name]['bert_base_shapes']
for shape, perf in bert_data.items():
for dtype, ref_latency in perf.items():
[seqlen] = [int(s.strip()) for s in shape.split(',')]
latency = bench_hf_transformers('bert-base-uncased', seqlen, dtype)
result_group.add_entry(ResultEntry(shape, dtype, latency, ref_latency))
return result_group

def resnet_regression():
regression_data = load_regression_data()
result_group = ResultGroup(name='resnet50 Regression', device_name=device_name)
resnet50_data = regression_data[device_name]['resnet50_shapes']
model_cls = torchvision.models.resnet50
for shape, perf in resnet50_data.items():
for dtype, ref_latency in perf.items():
_shape = [int(s.strip()) for s in shape.split(',')]
latency = bench_torchvision(model_cls, _shape, dtype)
result_group.add_entry(ResultEntry(shape, dtype, latency, ref_latency))
return result_group


def efficientnet_regression():
# ToDo
return None

def llama_regression():
# ToDo
return None

def model_performance_regression(report_file):
result_groups = []
result_groups.append(bert_regression())
result_groups.append(resnet_regression())
result_groups.append(efficientnet_regression())
result_groups.append(llama_regression())
with open(report_file, 'w') as f:
f.write("---------------- Model Performance Regression -----------------\n")
for result_group in result_groups:
if result_group is not None:
f.write(str(result_group))

if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='Model Performance Regression')
parser.add_argument(
'--report',
type=str,
default='./report_model_performance.txt',
help='Specify report output path'
)
args = parser.parse_args()
model_performance_regression(args.report)
83 changes: 83 additions & 0 deletions scripts/regression/op_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import json
import numpy as np
import torch
import hidet
import argparse
from result_entry import ResultEntry, ResultGroup, load_regression_data

device_name = str(hidet.cuda.properties().name, 'UTF-8')

def bench_matmul(m, n, k, dtype):
hidet.option.search_space(2)
a = hidet.symbol([m, k], dtype=dtype, device='cuda')
b = hidet.symbol([k, n], dtype=dtype, device='cuda')
c = hidet.ops.matmul(a, b)
g = hidet.trace_from(c, [a, b])
g = hidet.graph.optimize(g)
return g.latency(warmup=3, number=3, repeat=10)

def bench_fmha(sq, skv, d):
hidet.option.search_space(2)
q = hidet.symbol([sq, d], dtype='float16', device='cuda')
k = hidet.symbol([d, skv], dtype='float16', device='cuda')
v = hidet.symbol([skv, d], dtype='float16', device='cuda')
o = hidet.ops.attention(q, k, v)
g = hidet.trace_from(o, [q, k, v])
g = hidet.graph.optimize(g)
return g.latency(warmup=3, number=3, repeat=10)

def matmul_regression() -> ResultGroup:
regression_data = load_regression_data()
result_group = ResultGroup(name='Matrix Multiply Regression', device_name=device_name)
matmul_data = regression_data[device_name]['matmul_shapes']
for shape, perf in matmul_data.items():
for dtype, ref_latency in perf.items():
(m, n, k) = [int(s.strip()) for s in shape.split(',')]
latency = bench_matmul(m, n, k, dtype)
result_group.add_entry(ResultEntry(shape, dtype, latency, ref_latency))
return result_group


def fmha_regression() -> ResultGroup:
regression_data = load_regression_data()
result_group = ResultGroup(name='Fused Attention Regression', device_name=device_name)
fmha_data = regression_data[device_name]['fmha_shapes']
for shape, perf in fmha_data.items():
for dtype, ref_latency in perf.items():
(sq, skv, d) = [int(s.strip()) for s in shape.split(',')]
latency = bench_fmha(sq, skv, d)
result_group.add_entry(ResultEntry(shape, dtype, latency, ref_latency))
return result_group

def conv2d_regression() -> ResultGroup:
# ToDo
return None

def reduce_regression() -> ResultGroup:
# ToDo
return None


def op_performance_regression(report_file):
result_groups = []
result_groups.append(matmul_regression())
result_groups.append(fmha_regression())
result_groups.append(conv2d_regression())
result_groups.append(reduce_regression())
with open(report_file, 'w') as f:
f.write("---------------- Operator Performance Regression -----------------\n")
for result_group in result_groups:
if result_group is not None:
f.write(str(result_group))

if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='Operator Performance Regression')
parser.add_argument(
'--report',
type=str,
default='./report_op_performance.txt',
help='Specify report output path'
)
args = parser.parse_args()
op_performance_regression(args.report)