Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 51 additions & 49 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,55 @@ class SYCLToolchain {
}
};

std::vector<std::string> createCommandLine(const InputArgList &UserArgList,
BinaryFormat Format,
std::string_view SourceFilePath) {
DerivedArgList DAL{UserArgList};
const auto &OptTable = getDriverOptTable();
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_fsycl_device_only));
// User args may contain options not intended for the frontend, but we can't
// claim them here to tell the driver they're used later. Hence, suppress
// the unused argument warning.
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Qunused_arguments));

if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
auto [CPU, Features] =
Translator::getTargetCPUAndFeatureAttrs(nullptr, "", Format);
(void)Features;
StringRef OT = Format == BinaryFormat::PTX ? "nvptx64-nvidia-cuda"
: "amdgcn-amd-amdhsa";
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ), OT);
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_Xsycl_backend_EQ), OT);
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_offload_arch_EQ), CPU);
}

ArgStringList ASL;
for_each(DAL, [&DAL, &ASL](Arg *A) { A->render(DAL, ASL); });
for_each(UserArgList,
[&UserArgList, &ASL](Arg *A) { A->render(UserArgList, ASL); });

std::vector<std::string> CommandLine;
CommandLine.reserve(ASL.size() + 2);
CommandLine.emplace_back(ClangXXExe);
transform(ASL, std::back_inserter(CommandLine),
[](const char *AS) { return std::string{AS}; });
CommandLine.emplace_back(SourceFilePath);
return CommandLine;
}

public:
static SYCLToolchain &instance() {
static SYCLToolchain Instance;
return Instance;
}

bool run(const std::vector<std::string> &CommandLine,
FrontendAction &FEAction,
bool run(const InputArgList &UserArgList, BinaryFormat Format,
const char *SourceFilePath, FrontendAction &FEAction,
IntrusiveRefCntPtr<FileSystem> FSOverlay = nullptr,
DiagnosticConsumer *DiagConsumer = nullptr) {
std::vector<std::string> CommandLine =
createCommandLine(UserArgList, Format, SourceFilePath);

auto FS = llvm::makeIntrusiveRefCnt<llvm::vfs::OverlayFileSystem>(
llvm::vfs::getRealFileSystem());
FS->pushOverlay(ToolchainFS);
Expand Down Expand Up @@ -226,42 +265,6 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {

} // anonymous namespace

static std::vector<std::string>
createCommandLine(const InputArgList &UserArgList, BinaryFormat Format,
std::string_view SourceFilePath) {
DerivedArgList DAL{UserArgList};
const auto &OptTable = getDriverOptTable();
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_fsycl_device_only));
// User args may contain options not intended for the frontend, but we can't
// claim them here to tell the driver they're used later. Hence, suppress the
// unused argument warning.
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_Qunused_arguments));

if (Format == BinaryFormat::PTX || Format == BinaryFormat::AMDGCN) {
auto [CPU, Features] =
Translator::getTargetCPUAndFeatureAttrs(nullptr, "", Format);
(void)Features;
StringRef OT = Format == BinaryFormat::PTX ? "nvptx64-nvidia-cuda"
: "amdgcn-amd-amdhsa";
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_fsycl_targets_EQ), OT);
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_Xsycl_backend_EQ), OT);
DAL.AddJoinedArg(nullptr, OptTable.getOption(OPT_offload_arch_EQ), CPU);
}

ArgStringList ASL;
for_each(DAL, [&DAL, &ASL](Arg *A) { A->render(DAL, ASL); });
for_each(UserArgList,
[&UserArgList, &ASL](Arg *A) { A->render(UserArgList, ASL); });

std::vector<std::string> CommandLine;
CommandLine.reserve(ASL.size() + 2);
CommandLine.emplace_back(SYCLToolchain::instance().getClangXXExe());
transform(ASL, std::back_inserter(CommandLine),
[](const char *AS) { return std::string{AS}; });
CommandLine.emplace_back(SourceFilePath);
return CommandLine;
}

static llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem>
getInMemoryFS(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles) {
auto InMemoryFS = llvm::makeIntrusiveRefCnt<llvm::vfs::InMemoryFileSystem>();
Expand All @@ -283,9 +286,6 @@ Expected<std::string> jit_compiler::calculateHash(
const InputArgList &UserArgList, BinaryFormat Format) {
TimeTraceScope TTS{"calculateHash"};

std::vector<std::string> CommandLine =
createCommandLine(UserArgList, Format, SourceFile.Path);

class HashPreprocessedAction : public PreprocessorFrontendAction {
protected:
void ExecuteAction() override {
Expand Down Expand Up @@ -315,7 +315,8 @@ Expected<std::string> jit_compiler::calculateHash(
BLAKE3 Hasher;
HashPreprocessedAction HashAction{Hasher};

if (!SYCLToolchain::instance().run(CommandLine, HashAction,
if (!SYCLToolchain::instance().run(UserArgList, Format, SourceFile.Path,
HashAction,
getInMemoryFS(SourceFile, IncludeFiles)))
return createStringError("Calculating source hash failed");

Expand All @@ -324,10 +325,11 @@ Expected<std::string> jit_compiler::calculateHash(
ArrayRef<uint8_t>{reinterpret_cast<const uint8_t *>(&Format),
reinterpret_cast<const uint8_t *>(&Format + 1)});

// Last argument is "rtc_N.cpp" source file name which is never the same,
// ignore it:
for (auto &Opt : drop_end(CommandLine, 1))
Hasher.update(Opt);
for (Arg *Opt : UserArgList) {
Hasher.update(Opt->getSpelling());
for (const char *Val : Opt->getValues())
Hasher.update(Val);
}

std::string EncodedHash = encodeBase64(Hasher.result());

Expand All @@ -346,9 +348,9 @@ Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
DiagnosticOptions DiagOpts;
ClangDiagnosticWrapper Wrapper(BuildLog, &DiagOpts);

if (SYCLToolchain::instance().run(
createCommandLine(UserArgList, Format, SourceFile.Path), ELOA,
getInMemoryFS(SourceFile, IncludeFiles), Wrapper.consumer())) {
if (SYCLToolchain::instance().run(UserArgList, Format, SourceFile.Path, ELOA,
getInMemoryFS(SourceFile, IncludeFiles),
Wrapper.consumer())) {
return ELOA.takeModule();
} else {
return createStringError(BuildLog);
Expand Down