diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 3ea2b4a2a36a..2fbe82031c6b 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -28,18 +28,27 @@ py_library( cc_library( name = "passes", - srcs = ["launch_lowering.cc"], - hdrs = ["launch_lowering.h"], + srcs = [ + "launch_lowering.cc", + "passes.cc", + ], + hdrs = [ + "launch_lowering.h", + "pass_boilerplate.h", + "passes.h", + ], deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:DataLayoutInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@com_google_absl//absl/log", + "@llvm-project//mlir:TransformUtils", ], ) @@ -97,29 +106,38 @@ cc_library( ":passes", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:ExecutionEngine", "@llvm-project//mlir:ExecutionEngineUtils", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToLLVMIRTranslation", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MathToLLVM", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefToLLVM", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:NVGPUDialect", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:NVVMTarget", + "@llvm-project//mlir:NVVMToLLVM", "@llvm-project//mlir:NVVMToLLVMIRTranslation", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBToLLVM", "@llvm-project//mlir:VectorDialect", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 1c851d331a78..47072e659a76 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -37,7 +37,16 @@ limitations under the License. #include "llvm/include/llvm/ADT/SmallVector.h" #include "llvm/include/llvm/Support/CodeGen.h" #include "llvm/include/llvm/Support/TargetSelect.h" +#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/include/mlir/Conversion/Passes.h" +#include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" @@ -67,6 +76,7 @@ limitations under the License. #include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/include/mlir/Transforms/Passes.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" +#include "jaxlib/mosaic/gpu/passes.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" @@ -100,6 +110,7 @@ mlir::FailureOr GetPassPipeline( mlir::memref::registerMemRefPasses(); mlir::registerGPUPasses(); mosaic::gpu::registerGpuLaunchLoweringPass(); + mosaic::gpu::registerConvertGpuToLLVMPass(); return true; }(); (void)register_once; @@ -123,7 +134,7 @@ mlir::FailureOr GetPassPipeline( gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true}), gpu.module(cse), gpu.module(reconcile-unrealized-casts), - gpu-to-llvm{gpu-binary-annotation=gpu.binary use-bare-pointers-for-host=false use-bare-pointers-for-kernels=false}, + mosaic-convert-gpu-to-llvm, gpu-module-to-binary{format=)" + mlir::gpu::stringifyCompilationTarget(target).str() + R"(}, convert-math-to-llvm{approximate-log1p=true}, @@ -152,6 +163,16 @@ void InitContext(mlir::MLIRContext* context) { mlir::scf::SCFDialect, mlir::vector::VectorDialect, mlir::gpu::GPUDialect, mlir::nvgpu::NVGPUDialect, mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>(); + mlir::registerConvertNVVMToLLVMInterface(registry); + mlir::registerConvertComplexToLLVMInterface(registry); + mlir::registerConvertMemRefToLLVMInterface(registry); + mlir::registerConvertMathToLLVMInterface(registry); + mlir::registerConvertFuncToLLVMInterface(registry); + mlir::index::registerConvertIndexToLLVMInterface(registry); + mlir::cf::registerConvertControlFlowToLLVMInterface(registry); + mlir::ub::registerConvertUBToLLVMInterface(registry); // Arith needs this + mlir::arith::registerConvertArithToLLVMInterface(registry); + mlir::registerFinalizeMemRefToLLVMConversionPass(); mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry); mlir::registerBuiltinDialectTranslation(registry); diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc index 0c66a0325581..f6c0237ffecb 100644 --- a/jaxlib/mosaic/gpu/launch_lowering.cc +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -171,7 +171,7 @@ void buildInitFunction(mlir::OpBuilder &module_builder, used_smem = builder.create( loc, i32, builder.getI32IntegerAttr( - mlir::cast(const_smem.getValue()).getSInt())); + mlir::cast(const_smem.getValue()).getInt())); } } mlir::Value kernel_handle = diff --git a/jaxlib/mosaic/gpu/pass_boilerplate.h b/jaxlib/mosaic/gpu/pass_boilerplate.h new file mode 100644 index 000000000000..b0241fca97ab --- /dev/null +++ b/jaxlib/mosaic/gpu/pass_boilerplate.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ +#define JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ + +#include "mlir/include/mlir/IR/DialectRegistry.h" +#include "mlir/include/mlir/Pass/Pass.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Support/TypeID.h" +namespace mosaic { +namespace gpu { + +template +class Pass : public ::mlir::OperationPass { + public: + Pass() : ::mlir::OperationPass(::mlir::TypeID::get()) {} + Pass(const Pass &other) : ::mlir::OperationPass(other) {} + Pass &operator=(const Pass &) = delete; + Pass(Pass &&) = delete; + Pass &operator=(Pass &&) = delete; + ~Pass() = default; + + static constexpr ::llvm::StringLiteral getArgumentName() { + return ::llvm::StringLiteral(Derived::kArgumentName); + } + ::llvm::StringRef getArgument() const override { return getArgumentName(); } + ::llvm::StringRef getDescription() const override { return ""; } + static constexpr ::llvm::StringLiteral getPassName() { + return ::llvm::StringLiteral(Derived::kPassName); + } + ::llvm::StringRef getName() const override { return getPassName(); } + static bool classof(const ::mlir::Pass *pass) { + return pass->getTypeID() == ::mlir::TypeID::get(); + } + std::unique_ptr<::mlir::Pass> clonePass() const override { + return std::make_unique(*static_cast(this)); + } + void getDependentDialects(::mlir::DialectRegistry ®istry) const override {} + + private: + using This = + Pass; // Can't have a comma in the macro instantiation + + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(This) +}; + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc new file mode 100644 index 000000000000..a9a97275a346 --- /dev/null +++ b/jaxlib/mosaic/gpu/passes.cc @@ -0,0 +1,81 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/passes.h" +#include +#include +#include + +#include "llvm/include/llvm/ADT/StringRef.h" +#include "llvm/include/llvm/Support/Debug.h" +#include "mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/IR/SymbolTable.h" +#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Transforms/DialectConversion.h" +#include "jaxlib/mosaic/gpu/pass_boilerplate.h" + +namespace mosaic { +namespace gpu { + +namespace { + +class ConvertGpuToLLVMPass + : public mosaic::gpu::Pass { + public: + using mosaic::gpu::Pass::Pass; + static constexpr llvm::StringLiteral kArgumentName = + "mosaic-convert-gpu-to-llvm"; + static constexpr llvm::StringLiteral kPassName = "ConvertGpuToLLVMPass"; + + void runOnOperation() override { + llvm::DebugFlag = true; + mlir::MLIRContext *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + mlir::LLVMTypeConverter converter(ctx); + mlir::ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](mlir::gpu::LaunchFuncOp op) -> bool { + return converter.isLegal(op->getOperandTypes()) && + converter.isLegal(op->getResultTypes()); + }); + auto symtab = mlir::SymbolTable(getOperation()); + mlir::populateGpuToLLVMConversionPatterns(converter, patterns, "gpu.binary", + false, &symtab); + if (mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)) + .failed()) { + signalPassFailure(); + } + llvm::DebugFlag = false; + } +}; + +} // namespace + +void registerConvertGpuToLLVMPass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return std::make_unique(); + }); +} + +} // namespace gpu +} // namespace mosaic \ No newline at end of file diff --git a/jaxlib/mosaic/gpu/passes.h b/jaxlib/mosaic/gpu/passes.h new file mode 100644 index 000000000000..bf7a804ee217 --- /dev/null +++ b/jaxlib/mosaic/gpu/passes.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_ +#define JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_ + +namespace mosaic { +namespace gpu { + +void registerConvertGpuToLLVMPass(); + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_ diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 76e907bdc911..91bb7b5f5195 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -170,7 +170,7 @@ def setUp(self): class TestUtilTest(TestCase): - def test_copy(self): + def test_copy_basic(self): def kernel(ctx, src, dst, _): copy(src, dst) x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3)