-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MHLO] Init Torch to MHLO conversion. #1025
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
[submodule "external/llvm-project"] | ||
path = externals/llvm-project | ||
url = https://github.com/llvm/llvm-project.git | ||
[submodule "externals/mlir-hlo"] | ||
path = externals/mlir-hlo | ||
url = https://github.com/tensorflow/mlir-hlo.git |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,7 +41,13 @@ torch_mlir_add_llvm_external_project( | |
TORCH_MLIR_DIALECTS | ||
${CMAKE_CURRENT_SOURCE_DIR}/externals/llvm-external-projects/torch-mlir-dialects) | ||
|
||
torch_mlir_add_llvm_external_project( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a top level TORCH_MLIR_MHLO (name can be anything @silvasean any suggestions ?) CMake flag that can enable / disable the MHLO backend ? This can be the big hammer in case we have to disable it for any reason (broken on macOS etc) . |
||
mlir-hlo | ||
MLIR_HLO | ||
${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo) | ||
|
||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) | ||
message(STATUS "Torch-MLIR out-of-tree build.") | ||
# Out-of-tree build | ||
|
||
#------------------------------------------------------------------------------- | ||
|
@@ -81,7 +87,16 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) | |
set(TORCH-MLIR_BUILT_STANDALONE 1) | ||
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}") | ||
add_subdirectory(externals/llvm-external-projects/torch-mlir-dialects) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wrap in same TORCH_MLIR_MHLO (?) |
||
set(MHLO_BUILD_EMBEDDED ON) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is |
||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo | ||
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo | ||
EXCLUDE_FROM_ALL) | ||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include) | ||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo/include) | ||
include_directories(${CMAKE_CURRENT_BINARY_DIR}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the whole CMAKE_CURRENT_BINARY_DIR required since we are adding it globally ? |
||
else() | ||
message(STATUS "Torch-MLIR in-tree build.") | ||
# In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir | ||
# FIXME: This should really be inherited from the LLVM tree. In particular, | ||
# it's going to change when cross-compiling. | ||
|
@@ -95,6 +110,9 @@ else() | |
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include) | ||
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include) | ||
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") | ||
# since mhlo didn't set INTERFACE_DIRECTORIES for their target, we need include mhlo directories globally | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably good to file and issue upstream so they are aware of this and maybe if it is easy add a PR upstream. |
||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include) | ||
include_directories(${LLVM_BINARY_DIR}/tools/mlir_hlo/include) | ||
endif() | ||
|
||
set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch | ||
import torch.utils._pytree as pytree | ||
|
||
from functorch.compile import aot_module, aot_function | ||
from functorch.compile import nop | ||
from functorch.compile import get_decompositions | ||
|
||
import torch_mlir | ||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend | ||
|
||
import transformers | ||
from transformers import BertForMaskedLM | ||
|
||
pytree._register_pytree_node(transformers.modeling_outputs.MaskedLMOutput, lambda x: ( | ||
[x.logits], None), lambda values, _: transformers.modeling_outputs.MaskedLMOutput(logits=values[0])) | ||
|
||
model = BertForMaskedLM.from_pretrained('prajjwal1/bert-tiny') | ||
|
||
BATCH_SIZE = 2 | ||
SEQ_LEN = 128 | ||
data = { | ||
'input_ids': torch.randint(30522, (BATCH_SIZE, SEQ_LEN)), | ||
# 'labels': torch.randint(30522, (BATCH_SIZE, SEQ_LEN)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the comment code? |
||
} | ||
output = model(**data) | ||
|
||
|
||
def mlir_compile(fx_g, inputs): | ||
for node in fx_g.graph.nodes: | ||
# TODO(byronyi): aten::t is not supported in DecomposeComplexOps | ||
if node.target == torch.ops.aten.t: | ||
fx_g.graph.inserting_after(node) | ||
new_node = fx_g.graph.call_function( | ||
torch.ops.aten.transpose, args=(node.args[0], 0, 1)) | ||
node.replace_all_uses_with(new_node) | ||
fx_g.graph.erase_node(node) | ||
# TODO(byronyi): fx_g returning list breaks DecomposeComplexOps | ||
elif node.op == 'output': | ||
outputs = node.args | ||
num_outputs = len(node.args) | ||
node.args = (tuple(outputs) if num_outputs > 1 else outputs[0]) | ||
fx_g.graph.lint() | ||
fx_g.recompile() | ||
|
||
module = torch_mlir.compile( | ||
fx_g, inputs, output_type=torch_mlir.OutputType.MHLO) | ||
fname = "bert_forward.mlir" | ||
with open(fname, "w+") as fout: | ||
fout.write(str(module)) | ||
print("MHLO module has been save to {}".format(fname)) | ||
print("MHLO execution is not support yet. Stopped.") | ||
exit(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: The codes following will not be reached, and should be removed. |
||
|
||
backend = refbackend.RefBackendLinalgOnTensorsBackend() | ||
compiled = backend.compile(module) | ||
jit_module = backend.load(compiled) | ||
|
||
graph = torch.fx.Graph() | ||
args = [graph.placeholder(n.name) | ||
for n in fx_g.graph.nodes if n.op == 'placeholder'] | ||
|
||
def execute(*args): | ||
rets = jit_module.forward(*[t.numpy() for t in args]) | ||
return tuple([torch.from_numpy(t) for t in rets]) | ||
graph.output(graph.call_function(execute, tuple(args))) | ||
graph.lint() | ||
return torch.fx.GraphModule(fx_g, graph) | ||
|
||
|
||
decompositions = get_decompositions([ | ||
torch.ops.aten.embedding, | ||
]) | ||
compiled_model = aot_function( | ||
model, fw_compiler=mlir_compile, bw_compiler=nop, decompositions=decompositions) | ||
compiled_output = compiled_model(**data) | ||
|
||
for k in output: | ||
torch.testing.assert_close( | ||
output[k], compiled_output[k], atol=1e-4, rtol=1e-4) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch | ||
|
||
from functorch.compile import aot_module | ||
from functorch.compile import get_decompositions | ||
|
||
import torch_mlir | ||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend | ||
|
||
_CHECK_MHLO = True | ||
|
||
model = torch.nn.Linear(3, 4) | ||
|
||
# TODO: none in output breaks AdjustCallingConventions | ||
data = torch.rand((2, 3)).requires_grad_() | ||
output = model(data) | ||
|
||
|
||
def mlir_compile(fx_g, inputs): | ||
for node in fx_g.graph.nodes: | ||
# TODO(byronyi): aten::t is not supported in DecomposeComplexOps | ||
if node.target == torch.ops.aten.t: | ||
fx_g.graph.inserting_after(node) | ||
new_node = fx_g.graph.call_function( | ||
torch.ops.aten.transpose, args=(node.args[0], 0, 1)) | ||
node.replace_all_uses_with(new_node) | ||
fx_g.graph.erase_node(node) | ||
# TODO(byronyi): fx_g returning list breaks DecomposeComplexOps | ||
elif node.op == 'output': | ||
node.args = (tuple(node.args[0]),) | ||
fx_g.graph.lint() | ||
fx_g.recompile() | ||
|
||
if _CHECK_MHLO: | ||
module = torch_mlir.compile( | ||
fx_g, inputs, output_type=torch_mlir.OutputType.MHLO) | ||
print(module) | ||
exit(0) | ||
module = torch_mlir.compile( | ||
fx_g, inputs, output_type=torch_mlir.OutputType.LINALG_ON_TENSORS) | ||
backend = refbackend.RefBackendLinalgOnTensorsBackend() | ||
compiled = backend.compile(module) | ||
jit_module = backend.load(compiled) | ||
|
||
graph = torch.fx.Graph() | ||
args = [graph.placeholder(n.name) | ||
for n in fx_g.graph.nodes if n.op == 'placeholder'] | ||
|
||
def execute(*args): | ||
rets = jit_module.forward(*[t.numpy() for t in args]) | ||
return tuple([torch.from_numpy(t) for t in rets]) | ||
graph.output(graph.call_function(execute, tuple(args))) | ||
graph.lint() | ||
return torch.fx.GraphModule(fx_g, graph) | ||
|
||
|
||
decompositions = get_decompositions([ | ||
torch.ops.aten.detach, | ||
]) | ||
compiled_model = aot_module( | ||
model, mlir_compile, decompositions=decompositions) | ||
compiled_output = compiled_model(data) | ||
|
||
torch.testing.assert_close(output, compiled_output) | ||
|
||
output.sum().backward() | ||
grads = {k: torch.clone(v.grad) for k, v in model.named_parameters()} | ||
grads['data'] = torch.clone(data.grad) | ||
data.grad.zero_() | ||
model.zero_grad() | ||
|
||
compiled_output.sum().backward() | ||
compiled_grads = {k: torch.clone(v.grad) for k, v in model.named_parameters()} | ||
compiled_grads['data'] = torch.clone(data.grad) | ||
data.grad.zero_() | ||
model.zero_grad() | ||
|
||
for k in grads: | ||
torch.testing.assert_close(grads[k], compiled_grads[k]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import torch | ||
import torchvision.models as models | ||
|
||
from functorch.compile import aot_module | ||
from functorch.compile import nop | ||
from functorch.compile import get_decompositions | ||
|
||
import torch_mlir | ||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend | ||
|
||
model = models.resnet18(pretrained=True) | ||
model.train(False) | ||
|
||
data = torch.randn(1,3,200,200) | ||
output = model(data) | ||
|
||
def mlir_compile(fx_g, inputs): | ||
for node in fx_g.graph.nodes: | ||
# TODO(byronyi): aten::t is not supported in DecomposeComplexOps | ||
if node.target == torch.ops.aten.t: | ||
fx_g.graph.inserting_after(node) | ||
new_node = fx_g.graph.call_function( | ||
torch.ops.aten.transpose, args=(node.args[0], 0, 1)) | ||
node.replace_all_uses_with(new_node) | ||
fx_g.graph.erase_node(node) | ||
# TODO(byronyi): fx_g returning list breaks DecomposeComplexOps | ||
elif node.op == 'output': | ||
node.args = (tuple(node.args[0]),) | ||
fx_g.graph.lint() | ||
fx_g.recompile() | ||
|
||
module = torch_mlir.compile( | ||
fx_g, inputs, output_type=torch_mlir.OutputType.TORCH) | ||
module = torch_mlir.compile( | ||
fx_g, inputs, output_type=torch_mlir.OutputType.MHLO) | ||
with open("./resnet18_functorch_forward_mhlo.mlir", "w", encoding="utf-8") as outf: | ||
outf.write(str(module)) | ||
exit(0) | ||
|
||
module = torch_mlir.compile( | ||
fx_g, inputs, output_type=torch_mlir.OutputType.LINALG_ON_TENSORS) | ||
backend = refbackend.RefBackendLinalgOnTensorsBackend() | ||
compiled = backend.compile(module) | ||
jit_module = backend.load(compiled) | ||
|
||
graph = torch.fx.Graph() | ||
args = [graph.placeholder(n.name) | ||
for n in fx_g.graph.nodes if n.op == 'placeholder'] | ||
|
||
def execute(*args): | ||
rets = jit_module.forward(*[t.numpy() for t in args]) | ||
return tuple([torch.from_numpy(t) for t in rets]) | ||
graph.output(graph.call_function(execute, tuple(args))) | ||
graph.lint() | ||
return torch.fx.GraphModule(fx_g, graph) | ||
|
||
|
||
decompositions = get_decompositions([ | ||
torch.ops.aten.native_batch_norm, | ||
torch.ops.aten.mean.dim | ||
]) | ||
compiled_model = aot_module( | ||
model, mlir_compile, bw_compiler=nop, decompositions=decompositions) | ||
compiled_output = compiled_model(data) | ||
|
||
# torch.testing.assert_close(output, compiled_output) | ||
|
||
# output.sum().backward() | ||
# grads = {k: torch.clone(v.grad) for k, v in model.named_parameters()} | ||
# grads['data'] = torch.clone(data.grad) | ||
# data.grad.zero_() | ||
# model.zero_grad() | ||
|
||
# compiled_output.sum().backward() | ||
# compiled_grads = {k: torch.clone(v.grad) for k, v in model.named_parameters()} | ||
# compiled_grads['data'] = torch.clone(data.grad) | ||
# data.grad.zero_() | ||
# model.zero_grad() | ||
|
||
# for k in grads: | ||
# torch.testing.assert_close(grads[k], compiled_grads[k]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# Also available under a BSD-style license. See LICENSE. | ||
|
||
from PIL import Image | ||
import requests | ||
import torch | ||
import torch.nn as nn | ||
import torchvision.models as models | ||
from torchvision import transforms | ||
|
||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder | ||
|
||
from torch_mlir.passmanager import PassManager | ||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend | ||
|
||
|
||
mb = ModuleBuilder() | ||
|
||
def predictions(torch_func, jit_func, data, output): | ||
golden_prediction = torch_func(data) | ||
print("PyTorch prediction") | ||
print(golden_prediction) | ||
prediction = torch.from_numpy(jit_func(data.numpy())) | ||
print("torch-mlir prediction") | ||
print(prediction) | ||
|
||
|
||
class ModelModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.fc = nn.Linear(64, 10) | ||
self.train(False) | ||
|
||
def forward(self, data): | ||
return self.fc(data) | ||
|
||
|
||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.s = ModelModule() | ||
|
||
def forward(self, x): | ||
return self.s.forward(x) | ||
|
||
|
||
data = torch.rand(4, 64) | ||
output = ModelModule().forward(data) | ||
|
||
test_module = TestModule() | ||
class_annotator = ClassAnnotator() | ||
recursivescriptmodule = torch.jit.script(test_module) | ||
torch.jit.save(recursivescriptmodule, "/tmp/foo.pt") | ||
|
||
class_annotator.exportNone(recursivescriptmodule._c._type()) | ||
class_annotator.exportPath(recursivescriptmodule._c._type(), ["forward"]) | ||
class_annotator.annotateArgs( | ||
recursivescriptmodule._c._type(), | ||
["forward"], | ||
[ | ||
None, | ||
([4, 64], torch.float32, True), | ||
], | ||
) | ||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. | ||
mb.import_module(recursivescriptmodule._c, class_annotator) | ||
|
||
backend = refbackend.RefBackendLinalgOnTensorsBackend() | ||
with mb.module.context: | ||
pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline') | ||
pm.run(mb.module) | ||
|
||
compiled = backend.compile(mb.module) | ||
jit_module = backend.load(compiled) | ||
|
||
predictions(test_module.forward, jit_module.forward, data, output) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -125,4 +125,12 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { | |
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()"; | ||
} | ||
|
||
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> { | ||
let summary = "Convert Torch ops to MHLO ops"; | ||
let description = [{ | ||
Convert ATen ops to mhlo ops. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Convert Torch ops ... ? Seems not only aten ops in this pass. |
||
}]; | ||
let constructor = "mlir::torch::createConvertTorchToMhloPass()"; | ||
} | ||
|
||
#endif // TORCHMLIR_CONVERSION_PASSES |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably should land the llvm update and even mhlo as submodules first.