In [None]:
import os, sys

if not 'workbookDir' in globals():
    workbookDir = os.getcwd()

to_remove = []
d_to_remove = workbookDir
for i, p in enumerate(sys.path):
    try:
        if p == "":
            to_remove.append(i)
        elif os.path.samefile(p, d_to_remove):
            to_remove.append(i)
    except:
        pass

for i in reversed(to_remove):
    try:
        sys.path.pop(i)
    except:
        pass

sys.path.insert(0, os.path.expanduser("~/onnxruntime/build_rocm/Release/build/lib"))

In [None]:
import onnxruntime as ort
import onnx
import numpy as np
import matplotlib.pyplot as plt

ort.set_default_logger_severity(0)
ort.set_default_logger_verbosity(1000)

In [None]:
def multinormal_distribution(num_distribution, num_element_per_dist):
    arrays = []
    for i in range(num_distribution):
        mean = np.random.randn()
        std = np.random.rand() # * np.sqrt(num_element_per_dist)
        arrays.append(np.random.normal(mean, std, (num_element_per_dist,)))
    return np.array(arrays)

In [None]:
use_attn_bias = True
use_attn_mask = False

input = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT16, ["batchsize", 512, 768])
output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT16, ["batchsize", 512, 768])
attn_mask = onnx.helper.make_tensor_value_info("attn_mask", onnx.TensorProto.INT32, ["batchsize", 512])
attn_bias = onnx.helper.make_tensor_value_info("attn_bias", onnx.TensorProto.FLOAT16, ["batchsize", 12,  512, 512])

np.random.seed(1)
# qkv_weight_data = np.ones((768, 3 * 768))
qkv_weight_data = multinormal_distribution(3*768 * 12, 768 // 12).reshape((768, 3 * 768))
qkv_weight = onnx.helper.make_tensor("qkv_weight", onnx.TensorProto.FLOAT16, [768, 2304], qkv_weight_data)
# qkv_bias = onnx.helper.make_tensor("qkv_bias", onnx.TensorProto.FLOAT16, [2304], np.random.random([2304]))
qkv_bias = onnx.helper.make_tensor("qkv_bias", onnx.TensorProto.FLOAT16, [2304], np.zeros([2304]))

node_inputs = ["input", "qkv_weight", "qkv_bias", "", "", ""]
if use_attn_bias:
    node_inputs[5] = attn_bias.name
if use_attn_mask:
    node_inputs[3] = attn_mask.name
node = onnx.helper.make_node("Attention", inputs=node_inputs, outputs=["output"], domain="com.microsoft", num_heads=12)


graph_inputs = [input]
if use_attn_bias:
    graph_inputs.append(attn_bias)
if use_attn_mask:
    graph_inputs.append(attn_mask)
graph = onnx.helper.make_graph([node], "Attn", graph_inputs, [output], initializer=[qkv_weight, qkv_bias])

model = onnx.helper.make_model(graph, producer_name="tmp", opset_imports=[
    onnx.helper.make_opsetid('com.microsoft', 1), 
    onnx.helper.make_opsetid('ai.onnx.ml', 1), 
    onnx.helper.make_opsetid('', 14),
])

print(onnx.checker.check_model(model))

In [None]:
plt.imshow(qkv_weight_data)

In [None]:
so = ort.SessionOptions()
so.log_severity_level = 0
so.log_verbosity_level = 1000

sess0 = ort.InferenceSession(
    model.SerializeToString(),
    providers=[("CPUExecutionProvider", {"tunable_op_enabled": "0"})],
    sess_options=so,
)

sess1 = ort.InferenceSession(
    model.SerializeToString(),
    providers=[("ROCMExecutionProvider", {"tunable_op_enabled": "1"})],
    sess_options=so,
)

i = 17
results = [{'ep': 'ROCMExecutionProvider', 'results': {'onnxruntime::TunableOp<onnxruntime::contrib::rocm::GemmSoftmaxGemmPermuteParams<__half>, onnxruntime::rocm::tunable::Timer>': {'M512_N512_K64_O64_B768': i}, 'onnxruntime::TunableOp<onnxruntime::rocm::tunable::blas::StridedBatchedGemmParams<__half>, onnxruntime::rocm::tunable::Timer>': {'NN_512_64_512_B768': 0, 'NT_512_512_64_B768': 0}, 'onnxruntime::TunableOp<onnxruntime::rocm::tunable::blas::GemmParams<__half>, onnxruntime::rocm::tunable::Timer>': {'NN_32768_2304_768': 0, 'NN_32768_2304_1': 0}}, 'validators': {'ORT_VERSION': '1.15.0', 'ORT_GIT_COMMIT': '', 'ORT_BUILD_CONFIG': 'USE_CK=1|USE_ROCBLAS_EXTENSION_API=0|', 'HIP_VERSION': '50422803', 'ROCBLAS_VERSION': '2.46.0.ef7a9bb9-dirty', 'DEVICE_MODEL': 'AMD Instinct MI250X/MI250'}}]
# sess1.set_tuning_results(results)

In [None]:
batchsize = 64

input = (0.01 * np.random.randn(batchsize, 512, 768)).astype(np.float16)
# input = (0.01 * np.ones((batchsize, 512, 768))).astype(np.float16)

attn_bias = np.random.uniform(-2, 2, size=(batchsize, 12, 512, 512)).astype(np.float16)
# attn_bias[0,0,0,0] = float("nan")


attn_mask = np.ones([batchsize, 512], dtype=np.int32)
# attn_mask[1, 1] = 2
# attn_mask = np.zeros([batchsize, 512], dtype=np.int32)
# attn_mask = np.random.randint(0, 2, size=(batchsize, 512), dtype=np.int32)

In [None]:
attn_mask

In [None]:
input_feed = { "input": input }
if use_attn_bias:
    input_feed["attn_bias"] = attn_bias
if use_attn_mask:
    input_feed["attn_mask"] = attn_mask

In [None]:
o0 = sess0.run(
    output_names = [node.name for node in sess0.get_outputs()],
    input_feed = input_feed
)[0]

In [None]:
o1 = sess1.run(
    output_names = [node.name for node in sess1.get_outputs()],
    input_feed = input_feed
)[0]

In [None]:
ref = o0
my = o1
diff = ref - my

In [None]:
diff.shape

In [None]:
i = 1

In [None]:
ref[i]

In [None]:
my[i]

In [None]:
plt.imshow(my[i])

In [None]:
plt.figure()
plt.plot(ref[i][-2])
plt.plot(my[i][-2])

In [None]:
my[i][-1]

In [None]:
np.isnan(my).sum()

In [None]:
np.isnan(ref).sum()

In [None]:
denorm = ref.copy()
denorm[denorm == 0] = float("inf")
rtol = np.abs(diff / denorm)

In [None]:
import matplotlib.pyplot as plt

In [None]:
rtol_1d = rtol.reshape(-1)

print(np.sum(rtol_1d < 1))
print(np.sum(rtol_1d >= 1))

_ = plt.hist(rtol_1d[np.isfinite(rtol_1d)], bins=500, log=True)