-
Notifications
You must be signed in to change notification settings - Fork 81
/
utils.cpp
156 lines (141 loc) · 4.82 KB
/
utils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#include <stdio.h>
#include <occa/internal/modes/dpcpp/utils.hpp>
#include <occa/internal/modes/dpcpp/device.hpp>
#include <occa/internal/modes/dpcpp/memory.hpp>
#include <occa/internal/modes/dpcpp/kernel.hpp>
#include <occa/internal/modes/dpcpp/stream.hpp>
#include <occa/internal/modes/dpcpp/streamTag.hpp>
#include <occa/internal/io.hpp>
#include <occa/internal/utils/sys.hpp>
#include <occa/internal/utils/env.hpp>
#include <occa/core/base.hpp>
namespace occa
{
namespace dpcpp
{
/* Returns true if any DPC++ device is enabled on the machine */
bool isEnabled()
{
auto device_list = ::sycl::device::get_devices();
return (device_list.size() > 0);
}
void setCompiler(occa::json &dpcpp_properties) noexcept
{
std::string compiler;
if (env::var("OCCA_DPCPP_COMPILER").size()) {
compiler = env::var("OCCA_DPCPP_COMPILER");
} else if (dpcpp_properties.has("compiler")){
compiler = dpcpp_properties["compiler"].toString();
} else if (env::var("OCCA_CXX").size()) {
compiler = env::var("OCCA_CXX");
} else if (env::var("CXX").size()) {
compiler = env::var("CXX");
} else {
OCCA_FORCE_WARNING("OCCA_DPCPP_COMPILER is defaulting to clang++");
compiler = "clang++";
}
dpcpp_properties["compiler"] = compiler;
}
void setCompilerFlags(occa::json &dpcpp_properties) noexcept
{
std::string compiler_flags;
if (dpcpp_properties.has("compiler_flags"))
{
compiler_flags = dpcpp_properties["compiler_flags"].toString();
}
else if (env::var("OCCA_DPCPP_COMPILER_FLAGS").size())
{
compiler_flags = env::var("OCCA_DPCPP_COMPILER_FLAGS");
}
else
{
compiler_flags = "-O3 -fsycl";
}
dpcpp_properties["compiler_flags"] = compiler_flags;
}
void setSharedFlags(occa::json &dpcpp_properties) noexcept
{
std::string shared_flags;
if (env::var("OCCA_COMPILER_SHARED_FLAGS").size())
{
shared_flags = env::var("OCCA_COMPILER_SHARED_FLAGS");
}
else if (dpcpp_properties.has("compiler_shared_flags"))
{
shared_flags = (std::string) dpcpp_properties["compiler_shared_flags"];
}
else
{
shared_flags = "-shared -fPIC";
}
dpcpp_properties["compiler_shared_flags"] = shared_flags;
}
void setLinkerFlags(occa::json &dpcpp_properties) noexcept
{
std::string linker_flags;
if (env::var("OCCA_DPCPP_LINKER_FLAGS").size())
{
linker_flags = env::var("OCCA_DPCPP_LINKER_FLAGS");
}
else if (dpcpp_properties.has("linker_flags"))
{
linker_flags = dpcpp_properties["linker_flags"].toString();
}
dpcpp_properties["linker_flags"] = linker_flags;
}
occa::dpcpp::device& getDpcppDevice(modeDevice_t* device_)
{
occa::dpcpp::device* dpcppDevice = dynamic_cast<occa::dpcpp::device*>(device_);
OCCA_ERROR("[dpcpp::getDpcppDevice] Dynamic cast failed!",nullptr != dpcppDevice);
return *dpcppDevice;
}
occa::dpcpp::stream& getDpcppStream(const occa::stream& stream_)
{
auto* dpcpp_stream{dynamic_cast<occa::dpcpp::stream*>(stream_.getModeStream())};
OCCA_ERROR("[dpcpp::getDpcppStream]: Dynamic cast failed!", nullptr != dpcpp_stream);
return *dpcpp_stream;
}
occa::dpcpp::streamTag& getDpcppStreamTag(const occa::streamTag& tag_)
{
auto* dpcppTag{dynamic_cast<occa::dpcpp::streamTag*>(tag_.getModeStreamTag())};
OCCA_ERROR("[dpcpp::getDpcppStreamTag]: Dynamic cast failed!", nullptr != dpcppTag);
return *dpcppTag;
}
occa::device wrapDevice(::sycl::device sycl_device,
const occa::properties &props)
{
occa::properties allProps;
allProps["mode"] = "dpcpp";
allProps["wrapped"] = true;
allProps += props;
auto* wrapper{new dpcpp::device(allProps, sycl_device)};
wrapper->dontUseRefs();
wrapper->currentStream = wrapper->createStream(allProps["stream"]);
return occa::device(wrapper);
}
void warn(const ::sycl::exception &e,
const std::string &filename,
const std::string &function,
const int line,
const std::string &message)
{
std::stringstream ss;
ss << message << "\n"
<< "DPCPP Error:"
<< e.what();
occa::warn(filename, function, line, ss.str());
}
void error(const ::sycl::exception &e,
const std::string &filename,
const std::string &function,
const int line,
const std::string &message)
{
std::stringstream ss;
ss << message << "\n"
<< "DPCPP Error:"
<< e.what();
occa::error(filename, function, line, ss.str());
}
} // namespace dpcpp
} // namespace occa