Skip to content

Improve KE for commandline and programmatically tuning dispatch#18778

Merged
mindest merged 13 commits intomainfrom
guangyunhan/improve-ke
Apr 8, 2024
Merged

Improve KE for commandline and programmatically tuning dispatch#18778
mindest merged 13 commits intomainfrom
guangyunhan/improve-ke

Conversation

@cloudhan
Copy link
Copy Markdown
Contributor

No description provided.

@cloudhan
Copy link
Copy Markdown
Contributor Author

For example:

import os
import sys

sys.path.insert(0, "/home/guangyunhan/onnxruntime/onnxruntime/python/tools/kernel_explorer/kernels")
sys.path.insert(0, "/home/guangyunhan/onnxruntime/build_rocm/Release")
os.environ["KERNEL_EXPLORER_BUILD_DIR"] = "/home/guangyunhan/onnxruntime/build_rocm/Release"


import multiprocessing as mp
from multiprocessing import Pool, current_process


def profile(name, *args, **kwargs):
    import kernel_explorer as ke

    ke.set_return_tuning_results()
    ke.set_dispatchable_pattern("*Tunable*")
    print(os.environ["HIP_VISIBLE_DEVICES"])
    if name == "gemm":
        from gemm_test import profile_with_args as profile

        return profile(*args, **kwargs)
    elif name == "softmax":
        from softmax_test import profile_with_args as profile

        return profile(*args, **kwargs)
    else:
        return []


def init():
    pidx = int(current_process()._identity[0]) - 1
    start_gpu = 2
    num_gpu = 14
    os.environ["HIP_VISIBLE_DEVICES"] = str(pidx % num_gpu + start_gpu)


if __name__ == "__main__":
    configs = [
        ("gemm", "float16", False, False, 1, 8912, 8912),
        ("gemm", "float16", False, False, 8, 8912, 8912),
        ("gemm", "float16", False, False, 16, 8912, 8912),
        ("gemm", "float16", False, False, 24, 8912, 8912),
        ("gemm", "float16", False, False, 32, 8912, 8912),
        ("gemm", "float16", False, False, 40, 8912, 8912),
        ("gemm", "float16", False, False, 48, 8912, 8912),
        ("softmax", 1, 1024, False, "float16"),
        ("softmax", 2, 1024, False, "float16"),
    ]

    mp.set_start_method("spawn")

    with Pool(processes=4, initializer=init) as pool:
        ret = pool.starmap(profile, configs, chunksize=1)

    from pprint import pprint
    from onnxruntime.tools.offline_tuning import Merger

    m = Merger()
    for tr in ret:
        m.merge(tr)

    pprint(m.get_merged())

Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py Fixed
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py Fixed
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py Fixed
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py Fixed
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py Fixed
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py Fixed
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py Fixed
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py Fixed
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py Fixed
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py Fixed
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py Fixed
_ke_context.instance_dispatchable[f.__name__] = f

@wraps(f)
def wrapper(*args, **kwargs):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py Fixed
Copy link
Copy Markdown
Contributor

@mindest mindest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR. Tested locally, it worked fine. Please fix the lint warnings/errors.

Comment thread onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py Outdated
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py Outdated
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py Outdated
@cloudhan cloudhan force-pushed the guangyunhan/improve-ke branch from 65698f4 to d4310c3 Compare February 23, 2024 09:30
@cloudhan cloudhan marked this pull request as ready for review February 23, 2024 09:30
@cloudhan cloudhan requested review from kailums and mindest February 23, 2024 09:30
kailums
kailums previously approved these changes Feb 26, 2024
mindest
mindest previously approved these changes Feb 26, 2024
Copy link
Copy Markdown
Contributor

@mindest mindest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the new feature support.

Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py Outdated
Comment thread onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py Outdated
mindest
mindest previously approved these changes Feb 27, 2024
@cloudhan cloudhan force-pushed the guangyunhan/improve-ke branch from 1f9ee51 to aeec2ed Compare March 28, 2024 07:49
@cloudhan cloudhan requested review from kailums and mindest and removed request for mindest April 8, 2024 03:07
@mindest mindest merged commit e19c778 into main Apr 8, 2024
@mindest mindest deleted the guangyunhan/improve-ke branch April 8, 2024 03:09
TedThemistokleous pushed a commit to TedThemistokleous/onnxruntime that referenced this pull request May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants