In [1]:
import time
import numpy as np
import mxnet as mx
import gluonnlp as nlp
from hybrid_bert import get_hybrid_model, HybridBERTClassifier
import tvm
from tvm import relay
from tvm import autotvm
import tvm.contrib.graph_runtime as runtime

In [2]:
seq_length = 128
model_name = "bert_12_768_12"
dataset = "book_corpus_wiki_en_uncased"

mx_ctx = mx.cpu()
bert, _ = get_hybrid_model(
    name=model_name,
    ctx=mx_ctx,
    dataset_name=dataset,
    pretrained=False,
    use_pooler=True,
    use_decoder=False,
    use_classifier=False,
    seq_length=seq_length)
mx_model = HybridBERTClassifier(bert, num_classes=2, dropout=0.1)
mx_model.initialize(ctx=mx_ctx)
mx_model.hybridize(static_alloc=True)

In [3]:
inputs = np.random.randint(0, 1000, size=(1, seq_length)).astype('float32')
token_types = np.random.choice((0, 1), size=(1, seq_length)).astype('float32')
valid_length = np.asarray([seq_length]).astype('float32')

inputs_nd = mx.nd.array(inputs, ctx=mx_ctx)
token_types_nd = mx.nd.array(token_types, ctx=mx_ctx)
valid_length_nd = mx.nd.array(valid_length, ctx=mx_ctx)
mx_out = mx_model(inputs_nd, token_types_nd, valid_length_nd)
print(mx_out)


[[ 0.29543823 -0.354685  ]]
<NDArray 1x2 @cpu(0)>


In [4]:
# dry run
for _ in range(10):
    mx_model(inputs_nd, token_types_nd, valid_length_nd).wait_to_read()

min_repeat_ms = 2000
number = 20
while True:
    beg = time.time()
    for _ in range(number):
        mx_model(inputs_nd, token_types_nd, valid_length_nd).wait_to_read()
    end = time.time()
    lat = (end - beg) * 1e3
    if lat >= min_repeat_ms:
        break
    number = int(max(min_repeat_ms / (lat / number) + 1, number * 1.618))
print('mxnet latency: %.2f ms' % (lat / number))

mxnet latency: 37.45 ms


In [5]:
shape_dict = {
    'data0': (1, seq_length),
    'data1': (1, seq_length),
    'data2': (1,)
}
mod, params = relay.frontend.from_mxnet(mx_model, shape_dict)

In [6]:
ctx = tvm.cpu()
target = "llvm -mcpu=skylake-avx512 -libs=cblas"

with autotvm.apply_history_best("c5.log"):
    with relay.build_config(opt_level=3):
        graph, lib, new_params = relay.build(mod["main"], target, params=params)

In [7]:
ex = runtime.create(graph, lib, ctx)
ex.set_input(data0=inputs, data1=token_types, data2=valid_length, **new_params)
ex.run()
out = ex.get_output(0)
print(out)

# check correctness
tvm.testing.assert_allclose(out.asnumpy(), mx_out.asnumpy(), rtol=1e-3)

[[ 0.2954375  -0.35468534]]


In [8]:
# benchmark
ftimer = ex.module.time_evaluator("run", ctx, min_repeat_ms=2000)
prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
print("TVM latency for seq length %s: %.2f ms" % (seq_length, np.mean(prof_res)))

TVM latency for seq length 128: 34.50 ms
