Skip to content

Commit

Permalink
Add kernel/include_occa and kernel/link_occa properties to control in…
Browse files Browse the repository at this point in the history
…cluding and linking occa into kernels
  • Loading branch information
deukhyun-cha committed Mar 29, 2023
1 parent 5790d72 commit 247c1e8
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 47 deletions.
1 change: 1 addition & 0 deletions examples/cpp/18_nonblocking_streams/main.cpp
Expand Up @@ -27,6 +27,7 @@ int main(int argc, const char **argv) {
occa::json kernelProps({
{"defines/block", block},
{"defines/group", group},
{"serial/include_std", true},
});
occa::kernel powerOfPi = occa::buildKernel("powerOfPi.okl",
"powerOfPi",
Expand Down
2 changes: 2 additions & 0 deletions src/loops/typelessForLoop.cpp
Expand Up @@ -34,6 +34,8 @@ namespace occa {
loopScope.device = device;
}

loopScope.props["kernel/include_occa"] = true;

const int outerIterationCount = (int) outerIterations.size();
const int innerIterationCount = (int) innerIterations.size();

Expand Down
11 changes: 8 additions & 3 deletions src/occa/internal/lang/modes/serial.cpp
Expand Up @@ -34,8 +34,11 @@ namespace occa {

void serialParser::setupHeaders() {
strVector headers;
const bool includingStd = settings.get("serial/include_std", true);
headers.push_back("include <occa.hpp>\n");
const bool includeOcca = settings.get("kernel/include_occa", false);
if (includeOcca) {
headers.push_back("include <occa.hpp>\n");
}
const bool includingStd = settings.get("serial/include_std", false);
if (includingStd) {
headers.push_back("include <stdint.h>");
headers.push_back("include <cstdlib>");
Expand All @@ -51,7 +54,9 @@ namespace occa {
if (includingStd) {
header += "\nusing namespace std;";
}
header += "\nusing namespace occa;";
if (includeOcca) {
header += "\nusing namespace occa;";
}
}
directiveToken token(root.source->origin,
header);
Expand Down
19 changes: 14 additions & 5 deletions src/occa/internal/modes/cuda/device.cpp
Expand Up @@ -107,7 +107,8 @@ namespace occa {
return (
occa::hash(props["compiler"])
^ props["compiler_flags"]
^ props["compiler_env_script"]
^ props["kernel/include_occa"]
^ props["kernel/link_occa"]
);
}

Expand Down Expand Up @@ -287,6 +288,9 @@ namespace occa {
sys::addCompilerLibraryFlags(compilerFlags);
}

const bool includeOcca = kernelProps.get("kernel/include_occa", false);
const bool linkOcca = kernelProps.get("kernel/link_occa", false);

//---[ Compiling Command ]--------
std::stringstream command;
command << allProps["compiler"]
Expand All @@ -295,10 +299,15 @@ namespace occa {
#if (OCCA_OS == OCCA_WINDOWS_OS)
<< " -D OCCA_OS=OCCA_WINDOWS_OS -D _MSC_VER=1800"
#endif
<< " -I" << env::OCCA_DIR << "include"
<< " -I" << env::OCCA_INSTALL_DIR << "include"
<< " -L" << env::OCCA_INSTALL_DIR << "lib -locca"
<< " -x cu " << sourceFilename
;
if (includeOcca) {
command << " -I" << env::OCCA_DIR << "include"
<< " -I" << env::OCCA_INSTALL_DIR << "include";
}
if (linkOcca) {
command << " -L" << env::OCCA_INSTALL_DIR << "lib -locca";
}
command << " -x cu " << sourceFilename
<< " -o " << binaryFilename
<< " 2>&1";

Expand Down
21 changes: 15 additions & 6 deletions src/occa/internal/modes/hip/device.cpp
Expand Up @@ -105,6 +105,9 @@ namespace occa {
occa::hash(props["compiler"])
^ props["compiler_flags"]
^ props["compiler_env_script"]
^ props["hipcc_compiler_flags"]
^ props["kernel/include_occa"]
^ props["kernel/link_occa"]
);
}

Expand Down Expand Up @@ -282,14 +285,20 @@ namespace occa {
#else
<< " -f=\\\"" << compilerFlags << "\\\""
#endif
<< ' ' << hipccCompilerFlags
<< ' ' << hipccCompilerFlags;
#if defined(__HIP_PLATFORM_NVCC___) || (HIP_VERSION >= 305)
<< " -I" << env::OCCA_DIR << "include"
<< " -I" << env::OCCA_INSTALL_DIR << "include"
const bool includeOcca = kernelProps.get("kernel/include_occa", false);
const bool linkOcca = kernelProps.get("kernel/link_occa", false);
if (includeOcca) {
command << " -I" << env::OCCA_DIR << "include"
<< " -I" << env::OCCA_INSTALL_DIR << "include";
}
if (linkOcca) {
/* NC: hipcc doesn't seem to like linking a library in */
//<< " -L" << env::OCCA_INSTALL_DIR << "lib -locca";
}
#endif
/* NC: hipcc doesn't seem to like linking a library in */
//<< " -L" << env::OCCA_INSTALL_DIR << "lib -locca"
<< ' ' << sourceFilename
command << ' ' << sourceFilename
<< " -o " << binaryFilename
<< " 2>&1";

Expand Down
35 changes: 24 additions & 11 deletions src/occa/internal/modes/serial/device.cpp
Expand Up @@ -36,6 +36,8 @@ namespace occa {
^ props["compiler_language"]
^ props["compiler_linker_flags"]
^ props["compiler_shared_flags"]
^ props["include_occa"]
^ props["link_occa"]
);
}

Expand Down Expand Up @@ -315,6 +317,9 @@ namespace occa {
sys::addCompilerLibraryFlags(compilerFlags);
}

const bool includeOcca = kernelProps.get("kernel/include_occa", isLauncherKernel);
const bool linkOcca = kernelProps.get("kernel/link_occa", isLauncherKernel);

io::stageFile(
binaryFilename,
true,
Expand All @@ -323,11 +328,15 @@ namespace occa {
command << compiler
<< ' ' << compilerFlags
<< ' ' << sourceFilename
<< " -o " << tempFilename
<< " -I" << env::OCCA_DIR << "include"
<< " -I" << env::OCCA_INSTALL_DIR << "include"
<< " -L" << env::OCCA_INSTALL_DIR << "lib -locca"
<< ' ' << compilerLinkerFlags
<< " -o " << tempFilename;
if (includeOcca) {
command << " -I" << env::OCCA_DIR << "include"
<< " -I" << env::OCCA_INSTALL_DIR << "include";
}
if (linkOcca) {
command << " -L" << env::OCCA_INSTALL_DIR << "lib -locca";
}
command << ' ' << compilerLinkerFlags
<< " 2>&1"
<< std::endl;
#else
Expand All @@ -336,12 +345,16 @@ namespace occa {
<< " /D OCCA_OS=OCCA_WINDOWS_OS"
<< " /EHsc"
<< " /wd4244 /wd4800 /wd4804 /wd4018"
<< ' ' << compilerFlags
<< " /I" << env::OCCA_DIR << "include"
<< " /I" << env::OCCA_INSTALL_DIR << "include"
<< ' ' << sourceFilename
<< " /link " << env::OCCA_INSTALL_DIR << "lib/libocca.lib",
<< ' ' << compilerLinkerFlags
<< ' ' << compilerFlags;
if (includeOcca) {
command << " /I" << env::OCCA_DIR << "include"
<< " /I" << env::OCCA_INSTALL_DIR << "include";
}
command << ' ' << sourceFilename;
if (linkOcca) {
command << " /link " << env::OCCA_INSTALL_DIR << "lib/libocca.lib";
}
command << ' ' << compilerLinkerFlags
<< " /OUT:" << tempFilename
<< std::endl;
#endif
Expand Down
2 changes: 1 addition & 1 deletion tests/src/c/kernel.cpp
Expand Up @@ -87,7 +87,7 @@ void testRun() {
occa::env::OCCA_DIR + "tests/files/argKernel.okl"
);
occaJson kernelProps = occaJsonParse(
"{type_validation: false}"
"{type_validation: false, serial: {include_std: true}}"
);
occaKernel argKernel = (
occaBuildKernel(argKernelFile.c_str(),
Expand Down
3 changes: 2 additions & 1 deletion tests/src/core/kernel.cpp
Expand Up @@ -154,7 +154,8 @@ void testRun() {
);
occa::kernel argKernel = occa::buildKernel(argKernelFile,
"argKernel",
{{"type_validation", false}});
{{"type_validation", false},
{"serial/include_std", true}});

argKernel.setRunDims(occa::dim(1, 1, 1),
occa::dim(1, 1, 1));
Expand Down
27 changes: 15 additions & 12 deletions tests/src/math/fpMath.cpp
Expand Up @@ -55,45 +55,48 @@ std::string kernel_back_half =

void testUnaryFunctions(const occa::device& d) {
for (auto fp_type : arg_types) {
std::string arg_decl =
std::string arg_decl =
" " + fp_type + " " + unary_args + ";\n";
for(auto func : unary_functions) {
std::string function_call =
std::string function_call =
" " + fp_type + " w = " + func + "(" + unary_args + ");\n";
std::string kernel_src =
std::string kernel_src =
kernel_front_half + arg_decl + function_call +kernel_back_half;

occa::kernel k = d.buildKernelFromString(kernel_src,"f");
occa::kernel k = d.buildKernelFromString(kernel_src, "f",
{{"serial/include_std", true}});
}
}
}

void testBinaryFunctions(const occa::device& d) {
for (auto fp_type : arg_types) {
std::string arg_decl =
std::string arg_decl =
" " + fp_type + " " + binary_args + ";\n";
for(auto func : binary_functions) {
std::string function_call =
std::string function_call =
" " + fp_type + " w = " + func + "(" + binary_args + ");\n";
std::string kernel_src =
std::string kernel_src =
kernel_front_half + arg_decl + function_call +kernel_back_half;

occa::kernel k = d.buildKernelFromString(kernel_src,"f");
occa::kernel k = d.buildKernelFromString(kernel_src, "f",
{{"serial/include_std", true}});
}
}
}

void testTernaryFunctions(const occa::device& d) {
for (auto fp_type : arg_types) {
std::string arg_decl =
std::string arg_decl =
" " + fp_type + " " + ternary_args + ";\n";
for(auto func : ternary_functions) {
std::string function_call =
std::string function_call =
" " + fp_type + " w = " + func + "(" + ternary_args + ");\n";
std::string kernel_src =
std::string kernel_src =
kernel_front_half + arg_decl + function_call +kernel_back_half;

occa::kernel k = d.buildKernelFromString(kernel_src,"f");
occa::kernel k = d.buildKernelFromString(kernel_src, "f",
{{"serial/include_std", true}});
}
}
}
Expand Down
19 changes: 11 additions & 8 deletions tests/src/math/intMath.cpp
Expand Up @@ -29,30 +29,33 @@ std::string kernel_back_half =

void testUnaryFunctions(const occa::device& d) {
for (auto&& int_type : arg_types) {
const std::string arg_decl =
const std::string arg_decl =
" " + int_type + " " + unary_args + "; \n";
for (auto&& func : unary_functions) {
const std::string function_call =
const std::string function_call =
" " + int_type + " w = " + func + "(" + unary_args + "); \n";
const std::string kernel_src =
const std::string kernel_src =
kernel_front_half + arg_decl + function_call +kernel_back_half;

occa::kernel k = d.buildKernelFromString(kernel_src,"f");
occa::kernel k = d.buildKernelFromString(kernel_src, "f",
{{"serial/include_std", true}});
}
}
}

void testBinaryFunctions(const occa::device& d) {
for (auto&& int_type : arg_types) {
const std::string arg_decl =
const std::string arg_decl =
" " + int_type + " " + binary_args + "; \n";
for (auto&& func : binary_functions) {
const std::string function_call =
const std::string function_call =
" " + int_type + " w = " + func + "(" + binary_args + "); \n";
const std::string kernel_src =
const std::string kernel_src =
kernel_front_half + arg_decl + function_call +kernel_back_half;

occa::kernel k = d.buildKernelFromString(kernel_src,"f");
occa::kernel k = d.buildKernelFromString(kernel_src, "f",
{{"serial/include_std", true},
{"kernel/include_occa", true}}); // For min/max
}
}
}
Expand Down

0 comments on commit 247c1e8

Please sign in to comment.