Skip to content

Commit

Permalink
refactor benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Jan 22, 2019
1 parent c890d36 commit 8629805
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,16 @@ def get_benchmark_parser():
param_str = '\n'.join(['%20s = %s' % (k, v) for k, v in sorted(common.items())])
tprint('%20s %s\n%s\n%s\n' % ('ARG', 'VALUE', '_' * 50, param_str))

with open(common['client_vocab_source'], encoding='utf8') as fp:
vocab = list(set(vv for v in fp for vv in v.strip().split()))
tprint('vocabulary size: %d' % len(vocab))

args = namedtuple('args_nt', ','.join(common.keys()))
globals()[args.__name__] = args


class BenchmarkClient(threading.Thread):
def __init__(self, vocab):
def __init__(self):
super().__init__()
self.batch = [' '.join(random.choices(vocab, k=args.max_seq_len)) for _ in range(args.client_batch_size)]
self.num_repeat = args.num_repeat
Expand All @@ -83,10 +87,6 @@ def run(self):


if __name__ == '__main__':
with open(args.client_vocab_source, encoding='utf8') as fp:
vocab = list(set(vv for v in fp for vv in v.strip().split()))
tprint('vocabulary size: %d' % len(vocab))

experiments = {k: common['test_%s' % k] for k in
['client_batch_size', 'max_batch_size', 'max_seq_len', 'num_client', 'pooling_layer']}

Expand All @@ -106,7 +106,7 @@ def run(self):
time.sleep(args.wait_till_ready)

# sleep until server is ready
all_clients = [BenchmarkClient(vocab) for _ in range(args.num_client)]
all_clients = [BenchmarkClient() for _ in range(args.num_client)]
for bc in all_clients:
bc.start()

Expand Down

0 comments on commit 8629805

Please sign in to comment.