Skip to content

Commit

Permalink
Set SYCL subgroup size via kernel property or @simd_length attribut…
Browse files Browse the repository at this point in the history
…e. (#726)
  • Loading branch information
kris-rowe committed Dec 15, 2023
1 parent 8957558 commit 66ae951
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 13 deletions.
5 changes: 5 additions & 0 deletions src/occa/internal/lang/attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ namespace occa {
bool attributeArg_t::exists() const {
return expr;
}

bool attributeArg_t::canEvaluate() const {
if (!expr) return false;
return expr->canEvaluate();
}
//==================================

//---[ Attribute ]------------------
Expand Down
2 changes: 2 additions & 0 deletions src/occa/internal/lang/attribute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ namespace occa {
void clear();

bool exists() const;

bool canEvaluate() const;
};
//==================================

Expand Down
1 change: 1 addition & 0 deletions src/occa/internal/lang/builtins/attributes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <occa/internal/lang/builtins/attributes/outer.hpp>
#include <occa/internal/lang/builtins/attributes/restrict.hpp>
#include <occa/internal/lang/builtins/attributes/shared.hpp>
#include <occa/internal/lang/builtins/attributes/simdLength.hpp>
#include <occa/internal/lang/builtins/attributes/tile.hpp>

#endif
50 changes: 50 additions & 0 deletions src/occa/internal/lang/builtins/attributes/simdLength.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <occa/internal/lang/expr.hpp>
#include <occa/internal/lang/parser.hpp>
#include <occa/internal/lang/statement.hpp>
#include <occa/internal/lang/variable.hpp>
#include <occa/internal/lang/builtins/attributes/simdLength.hpp>

namespace occa {
namespace lang {
namespace attributes {

const std::string& simdLength::name() const { return name_;}

bool simdLength::forStatementType(const int sType) const {
return (sType & statementType::for_);
}

bool simdLength::isValid(const attributeToken_t &attr) const {
if (attr.kwargs.size()) {
attr.printError(name_ + " does not take kwargs");
return false;
}

if (1 != attr.args.size()) {
attr.printError(name_ + " takes one argument");
return false;
}

const auto& attr_arg = attr.args[0];
if (!attr_arg.canEvaluate()) {
attr.printError(name_ + " cannot evaluate argument");
return false;
}

primitive value = attr_arg.expr->evaluate();
if (!value.isInteger()) {
attr.printError(name_ + " take an integer argument");
return false;
}

if(0 > value.to<int>()) {
attr.printError(name_ + " arguments must be postive!");
return false;
}

return true;
}

}
}
}
24 changes: 24 additions & 0 deletions src/occa/internal/lang/builtins/attributes/simdLength.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef OCCA_INTERNAL_LANG_BUILTINS_ATTRIBUTES_SIMD_LENGTH_HEADER
#define OCCA_INTERNAL_LANG_BUILTINS_ATTRIBUTES_SIMD_LENGTH_HEADER

#include <occa/internal/lang/attribute.hpp>

namespace occa {
namespace lang {
namespace attributes {

class simdLength : public attribute_t {
public:
simdLength() = default;
const std::string& name() const override;
bool forStatementType(const int sType) const override;
bool isValid(const attributeToken_t &attr) const override;
private:
static const inline std::string name_{"simd_length"};
};

}
}
}

#endif
94 changes: 82 additions & 12 deletions src/occa/internal/lang/modes/dpcpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,78 @@
#include <occa/internal/lang/builtins/attributes.hpp>
#include <occa/internal/lang/builtins/types.hpp>
#include <occa/internal/lang/expr.hpp>
// #include <stringstream>
#include <occa/internal/lang/attribute.hpp>

namespace {

class dpcppLambda_t : public occa::lang::lambda_t {
public:
int simd_length{-1};

dpcppLambda_t(occa::lang::capture_t capture_, int simd_length_)
: lambda_t(capture_), simd_length(simd_length_) {}

dpcppLambda_t(const dpcppLambda_t& other)
: lambda_t(other), simd_length(other.simd_length) {}

~dpcppLambda_t() = default;

bool equals(const type_t &other) const override {
const dpcppLambda_t &other_ = other.to<dpcppLambda_t>();
if (simd_length != other_.simd_length) return false;
return lambda_t::equals(other);
}

void printDeclaration(occa::lang::printer &pout) const override {
pout << "[";

switch (this->capture) {
case occa::lang::capture_t::byValue:
pout << "=";
break;
case occa::lang::capture_t::byReference:
pout << "&";
break;
default:
pout << "???";
break;
}

pout << "](";

if (!args.empty()) {
const std::string argIndent = pout.indentFromNewline();
args[0]->printDeclaration(pout);
for (std::size_t i = 1; i < args.size(); ++i) {
pout << ",\n" << argIndent;
args[i]->printDeclaration(pout);
}
}
pout << ") ";

if (0 < simd_length) {
pout << "[[intel::reqd_sub_group_size(";
pout.print(simd_length);
pout << ")]]";
}

pout << " {";

pout.printNewline();
pout.pushInlined(false);
pout.addIndentation();

body->print(pout);

pout.removeIndentation();
pout.popInlined();
pout.printNewline();
pout.printIndentation();
pout << "}\n";
}
};

}

namespace occa
{
Expand All @@ -20,6 +91,7 @@ namespace occa
shared("auto", qualifierType::custom)
{
okl::addOklAttributes(*this);
simd_length_default = settings_.get("simd_length",-1);
}

void dpcppParser::onClear()
Expand Down Expand Up @@ -79,15 +151,7 @@ namespace occa

std::string dpcppParser::launchBoundsAttribute(const int innerDims[3])
{
std::stringstream ss;
ss << "[[sycl::reqd_work_group_size("
<< innerDims[2]
<< ","
<< innerDims[1]
<< ","
<< innerDims[0]
<< ")]]\n";
return ss.str();
return "";
}

// @note: As of SYCL 2020 this will need to change from `CL/sycl.hpp` to `sycl.hpp`
Expand Down Expand Up @@ -188,9 +252,15 @@ namespace occa
lambda_t &cg_function = *(new lambda_t(capture_t::byReference));
cg_function.addArgument(sycl_handler);

lambda_t &sycl_kernel = *(new lambda_t(capture_t::byValue));
sycl_kernel.addArgument(sycl_nditem);

int simd_length = simd_length_default;
if (k.hasAttribute("simd_length")) {
const attributeToken_t& attr = k.attributes["simd_length"];
simd_length = attr.args[0].expr->evaluate();
}

dpcppLambda_t& sycl_kernel = *(new dpcppLambda_t(capture_t::byValue, simd_length));
sycl_kernel.addArgument(sycl_nditem);
sycl_kernel.body->swap(k);

lambdaNode sycl_kernel_node(sycl_kernel.source, sycl_kernel);
Expand Down
3 changes: 2 additions & 1 deletion src/occa/internal/lang/modes/dpcpp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ namespace occa
void setSharedQualifiers();
void setKernelQualifiers(function_t &function);
void migrateLocalDecls(functionDeclStatement &kernelSmnt);
void setLaunchBounds();

void setupAtomics();
static bool transformAtomicBlockStatement(blockStatement &blockSmnt);
static bool transformAtomicBasicExpressionStatement(expressionStatement &exprSmnt);

private:
int simd_length_default;

inline int dpcppDimensionOrder(const int index) { return 2 - index; }
};
} // namespace okl
Expand Down
1 change: 1 addition & 0 deletions src/occa/internal/lang/modes/okl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ namespace occa {
parser.addAttribute<attributes::shared>();
parser.addAttribute<attributes::maxInnerDims>();
parser.addAttribute<attributes::noBarrier>();
parser.addAttribute<attributes::simdLength>();
}

void setOklLoopIndices(functionDeclStatement &kernelSmnt) {
Expand Down
4 changes: 4 additions & 0 deletions src/occa/internal/lang/modes/withLauncher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ namespace occa {
forStatement &newForSmnt = (forStatement&) forSmnt.clone();
newKernelSmnt.set(newForSmnt);

if (newForSmnt.hasAttribute("simd_length")) {
newKernelSmnt.addAttribute(newForSmnt.attributes["simd_length"]);
}

bool addLaunchBoundsAttribute{true};
int kernelInnerDims[3] = {1,1,1};
if (newForSmnt.hasAttribute("max_inner_dims")) {
Expand Down
56 changes: 56 additions & 0 deletions tests/src/internal/lang/modes/dpcpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void testSharedAnnotation();
void testBarriers();
void testAtomic();
void testSource();
void testSimdLength();

int main(const int argc, const char **argv) {
parser.settings["okl/validate"] = true;
Expand All @@ -38,6 +39,7 @@ int main(const int argc, const char **argv) {
testSharedAnnotation();
testBarriers();
testSource();
testSimdLength();

return 0;
}
Expand Down Expand Up @@ -163,3 +165,57 @@ void testSource() {
"}\n"
);
}

void testSimdLengthAttribute() {
const std::string kernel_source = R"(
@kernel void f() {
@outer @simd_length(16)
for (int o = 0; o < 1; ++o) {
@inner for (int i = 0; i < 32; ++i) {
int j = i + o;
}
}
}
)";

parser.parseSource(kernel_source);
ASSERT_TRUE(parser.success);

printer pout;
parser.root.print(pout);
const std::string translated_source = pout.str();

auto pos = translated_source.find("[[intel::reqd_sub_group_size(16)]]");
ASSERT_TRUE(std::string::npos != pos);
}

void testSimdLengthProperty() {
const std::string kernel_source = R"(
@kernel void f() {
@outer for (int o = 0; o < 1; ++o) {
@inner for (int i = 0; i < 32; ++i) {
int j = i + o;
}
}
}
)";

occa::json properties;
properties["simd_length"] = 16;
occa::lang::okl::dpcppParser dpcpp_parser(properties);

dpcpp_parser.parseSource(kernel_source);
ASSERT_TRUE(parser.success);

printer pout;
dpcpp_parser.root.print(pout);
const std::string translated_source = pout.str();

auto pos = translated_source.find("[[intel::reqd_sub_group_size(16)]]");
ASSERT_TRUE(std::string::npos != pos);
}

void testSimdLength() {
testSimdLengthAttribute();
testSimdLengthProperty();
}

0 comments on commit 66ae951

Please sign in to comment.