Skip to content

Commit

Permalink
unify the network QDLs
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunsuresh committed Nov 30, 2023
1 parent 0547669 commit acc3024
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 69 deletions.
18 changes: 15 additions & 3 deletions language/bert/bert_base_QDL.py → language/bert/bert_QDL.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import squad_QSL


class bert_base_QDL:
class bert_QDL:
"""QDL acting as a proxy to the SUT.
This QDL communicates with the SUT via HTTP.
It uses two endpoints to communicate with the SUT:
Expand Down Expand Up @@ -77,10 +77,22 @@ def process_query_async(self, query_samples):
query_samples: A list of QuerySample objects.
"""

responses = []
for i in range(len(query_samples)):
responses = []
eval_features = self.qsl.get_features(query_samples[i].index)
'''implement this'''
encoded_eval_features = {
"input_ids": eval_features.input_ids,
"input_mask": eval_features.input_mask,
"segment_ids": eval_features.segment_ids
}
output = self.client_predict(encoded_eval_features, query_samples[i].index)
output = np.array(output).astype(np.float32)
response_array = array.array("B", output.tobytes())
bi = response_array.buffer_info()

responses.append(lg.QuerySampleResponse(query_samples[i].id, bi[0], bi[1]))
lg.QuerySamplesComplete(responses)


def get_sut_id_round_robin(self):
"""Get the SUT id in round robin."""
Expand Down
10 changes: 2 additions & 8 deletions language/bert/network_LON.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl import app
import squad_QSL
import mlperf_loadgen as lg
import bert_QDL

def set_args(argv, g_settings, g_log_settings, g_audit_conf, g_sut_server, g_backend, g_total_count_override=None, g_perf_count_override=None):

Expand All @@ -35,14 +36,7 @@ def set_args(argv, g_settings, g_log_settings, g_audit_conf, g_sut_server, g_bac

def main(argv):
qsl = squad_QSL.get_squad_QSL(total_count_override, perf_count_override)
if backend == "onnxruntime":
import bert_onnxruntime_QDL
qdl = bert_onnxruntime_QDL.bert_onnxruntime_QDL(qsl, sut_server_addr=sut_server)
elif backend == "pytorch":
import bert_pytorch_QDL
qdl = bert_pytorch_QDL.bert_pytorch_QDL(qsl, sut_server_addr=sut_server)
else:
raise ValueError('`backend` should be one of onnxruntime,pytorch for Loadgen over the network bert implementation')
qdl = bert_QDL.bert_QDL(qsl, sut_server_addr=sut_server)

lg.StartTestWithLogSettings(qdl.qdl, qsl.qsl, settings, log_settings, audit_conf)

Expand Down
68 changes: 34 additions & 34 deletions language/bert/onnxruntime_SUT.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,43 +58,43 @@ def issue_queries(self, query_samples):
for i in range(len(query_samples)):

eval_features = self.qsl.get_features(query_samples[i].index)
if self.quantized:
fd = {
"input_ids": np.array(eval_features.input_ids).astype(np.int64)[np.newaxis, :],
"attention_mask": np.array(eval_features.input_mask).astype(np.int64)[np.newaxis, :],
"token_type_ids": np.array(eval_features.segment_ids).astype(np.int64)[np.newaxis, :]
}
else:
fd = {
"input_ids": np.array(eval_features.input_ids).astype(np.int64)[np.newaxis, :],
"input_mask": np.array(eval_features.input_mask).astype(np.int64)[np.newaxis, :],
"segment_ids": np.array(eval_features.segment_ids).astype(np.int64)[np.newaxis, :]
}
if self.network == "sut":
for key in fd:
fd[key] = fd[key].tolist()

scores = self.sess.run([o.name for o in self.sess.get_outputs()], fd)
output = np.stack(scores, axis=-1)[0]

if self.network == "sut":
return output.tolist()

response_array = array.array("B", output.tobytes())
bi = response_array.buffer_info()
response = lg.QuerySampleResponse(query_samples[i].id, bi[0], bi[1])
lg.QuerySamplesComplete([response])

def process_sample(self, sample_input):
'''For Loadgen over the network'''
sample_input["input_ids"] = np.array(sample_input["input_ids"])
sample_input["input_mask"] = np.array(sample_input["input_mask"])
sample_input["segment_ids"] = np.array(sample_input["segment_ids"])
self.process_sample(eval_features, query_samples[i].id)

def process_sample(self, eval_features, query_id):

scores = self.sess.run([o.name for o in self.sess.get_outputs()], sample_input)
'''For Loadgen over the network'''
if self.network == "sut":
input_ids = eval_features['input_ids']
input_mask = eval_features['input_mask']
segment_ids = eval_features['segment_ids']
else:
input_ids = eval_features.input_ids
input_mask = eval_features.input_mask
segment_ids = eval_features.segment_ids

if self.quantized:
fd = {
"input_ids": np.array(input_ids).astype(np.int64)[np.newaxis, :],
"attention_mask": np.array(input_mask).astype(np.int64)[np.newaxis, :],
"token_type_ids": np.array(segment_ids).astype(np.int64)[np.newaxis, :]
}
else:
fd = {
"input_ids": np.array(input_ids).astype(np.int64)[np.newaxis, :],
"input_mask": np.array(input_mask).astype(np.int64)[np.newaxis, :],
"segment_ids": np.array(segment_ids).astype(np.int64)[np.newaxis, :]
}

scores = self.sess.run([o.name for o in self.sess.get_outputs()], fd)
output = np.stack(scores, axis=-1)[0]
return output.tolist()

if self.network == "sut":
return output.tolist()

response_array = array.array("B", output.tobytes())
bi = response_array.buffer_info()
response = lg.QuerySampleResponse(query_id, bi[0], bi[1])
lg.QuerySamplesComplete([response])

def flush_queries(self):
pass
Expand Down
51 changes: 27 additions & 24 deletions language/bert/pytorch_SUT.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, args):
type_vocab_size=config_json["type_vocab_size"],
vocab_size=config_json["vocab_size"])

self.network = args.network
self.dev = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
self.version = transformers.__version__

Expand All @@ -65,38 +66,40 @@ def __init__(self, args):
self.qsl = get_squad_QSL(args.max_examples)

def issue_queries(self, query_samples):
for i in range(len(query_samples)):
eval_features = self.qsl.get_features(query_samples[i].index)
self.process_sample(eval_features, query_samples[i].id)

def process_sample(self, sample_input, query_id = None):

if self.network == "sut":
input_ids = sample_input['input_ids']
input_mask = sample_input['input_mask']
segment_ids = sample_input['segment_ids']
else:
input_ids = sample_input.input_ids
input_mask = sample_input.input_mask
segment_ids = sample_input.segment_ids

with torch.no_grad():
for i in range(len(query_samples)):
eval_features = self.qsl.get_features(query_samples[i].index)
model_output = self.model.forward(input_ids=torch.LongTensor(eval_features.input_ids).unsqueeze(0).to(self.dev),
attention_mask=torch.LongTensor(eval_features.input_mask).unsqueeze(0).to(self.dev),
token_type_ids=torch.LongTensor(eval_features.segment_ids).unsqueeze(0).to(self.dev))
if self.version >= '4.0.0':
start_scores = model_output.start_logits
end_scores = model_output.end_logits
else:
start_scores, end_scores = model_output
output = torch.stack([start_scores, end_scores], axis=-1).squeeze(0).cpu().numpy()

response_array = array.array("B", output.tobytes())
bi = response_array.buffer_info()
response = lg.QuerySampleResponse(query_samples[i].id, bi[0], bi[1])
lg.QuerySamplesComplete([response])

def process_sample(self, sample_input):
with torch.no_grad():
'''For Loadgen over the network'''
model_output = self.model.forward(input_ids=torch.LongTensor(sample_input['input_ids']).unsqueeze(0).to(self.dev),
attention_mask=torch.LongTensor(sample_input['input_mask']).unsqueeze(0).to(self.dev),
token_type_ids=torch.LongTensor(sample_input['segment_ids']).unsqueeze(0).to(self.dev))
model_output = self.model.forward(input_ids=torch.LongTensor(input_ids).unsqueeze(0).to(self.dev),
attention_mask=torch.LongTensor(input_mask).unsqueeze(0).to(self.dev),
token_type_ids=torch.LongTensor(segment_ids).unsqueeze(0).to(self.dev))
if self.version >= '4.0.0':
start_scores = model_output.start_logits
end_scores = model_output.end_logits
else:
start_scores, end_scores = model_output
output = torch.stack([start_scores, end_scores], axis=-1).squeeze(0).cpu().numpy()

return output.tolist()
if self.network == "sut":
return output.tolist()

response_array = array.array("B", output.tobytes())
bi = response_array.buffer_info()
response = lg.QuerySampleResponse(query_id, bi[0], bi[1])
lg.QuerySamplesComplete([response])


def flush_queries(self):
pass
Expand Down

0 comments on commit acc3024

Please sign in to comment.