-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[Mosaic GPU] Use PTX ISA version = min(ptxas, LLVM) #28595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mosaic GPU] Use PTX ISA version = min(ptxas, LLVM) #28595
Conversation
jaxlib/mosaic/gpu/custom_call.cc
Outdated
| } | ||
| stdout += buf; | ||
| } | ||
| if (close(stdout_pipe[0]) == -1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use absl::Cleanup to close the pipes. This, or the close above might not get executed if something else fails along the way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jaxlib/mosaic/gpu/custom_call.cc
Outdated
| return absl::InternalError( | ||
| absl::StrCat("Failed to read from pipe: ", strerror(errno))); | ||
| } | ||
| stdout += buf; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks sketchy to me? The addition will probably look for the null byte to decide where to stop the concatenation but there might be no null byte in there (read does not insert one). Use stdout.append(buf, status) (and perhaps rename status to bytes_read)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I'm more used to writing C for systems code like that and forming the string explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
jaxlib/mosaic/gpu/custom_call.cc
Outdated
| return absl::InternalError("Failed to wait for CUDA tool invocation"); | ||
| } | ||
| if (status != 0) return absl::InternalError("CUDA tool failed"); | ||
| if (status != 0) return absl::InternalError(stdout); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please prepend a failure message to stdout in case the tool didn't write something helpful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jaxlib/mosaic/gpu/target.cc
Outdated
|
|
||
| const std::string sm = sm_arch_specific ? sm_arch_specific : sm_base; | ||
|
|
||
| absl::StatusOr<std::string> GetPtxIsaVersion(int major, int minor) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is dead now, isn't it? Perhaps delete it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry this PR is a bit raw, I'll refactor it a bit better + add more inline docs. This function does play a role here, it takes an input PTX ISA version specified as major.minor and returns min(major.minor, <latest PTX ISA supported by LLVM>). This is to cover cases when ptxas is more recent and LLVM hasn't caught up to it yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored the function for clarity.
8413c48 to
ff213dc
Compare
jaxlib/mosaic/gpu/custom_call.cc
Outdated
| absl::StatusOr<int> GetLatestPtxasPtxIsaVersion() { | ||
| std::vector<const char*> ptxas_args = {"ptxas", "--input-as-string", | ||
| ".version 99.99", nullptr}; | ||
| auto result = RunCUDATool("ptxas", ptxas_args); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Avoid using auto when the type isn't obvious. Something like
auto status = RunCUDATool("ptxas", ptxas_args).status();
Might be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, fixed.
| absl::StrFormat("Failed to parse PTX ISA minor version, expected a " | ||
| "parsable integer, instead got: %s", | ||
| major_minor[1])); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to fail loudly here if minor is ever >= 10.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
jaxlib/mosaic/gpu/custom_call.cc
Outdated
|
|
||
| absl::StatusOr<std::string> GetPtxIsaVersion() { | ||
| int ptxas_latest_version; | ||
| TF_ASSIGN_OR_RETURN(ptxas_latest_version, GetLatestPtxasPtxIsaVersion()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: why not just put the declaration in here?
TF_ASSIGN_OR_RETURN(int ptxas_latest_version, GetLatestPtxasPtxIsaVersion());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, thank you.
jaxlib/mosaic/gpu/target.cc
Outdated
| for (const llvm::SubtargetFeatureKV& feature : | ||
| subtarget_info->getEnabledProcessorFeatures()) { | ||
| subtarget_info->getAllProcessorFeatures()) { | ||
| if (absl::StartsWith(feature.Key, "ptx")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing really wrong here, but it's not super idiomatic:
absl::string_view key = feature.Key;
if (absl::ConsumePrefix(&key, "ptx")) {
// No need for version_string
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, didn't know about ConsumePrefix.
| continue; | ||
| } | ||
| // Dump SASS. | ||
| std::cout << *result << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you missed an explicit print for ptxas output above (if dump_ptxas is on)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed that (see last force push) and squashed the commits.
b8d2305 to
c1e8f25
Compare
Separate commits to make it easier to review.