Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Actually migrate to StableHLO #2177

Closed
wants to merge 1 commit into from

Conversation

makslevental
Copy link
Collaborator

@makslevental makslevental commented May 26, 2023

Not sure what the story is but after #1840 I believe we should be pointing to https://github.com/openxla/stablehlo/ instead of https://github.com/tensorflow/mlir-hlo, which is no longer being updated.

@makslevental
Copy link
Collaborator Author

makslevental commented May 26, 2023

So I ran this locally with TORCH_MLIR_ENABLE_STABLEHLO=OFF (whoops, duh) and thought it would be a simple thing. But it's not - the blocker seems to be MLIRBufferTransforms which probably can't/won't/shouldn't go to stablehlo. So now I'm just wondering what the plan is since mlir-hlo is being sunsetted. I.e., we're effectively stuck on an LLVM commit.

cc @powderluv

EDIT: Okay I sounded the alarm too soon - looks like at least LLVM gets bumped regularly (and it just hasn't caught up yet). Still I'm curious what the plan here is.

@makslevental makslevental force-pushed the migratetostablehlo branch 2 times, most recently from 2bad4cf to 78c5c1f Compare May 26, 2023 19:20
@powderluv
Copy link
Collaborator

I think there was some other reason to use the old repo (which everytime we bump topples our git with index.lock issues). @ramiro050 do you know why we need the tf/mhlo repo ?

@ashay
Copy link
Collaborator

ashay commented May 26, 2023

Huh.. I was unaware that MHLO was being sunset. Thanks for the link!

The reason we were pointing to the MHLO repo was because we used the MHLO passes to lower from MHLO to Linalg, after converting from StableHlo to MHLO (see pass pipeline here). Now that MHLO is going away, does StableHlo include the lowering-to-linalg pass?

@ashay
Copy link
Collaborator

ashay commented May 26, 2023

@burmako
Copy link

burmako commented May 28, 2023

Hi folks! We are still updating the MLIR-HLO repository, but indeed we are planning to sunset it in the near future. We have not yet discussed the exact date, but will share both the exact details and the proposed date on Discourse.

"The reason we were pointing to the MHLO repo was because we used the MHLO passes to lower from MHLO to Linalg". Meanwhile, I think you may be interested in the work that IREE folks are doing to switch from the MLIR-HLO repository to the StableHLO repository (cc @kuhar who's driving this work). Here's the tracking issue: iree-org/iree#12678, and here's the OpenXLA RFC about this which goes into a bit more detail: https://groups.google.com/a/openxla.org/g/openxla-discuss/c/EWuUbyL5n3c.

@powderluv
Copy link
Collaborator

Wait is https://github.com/openxla/xla/blob/main/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td#L107 the canonical spot for all mhlo->linalg lowerings ? I am not sure how viable it would be to add lowerings to XLA for Pytorch->StableHLO->LinAlg .

@burmako where do we add any new stablehlo->linalg lowerings ? Is it in the XLA repo ?

@ashay @qingyunqu @ZihengJiang @Vremold @tanyokwok should we fully switch to stablhlo and remove mhlo as a submodule ? (the mhlo repo still topples over git with the index.lock issue every time it is updated)

@ashay
Copy link
Collaborator

ashay commented May 28, 2023

Oh hey Eugene, thanks for the links to the IREE issue! Minor question, what's the reason for IREE making a copy of the MHLO source files instead of, say, a submodule dependency? Would it be a bad idea for Torch-MLIR to add a submodule dependency on XLA (and build the MHLO libraries using CMake from https://github.com/openxla/xla/tree/main/xla/mlir_hlo)?

More importantly, I noticed that XLA integrates StableHlo in what looks like a patch applied to the last update, but this presents two challenges. First, to build Torch-MLIR, we'd need the StableHLO commit hash that XLA uses, and while it can be parsed out of workspace.bzl, I wonder if you'd be open to saving it in a plaintext file, similar to how StableHLO points to the LLVM commit.

The other challenge would be in sharing green commits between Torch-MLIR and ONNX-MLIR (and potentially, Triton?) so that all involved projects can use a single LLVM commit. While we currently create green commit branches in the MHLO repo, we'd probably need to setup a similar system for StableHLO and/or XLA, depending on whether we link XLA and StableHLO using commit hashes versus copying the StableHLO source into XLA. Let me know your thoughts on either approach (or other approaches).

Anush, yeah, it makes sense to drop the MHLO submodule, but I wonder if we can just depend on XLA, which would point to the appropriate commit hash of StableHLO instead of adding submodule dependencies on both.

@powderluv
Copy link
Collaborator

I think an implicit goal of StableHLO was to not depend on XLA (?) and be standalone hence my earlier question on where do we put these passes. I would rather copy / fork if we have to than take an XLA dep because a few passes live there.

Also @stellaraccident fyi

@kuhar
Copy link
Member

kuhar commented May 28, 2023

I think an implicit goal of StableHLO was to not depend on XLA (?) and be standalone hence my earlier question on where do we put these passes. I would rather copy / fork if we have to than take an XLA dep because a few passes live there.

In IREE we ported the code to understand what it would take to lower directly to linalg + other dialects. This allowed us to iterate on this quickly and after a little over a month of work we are almost done.

Along the way we learned that it takes more effort than we anticipated: the C++ api is close but not exactly the same as MHLO, StableHLO is doesn't provide folders and canonicalzers, there's some utility code that needed to be ported, the printed form is often different which requires modifications to tests, CHLO needs to be lowered to MHLO, etc.

IMO it would make sense to upstream the work IREE's done to StableHLO, provided that this aligns with the direction of the StableHLO project itself. It would be a shame to duplicate porting this effort across other projects.

@burmako
Copy link

burmako commented May 28, 2023

"where do we add any new stablehlo->linalg lowerings ? Is it in the XLA repo ?". I think that these lowerings could live in the StableHLO repository, and I've just written this up as an OpenXLA RFC. Thank you, @kuhar, for making this RFC possible through your excellent work!

"I think an implicit goal of StableHLO was to not depend on XLA (?) and be standalone hence my earlier question on where do we put these passes. I would rather copy / fork if we have to than take an XLA dep because a few passes live there". The RFC goes into details about goals of StableHLO, but in a nutshell: yes, the goal of StableHLO is to be framework- and compiler-agnostic. This is why I think it can be a great fit for this purpose - it's relatively lightweight to depend on, supports both CMake and Bazel builds and doesn't involve heavyweight CI.

"The other challenge would be in sharing green commits between Torch-MLIR and ONNX-MLIR (and potentially, Triton?) so that all involved projects can use a single LLVM commit". Happy to discuss the logistics of this! Although, based on some recent discussions, I think that Stella may already have some opinion on how this could be done, so let's ask her first: @stellaraccident.

@powderluv
Copy link
Collaborator

Folks - should we now switch over to stableHLO proper and drop mhlo ?

@stellaraccident
Copy link
Collaborator

Folks - should we now switch over to stableHLO proper and drop mhlo ?

Those passes are effectively a testing-only dep, right? (You can convert to stablehlo without them but verifying on the ref backend needs them)

If there isn't consensus from the stablehlo side, I would just fork whatever is necessary into the torch-mlir repo and drop the dep on mhlo. Ideally, they can just go into stablehlo, even as an optional contrib kind of thing.

In the Turbine work, I've already had to fork around and not use stablehlo at all because there isn't a reasonable path to use that part of the codebase without unusable deps.

@stellaraccident
Copy link
Collaborator

I'll also say that I haven't experienced much trouble with sharing commits if we leave projects like torch-mlir and stablehlo pretty lightweight (just dialects and high level conversion passes)... The stuff doesn't drift that much at that level. I regularly use both projects at other commits (but only by excluding the testing infra). It may be worthwhile to further isolate testing and lowering in the projects so that things stay pretty portable across versions.

@qingyunqu
Copy link
Collaborator

Folks - should we now switch over to stableHLO proper and drop mhlo ?

@powderluv If StableHLO-to-Linalg has been completed, I think it's feasible to migrate to it. Otherwise we will lack some pass to do more strong support (ex. hlo-to-memeref, see #2154).

@burmako
Copy link

burmako commented Jul 19, 2023

@qingyunqu To the best of my knowledge, the StableHLO-to-Linalg work has been completed, and I was planning to migrate it from IREE into the StableHLO repository later this month. If this is time-sensitive on your side, I can hurry up on my side as well.

@stellaraccident "In the Turbine work, I've already had to fork around and not use stablehlo at all because there isn't a reasonable path to use that part of the codebase without unusable deps". Can you tell more about this? The StableHLO repository was designed to specifically avoid this problem of the MLIR-HLO repository, so if you can elaborate on what dependencies you find unusable, we may be able to address that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants