Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MHLO] Init Torch to MHLO conversion. #1025

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
@@ -1,3 +1,6 @@
[submodule "external/llvm-project"]
path = externals/llvm-project
url = https://github.com/llvm/llvm-project.git
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably should land the llvm update and even mhlo as submodules first.

[submodule "externals/mlir-hlo"]
path = externals/mlir-hlo
url = https://github.com/tensorflow/mlir-hlo.git
18 changes: 18 additions & 0 deletions CMakeLists.txt
Expand Up @@ -41,7 +41,13 @@ torch_mlir_add_llvm_external_project(
TORCH_MLIR_DIALECTS
${CMAKE_CURRENT_SOURCE_DIR}/externals/llvm-external-projects/torch-mlir-dialects)

torch_mlir_add_llvm_external_project(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a top level TORCH_MLIR_MHLO (name can be anything @silvasean any suggestions ?) CMake flag that can enable / disable the MHLO backend ? This can be the big hammer in case we have to disable it for any reason (broken on macOS etc) .

mlir-hlo
MLIR_HLO
${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo)

if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
message(STATUS "Torch-MLIR out-of-tree build.")
# Out-of-tree build

#-------------------------------------------------------------------------------
Expand Down Expand Up @@ -81,7 +87,16 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(TORCH-MLIR_BUILT_STANDALONE 1)
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}")
add_subdirectory(externals/llvm-external-projects/torch-mlir-dialects)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap in same TORCH_MLIR_MHLO (?)

set(MHLO_BUILD_EMBEDDED ON)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is MHLO_BUILD_EMBEDDED meant for? It seems not used.

add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
EXCLUDE_FROM_ALL)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the whole CMAKE_CURRENT_BINARY_DIR required since we are adding it globally ?

else()
message(STATUS "Torch-MLIR in-tree build.")
# In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir
# FIXME: This should really be inherited from the LLVM tree. In particular,
# it's going to change when cross-compiling.
Expand All @@ -95,6 +110,9 @@ else()
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
# since mhlo didn't set INTERFACE_DIRECTORIES for their target, we need include mhlo directories globally
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably good to file and issue upstream so they are aware of this and maybe if it is easy add a PR upstream.

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include)
include_directories(${LLVM_BINARY_DIR}/tools/mlir_hlo/include)
endif()

set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
Expand Down
79 changes: 79 additions & 0 deletions examples/aot_autograd_bert.py
@@ -0,0 +1,79 @@
import torch
import torch.utils._pytree as pytree

from functorch.compile import aot_module, aot_function
from functorch.compile import nop
from functorch.compile import get_decompositions

import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend

import transformers
from transformers import BertForMaskedLM

pytree._register_pytree_node(transformers.modeling_outputs.MaskedLMOutput, lambda x: (
[x.logits], None), lambda values, _: transformers.modeling_outputs.MaskedLMOutput(logits=values[0]))

model = BertForMaskedLM.from_pretrained('prajjwal1/bert-tiny')

BATCH_SIZE = 2
SEQ_LEN = 128
data = {
'input_ids': torch.randint(30522, (BATCH_SIZE, SEQ_LEN)),
# 'labels': torch.randint(30522, (BATCH_SIZE, SEQ_LEN)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the comment code?

}
output = model(**data)


def mlir_compile(fx_g, inputs):
for node in fx_g.graph.nodes:
# TODO(byronyi): aten::t is not supported in DecomposeComplexOps
if node.target == torch.ops.aten.t:
fx_g.graph.inserting_after(node)
new_node = fx_g.graph.call_function(
torch.ops.aten.transpose, args=(node.args[0], 0, 1))
node.replace_all_uses_with(new_node)
fx_g.graph.erase_node(node)
# TODO(byronyi): fx_g returning list breaks DecomposeComplexOps
elif node.op == 'output':
outputs = node.args
num_outputs = len(node.args)
node.args = (tuple(outputs) if num_outputs > 1 else outputs[0])
fx_g.graph.lint()
fx_g.recompile()

module = torch_mlir.compile(
fx_g, inputs, output_type=torch_mlir.OutputType.MHLO)
fname = "bert_forward.mlir"
with open(fname, "w+") as fout:
fout.write(str(module))
print("MHLO module has been save to {}".format(fname))
print("MHLO execution is not support yet. Stopped.")
exit(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The codes following will not be reached, and should be removed.


backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)

graph = torch.fx.Graph()
args = [graph.placeholder(n.name)
for n in fx_g.graph.nodes if n.op == 'placeholder']

def execute(*args):
rets = jit_module.forward(*[t.numpy() for t in args])
return tuple([torch.from_numpy(t) for t in rets])
graph.output(graph.call_function(execute, tuple(args)))
graph.lint()
return torch.fx.GraphModule(fx_g, graph)


decompositions = get_decompositions([
torch.ops.aten.embedding,
])
compiled_model = aot_function(
model, fw_compiler=mlir_compile, bw_compiler=nop, decompositions=decompositions)
compiled_output = compiled_model(**data)

for k in output:
torch.testing.assert_close(
output[k], compiled_output[k], atol=1e-4, rtol=1e-4)
78 changes: 78 additions & 0 deletions examples/aot_autograd_linear.py
@@ -0,0 +1,78 @@
import torch

from functorch.compile import aot_module
from functorch.compile import get_decompositions

import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend

_CHECK_MHLO = True

model = torch.nn.Linear(3, 4)

# TODO: none in output breaks AdjustCallingConventions
data = torch.rand((2, 3)).requires_grad_()
output = model(data)


def mlir_compile(fx_g, inputs):
for node in fx_g.graph.nodes:
# TODO(byronyi): aten::t is not supported in DecomposeComplexOps
if node.target == torch.ops.aten.t:
fx_g.graph.inserting_after(node)
new_node = fx_g.graph.call_function(
torch.ops.aten.transpose, args=(node.args[0], 0, 1))
node.replace_all_uses_with(new_node)
fx_g.graph.erase_node(node)
# TODO(byronyi): fx_g returning list breaks DecomposeComplexOps
elif node.op == 'output':
node.args = (tuple(node.args[0]),)
fx_g.graph.lint()
fx_g.recompile()

if _CHECK_MHLO:
module = torch_mlir.compile(
fx_g, inputs, output_type=torch_mlir.OutputType.MHLO)
print(module)
exit(0)
module = torch_mlir.compile(
fx_g, inputs, output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)

graph = torch.fx.Graph()
args = [graph.placeholder(n.name)
for n in fx_g.graph.nodes if n.op == 'placeholder']

def execute(*args):
rets = jit_module.forward(*[t.numpy() for t in args])
return tuple([torch.from_numpy(t) for t in rets])
graph.output(graph.call_function(execute, tuple(args)))
graph.lint()
return torch.fx.GraphModule(fx_g, graph)


decompositions = get_decompositions([
torch.ops.aten.detach,
])
compiled_model = aot_module(
model, mlir_compile, decompositions=decompositions)
compiled_output = compiled_model(data)

torch.testing.assert_close(output, compiled_output)

output.sum().backward()
grads = {k: torch.clone(v.grad) for k, v in model.named_parameters()}
grads['data'] = torch.clone(data.grad)
data.grad.zero_()
model.zero_grad()

compiled_output.sum().backward()
compiled_grads = {k: torch.clone(v.grad) for k, v in model.named_parameters()}
compiled_grads['data'] = torch.clone(data.grad)
data.grad.zero_()
model.zero_grad()

for k in grads:
torch.testing.assert_close(grads[k], compiled_grads[k])
81 changes: 81 additions & 0 deletions examples/aot_autograd_resnet18.py
@@ -0,0 +1,81 @@
import torch
import torchvision.models as models

from functorch.compile import aot_module
from functorch.compile import nop
from functorch.compile import get_decompositions

import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend

model = models.resnet18(pretrained=True)
model.train(False)

data = torch.randn(1,3,200,200)
output = model(data)

def mlir_compile(fx_g, inputs):
for node in fx_g.graph.nodes:
# TODO(byronyi): aten::t is not supported in DecomposeComplexOps
if node.target == torch.ops.aten.t:
fx_g.graph.inserting_after(node)
new_node = fx_g.graph.call_function(
torch.ops.aten.transpose, args=(node.args[0], 0, 1))
node.replace_all_uses_with(new_node)
fx_g.graph.erase_node(node)
# TODO(byronyi): fx_g returning list breaks DecomposeComplexOps
elif node.op == 'output':
node.args = (tuple(node.args[0]),)
fx_g.graph.lint()
fx_g.recompile()

module = torch_mlir.compile(
fx_g, inputs, output_type=torch_mlir.OutputType.TORCH)
module = torch_mlir.compile(
fx_g, inputs, output_type=torch_mlir.OutputType.MHLO)
with open("./resnet18_functorch_forward_mhlo.mlir", "w", encoding="utf-8") as outf:
outf.write(str(module))
exit(0)

module = torch_mlir.compile(
fx_g, inputs, output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)

graph = torch.fx.Graph()
args = [graph.placeholder(n.name)
for n in fx_g.graph.nodes if n.op == 'placeholder']

def execute(*args):
rets = jit_module.forward(*[t.numpy() for t in args])
return tuple([torch.from_numpy(t) for t in rets])
graph.output(graph.call_function(execute, tuple(args)))
graph.lint()
return torch.fx.GraphModule(fx_g, graph)


decompositions = get_decompositions([
torch.ops.aten.native_batch_norm,
torch.ops.aten.mean.dim
])
compiled_model = aot_module(
model, mlir_compile, bw_compiler=nop, decompositions=decompositions)
compiled_output = compiled_model(data)

# torch.testing.assert_close(output, compiled_output)

# output.sum().backward()
# grads = {k: torch.clone(v.grad) for k, v in model.named_parameters()}
# grads['data'] = torch.clone(data.grad)
# data.grad.zero_()
# model.zero_grad()

# compiled_output.sum().backward()
# compiled_grads = {k: torch.clone(v.grad) for k, v in model.named_parameters()}
# compiled_grads['data'] = torch.clone(data.grad)
# data.grad.zero_()
# model.zero_grad()

# for k in grads:
# torch.testing.assert_close(grads[k], compiled_grads[k])
78 changes: 78 additions & 0 deletions examples/torchscript_matmul.py
@@ -0,0 +1,78 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from PIL import Image
import requests
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms

from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder

from torch_mlir.passmanager import PassManager
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend


mb = ModuleBuilder()

def predictions(torch_func, jit_func, data, output):
golden_prediction = torch_func(data)
print("PyTorch prediction")
print(golden_prediction)
prediction = torch.from_numpy(jit_func(data.numpy()))
print("torch-mlir prediction")
print(prediction)


class ModelModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(64, 10)
self.train(False)

def forward(self, data):
return self.fc(data)


class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.s = ModelModule()

def forward(self, x):
return self.s.forward(x)


data = torch.rand(4, 64)
output = ModelModule().forward(data)

test_module = TestModule()
class_annotator = ClassAnnotator()
recursivescriptmodule = torch.jit.script(test_module)
torch.jit.save(recursivescriptmodule, "/tmp/foo.pt")

class_annotator.exportNone(recursivescriptmodule._c._type())
class_annotator.exportPath(recursivescriptmodule._c._type(), ["forward"])
class_annotator.annotateArgs(
recursivescriptmodule._c._type(),
["forward"],
[
None,
([4, 64], torch.float32, True),
],
)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, class_annotator)

backend = refbackend.RefBackendLinalgOnTensorsBackend()
with mb.module.context:
pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
pm.run(mb.module)

compiled = backend.compile(mb.module)
jit_module = backend.load(compiled)

predictions(test_module.forward, jit_module.forward, data, output)
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 619 files
1 change: 1 addition & 0 deletions externals/mlir-hlo
Submodule mlir-hlo added at 1bafb1
8 changes: 8 additions & 0 deletions include/torch-mlir/Conversion/Passes.td
Expand Up @@ -125,4 +125,12 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
}

def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
let summary = "Convert Torch ops to MHLO ops";
let description = [{
Convert ATen ops to mhlo ops.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert ATen ops to mhlo ops.

Convert Torch ops ... ? Seems not only aten ops in this pass.

}];
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
}

#endif // TORCHMLIR_CONVERSION_PASSES