In [None]:
import os, sys

if not 'workbookDir' in globals():
    workbookDir = os.getcwd()

to_remove = []
d_to_remove = workbookDir
for i, p in enumerate(sys.path):
    try:
        if p == "":
            to_remove.append(i)
        elif os.path.samefile(p, d_to_remove):
            to_remove.append(i)
    except:
        pass

for i in reversed(to_remove):
    try:
        sys.path.pop(i)
    except:
        pass

sys.path.insert(0, os.path.expanduser("~/onnxruntime/build_rocm/Release/build/lib"))

In [None]:
import onnxruntime as ort
import onnx
import numpy as np
import matplotlib.pyplot as plt

ort.set_default_logger_severity(0)
ort.set_default_logger_verbosity(1000)

In [None]:
def multinormal_distribution(num_distribution, num_element_per_dist):
    arrays = []
    for i in range(num_distribution):
        mean = np.random.randn()
        std = np.random.rand() # * np.sqrt(num_element_per_dist)
        arrays.append(np.random.normal(mean, std, (num_element_per_dist,)))
    return np.array(arrays)

In [None]:
pack_kv = False
pack_qkv = True
assert int(pack_kv) + int(pack_qkv) == 1

B,S,N,H = 2,64,8,160

# B,S,N,H = 1,2,1,8


np.random.seed(1)
qkv = multinormal_distribution(B * S * N * 3, H).reshape(B, S, N, 3, H)

q_data = qkv[:, :, :, 0, :].reshape(B, S, N*H).astype(np.float16)
kv_data = qkv[:, :, :, 1:, :].astype(np.float16)
qkv_data = qkv.astype(np.float16)

q = onnx.helper.make_tensor_value_info("q", onnx.TensorProto.FLOAT16, ["batchsize", S, N*H])
kv = onnx.helper.make_tensor_value_info("kv", onnx.TensorProto.FLOAT16, ["batchsize", S, N, 2, H])
qkv = onnx.helper.make_tensor_value_info("qkv", onnx.TensorProto.FLOAT16, ["batchsize", S, N, 3, H])

output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT16, ["batchsize", S, N*H])

node_inputs = ["q", "kv"] if pack_kv else ["qkv"]
node = onnx.helper.make_node("MultiHeadAttention", inputs=node_inputs, outputs=["output"], domain="com.microsoft", num_heads=N)


graph_inputs = [q, kv] if pack_kv else [qkv]
graph = onnx.helper.make_graph([node], "Attn", graph_inputs, [output])

model = onnx.helper.make_model(graph, producer_name="tmp", opset_imports=[
    onnx.helper.make_opsetid('com.microsoft', 1), 
    onnx.helper.make_opsetid('ai.onnx.ml', 1), 
    onnx.helper.make_opsetid('', 14),
])

print(onnx.checker.check_model(model))

In [None]:
so = ort.SessionOptions()
so.log_severity_level = 0
so.log_verbosity_level = 1000

sess = ort.InferenceSession(
    model.SerializeToString(),
    providers=[("ROCMExecutionProvider", {"tunable_op_enabled": "1"})],
    sess_options=so,
)

In [None]:
input_feed = {"q": q_data, "kv": kv_data} if pack_kv else {"qkv": qkv_data}

In [None]:
our = sess.run(
    output_names = [node.name for node in sess.get_outputs()],
    input_feed = input_feed
)[0]

In [None]:
qkv_data.shape

In [None]:
import scipy

def ref_impl(qkv):
    Q = np.swapaxes(qkv[:, :, :, 0, :], 2, 1)
    K = np.swapaxes(qkv[:, :, :, 1, :], 2, 1)
    V = np.swapaxes(qkv[:, :, :, 2, :], 2, 1)

    pre_softmax_attn_scores = Q @ np.swapaxes(K, 2, 3)
    scale = 1.0/np.sqrt(H)
    pre_softmax_attn_scores = pre_softmax_attn_scores * scale

    attn_scores = scipy.special.softmax(pre_softmax_attn_scores, axis=-1)
    attn = attn_scores @ V
    attn = np.swapaxes(attn, 2, 1)  # permute 0213
    return np.reshape(attn, attn.shape[:2] + (-1,))
    

In [None]:
ref = ref_impl(qkv_data)

In [None]:
print(ref.shape)
print(our.shape)

In [None]:
diff = ref - our

In [None]:
diff.shape

In [None]:
diff

In [None]:
i = 0

In [None]:
ref[i]

In [None]:
our[i]

In [None]:
plt.imshow(our[i])

In [None]:
plt.figure()
plt.plot(ref[i][2])
plt.plot(our[i][2])

In [None]:
np.isnan(our).sum()

In [None]:
np.isnan(ref).sum()

In [None]:
denorm = ref.copy()
denorm[denorm == 0] = float("inf")
rtol = np.abs(diff / denorm)

In [None]:
import matplotlib.pyplot as plt

In [None]:
rtol_1d = rtol.reshape(-1)

print(np.sum(rtol_1d < 1))
print(np.sum(rtol_1d >= 1))

_ = plt.hist(rtol_1d[np.isfinite(rtol_1d)], bins=500, log=True)

In [None]:
num_topk = 1000
indices = np.unravel_index(np.argsort(rtol.reshape(-1))[-num_topk:], rtol.shape)
print("{:<20} {:<16.8} {:<20.8} {:<20.8}".format("index", "rtol", "ref", "our"))
for i in reversed(range(num_topk)):
    idx = (indices[0][i], indices[1][i], indices[2][i])
    print(f"{str(idx):<20} {rtol[idx]:<16.8} {ref[idx]:<20.8} {our[idx]:<20.8}")