Skip to content

Commit

Permalink
Allow user to define the maximum number of kernel arguments and autom…
Browse files Browse the repository at this point in the history
…atically run script to generate the codes
  • Loading branch information
deukhyun-cha committed Nov 14, 2023
1 parent ec6e59d commit 63dbab5
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 54 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ opt
/.compiledDefines
/include/occa/defines/compiledDefines.hpp
/include/occa/scripts
/include/occa/core/kernelOperators.hpp_codegen
/src/core/kernelOperators.cpp_codegen
/src/occa/internal/utils/runFunction.cpp_codegen
/include/occa/defines/macros.hpp_codegen

# Binaries generated to fetch compiler information
/scripts/compiler/compilerSupportsMPI
Expand Down
14 changes: 14 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ else()
add_compile_definitions(OCCA_THREAD_SHARABLE_ENABLED=0)
endif()

set(MAX_NUM_KERNEL_ARGS_DEFAULT "128")
set(MAX_NUM_KERNEL_ARGS ${MAX_NUM_KERNEL_ARGS_DEFAULT} CACHE STRING "The maximum number of allowed kernel arguments")
if (${MAX_NUM_KERNEL_ARGS} GREATER ${MAX_NUM_KERNEL_ARGS_DEFAULT})
execute_process(COMMAND python --version OUTPUT_VARIABLE python_version)
string(REGEX MATCH "[0-9.]\+" python_version ${python_version})
if ("${python_version}" VERSION_LESS "3.7.2")
message(WARNING "-- Failed to set the maximum number of kernel arguments to ${MAX_NUM_KERNEL_ARGS}, required minimum python version 3.7.2. The default value ${MAX_NUM_KERNEL_ARGS_DEFAULT} will be used.")
else()
message("-- Codegen for the maximum number of kernel arguments : ${MAX_NUM_KERNEL_ARGS}")
execute_process(COMMAND ${CMAKE_COMMAND} -E env OCCA_DIR=${CMAKE_CURRENT_SOURCE_DIR} python ${CMAKE_CURRENT_SOURCE_DIR}/scripts/codegen/setup_kernel_operators.py -N ${MAX_NUM_KERNEL_ARGS})
endif()
endif()
add_compile_definitions(OCCA_MAX_ARGS=${MAX_NUM_KERNEL_ARGS})

set(OCCA_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(OCCA_BUILD_DIR ${CMAKE_BINARY_DIR})

Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ MAKE_COMPILED_DEFINES := $(shell cat "$(OCCA_DIR)/scripts/build/compiledDefinesT
s,@@OCCA_HIP_ENABLED@@,$(OCCA_HIP_ENABLED),g;\
s,@@OCCA_OPENCL_ENABLED@@,$(OCCA_OPENCL_ENABLED),g;\
s,@@OCCA_METAL_ENABLED@@,$(OCCA_METAL_ENABLED),g;\
s,@@OCCA_DPCPP_ENABLED@@,$(OCCA_DPCPP_ENABLED),g;\
s,@@OCCA_DPCPP_ENABLED@@,$(OCCA_DPCPP_ENABLED),g;\
s,@@OCCA_THREAD_SHARABLE_ENABLED@@,$(OCCA_THREAD_SHARABLE_ENABLED),g;\
s,@@OCCA_MAX_ARGS@@,$(OCCA_MAX_ARGS),g;\
s,@@OCCA_BUILD_DIR@@,$(OCCA_BUILD_DIR),g;"\
> "$(NEW_COMPILED_DEFINES)")

Expand Down
33 changes: 1 addition & 32 deletions include/occa/defines/macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,7 @@
// Just in case someone wants to run with an older format than C99
#ifndef OCCA_DISABLE_VARIADIC_MACROS

# define OCCA_ARG_COUNT(...) OCCA_ARG_COUNT2(\
__VA_ARGS__, \
128, 127, 126, 125, 124, 123, 122, 121, \
120, 119, 118, 117, 116, 115, 114, 113, 112, 111, \
110, 109, 108, 107, 106, 105, 104, 103, 102, 101, \
100, 99, 98, 97, 96, 95, 94, 93, 92, 91, \
90, 89, 88, 87, 86, 85, 84, 83, 82, 81, \
80, 79, 78, 77, 76, 75, 74, 73, 72, 71, \
70, 69, 68, 67, 66, 65, 64, 63, 62, 61, \
60, 59, 58, 57, 56, 55, 54, 53, 52, 51, \
50, 49, 48, 47, 46, 45, 44, 43, 42, 41, \
40, 39, 38, 37, 36, 35, 34, 33, 32, 31, \
30, 29, 28, 27, 26, 25, 24, 23, 22, 21, \
20, 19, 18, 17, 16, 15, 14, 13, 12, 11, \
10, 9, 8, 7, 6, 5, 4, 3, 2, 1, \
0)

# define OCCA_ARG_COUNT2( \
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, \
_11, _12, _13, _14, _15, _16, _17, _18, _19, _20, \
_21, _22, _23, _24, _25, _26, _27, _28, _29, _30, \
_31, _32, _33, _34, _35, _36, _37, _38, _39, _40, \
_41, _42, _43, _44, _45, _46, _47, _48, _49, _50, \
_51, _52, _53, _54, _55, _56, _57, _58, _59, _60, \
_61, _62, _63, _64, _65, _66, _67, _68, _69, _70, \
_71, _72, _73, _74, _75, _76, _77, _78, _79, _80, \
_81, _82, _83, _84, _85, _86, _87, _88, _89, _90, \
_91, _92, _93, _94, _95, _96, _97, _98, _99, _100, \
_101, _102, _103, _104, _105, _106, _107, _108, _109, _110, \
_111, _112, _113, _114, _115, _116, _117, _118, _119, _120, \
_121, _122, _123, _124, _125, _126, _127, _128, \
N, ...) N
#include "macros.hpp_codegen"

#endif // OCCA_DISABLE_VARIADIC_MACROS

Expand Down
38 changes: 38 additions & 0 deletions include/occa/defines/macros.hpp_codegen
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// -------------[ DO NOT EDIT ]-------------
// THIS IS AN AUTOMATICALLY GENERATED FILE
// EDIT: scripts/codegen/setup_kernel_operators.py
// =========================================

# define OCCA_ARG_COUNT2( \
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, \
_11, _12, _13, _14, _15, _16, _17, _18, _19, _20, \
_21, _22, _23, _24, _25, _26, _27, _28, _29, _30, \
_31, _32, _33, _34, _35, _36, _37, _38, _39, _40, \
_41, _42, _43, _44, _45, _46, _47, _48, _49, _50, \
_51, _52, _53, _54, _55, _56, _57, _58, _59, _60, \
_61, _62, _63, _64, _65, _66, _67, _68, _69, _70, \
_71, _72, _73, _74, _75, _76, _77, _78, _79, _80, \
_81, _82, _83, _84, _85, _86, _87, _88, _89, _90, \
_91, _92, _93, _94, _95, _96, _97, _98, _99, _100, \
_101, _102, _103, _104, _105, _106, _107, _108, _109, _110, \
_111, _112, _113, _114, _115, _116, _117, _118, _119, _120, \
_121, _122, _123, _124, _125, _126, _127, _128, \
N, ...) N

# define OCCA_ARG_COUNT(...) OCCA_ARG_COUNT2( \
__VA_ARGS__, \
128, 127, 126, 125, 124, 123, 122, 121, \
120, 119, 118, 117, 116, 115, 114, 113, 112, 111, \
110, 109, 108, 107, 106, 105, 104, 103, 102, 101, \
100, 99, 98, 97, 96, 95, 94, 93, 92, 91, \
90, 89, 88, 87, 86, 85, 84, 83, 82, 81, \
80, 79, 78, 77, 76, 75, 74, 73, 72, 71, \
70, 69, 68, 67, 66, 65, 64, 63, 62, 61, \
60, 59, 58, 57, 56, 55, 54, 53, 52, 51, \
50, 49, 48, 47, 46, 45, 44, 43, 42, 41, \
40, 39, 38, 37, 36, 35, 34, 33, 32, 31, \
30, 29, 28, 27, 26, 25, 24, 23, 22, 21, \
20, 19, 18, 17, 16, 15, 14, 13, 12, 11, \
10, 9, 8, 7, 6, 5, 4, 3, 2, 1, \
0)

2 changes: 0 additions & 2 deletions include/occa/defines/occa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
#define OKL_VERSION 10600
#define OKL_VERSION_STR "1.6.0"

#define OCCA_MAX_ARGS 128

#define OCCA_DEFAULT_MEM_BYTE_ALIGN 32

#endif
44 changes: 28 additions & 16 deletions scripts/build/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,16 @@ endif


#---[ Variable Dependencies ]---------------------
fortranEnabled = 0
mpiEnabled = 0
openmpEnabled = 0
cudaEnabled = 0
hipEnabled = 0
openclEnabled = 0
metalEnabled = 0
dpcppEnabled = 0

fortranEnabled = 0
mpiEnabled = 0
openmpEnabled = 0
cudaEnabled = 0
hipEnabled = 0
openclEnabled = 0
metalEnabled = 0
dpcppEnabled = 0
threadSharableEnabled = 0
maxArgs = 128

#---[ Fortran ]-------------------------
ifdef OCCA_FORTRAN_ENABLED
Expand Down Expand Up @@ -480,6 +481,15 @@ ifeq ($(usingMacOS),1)
endif
endif

#---[ Other parameters ]---------------------------
ifdef OCCA_THREAD_SHARABLE_ENABLED
threadSharableEnabled = $(OCCA_THREAD_SHARABLE_ENABLED)
endif

ifdef OCCA_MAX_ARGS
maxArgs = $(OCCA_MAX_ARGS)
endif

ifeq ($(cudaEnabled),1)
compilerFlags += -Wno-c++11-long-long
endif
Expand All @@ -491,11 +501,13 @@ else
OCCA_CHECK_ENABLED := 0
endif

OCCA_FORTRAN_ENABLED := $(fortranEnabled)
OCCA_OPENMP_ENABLED := $(openmpEnabled)
OCCA_CUDA_ENABLED := $(cudaEnabled)
OCCA_HIP_ENABLED := $(hipEnabled)
OCCA_OPENCL_ENABLED := $(openclEnabled)
OCCA_METAL_ENABLED := $(metalEnabled)
OCCA_DPCPP_ENABLED := $(dpcppEnabled)
OCCA_FORTRAN_ENABLED := $(fortranEnabled)
OCCA_OPENMP_ENABLED := $(openmpEnabled)
OCCA_CUDA_ENABLED := $(cudaEnabled)
OCCA_HIP_ENABLED := $(hipEnabled)
OCCA_OPENCL_ENABLED := $(openclEnabled)
OCCA_METAL_ENABLED := $(metalEnabled)
OCCA_DPCPP_ENABLED := $(dpcppEnabled)
OCCA_THREAD_SHARABLE_ENABLED := $(threadSharableEnabled)
OCCA_MAX_ARGS := $(maxArgs)
#=================================================
3 changes: 3 additions & 0 deletions scripts/build/compiledDefinesTemplate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
#define OCCA_METAL_ENABLED @@OCCA_METAL_ENABLED@@
#define OCCA_DPCPP_ENABLED @@OCCA_DPCPP_ENABLED@@

#define OCCA_THREAD_SHARABLE_ENABLED @@OCCA_THREAD_SHARABLE_ENABLED@@
#define OCCA_MAX_ARGS @@OCCA_MAX_ARGS@@

#define OCCA_BUILD_DIR "@@OCCA_BUILD_DIR@@"

#endif
41 changes: 39 additions & 2 deletions scripts/codegen/setup_kernel_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import os
import functools

import argparse

OCCA_DIR = os.environ.get(
'OCCA_DIR',
Expand Down Expand Up @@ -79,7 +79,7 @@ def run_function_from_arguments(N):
content = '\nswitch (argc) {\n'
for n in range(N + 1):
content += run_function_from_argument(n)
content += '}\n';
content += ' default:\n OCCA_FORCE_ERROR("TOO MANY KERNEL ARGUMENTS REQUESTED");\n}\n'

return content

Expand Down Expand Up @@ -148,7 +148,44 @@ def operator_definition(N):
'''
return content

def macro_count2(N):
content = '# define OCCA_ARG_COUNT2( \\\n'
indent=' ' * 2
for n in range(1, N+1):
if n % 10 == 1:
content += indent
content += '_' + str(n) + ', '
if n % 10 == 0:
content += '\\\n'
if N % 10 > 0:
content += '\\\n'
content += indent + 'N, ...) N\n'
return content

def macro_count(N):
content = '# define OCCA_ARG_COUNT(...) OCCA_ARG_COUNT2( \\\n'
indent=' ' * 2
content += indent + '__VA_ARGS__, \\\n' + indent
for n in range(N, 0, -1):
content += str(n) + ', '
if n % 10 == 1:
content += '\\\n' + indent
content += '0)\n'
return content

@to_file('include/occa/defines/macros.hpp_codegen')
def macro_declarations(N):
return ''.join(
macro_count2(N) + '\n' + macro_count(N)
)

if __name__ == '__main__':
parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument("-N","--NargsMax", type=int, default=MAX_ARGS)
args = parser.parse_args()
MAX_ARGS = args.NargsMax

run_function_from_arguments(MAX_ARGS)
operator_declarations(MAX_ARGS)
operator_definitions(MAX_ARGS)
macro_declarations(MAX_ARGS)
1 change: 1 addition & 0 deletions src/core/kernelOperators.cpp_codegen
Original file line number Diff line number Diff line change
Expand Up @@ -3490,3 +3490,4 @@ void kernel::operator() (const kernelArg &arg1, const kernelArg &arg2, const ker
modeKernel->setArguments(args, 128);
run();
}

2 changes: 1 addition & 1 deletion src/occa/internal/utils/runFunction.cpp_codegen
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ switch (argc) {
args[95], args[96], args[97], args[98], args[99],
args[100], args[101]);
break;
case 103:
case 103:
f(args[0], args[1], args[2], args[3], args[4],
args[5], args[6], args[7], args[8], args[9],
args[10], args[11], args[12], args[13], args[14],
Expand Down

0 comments on commit 63dbab5

Please sign in to comment.