You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "/proj/gdba/kumar/nod/SHARK-TestSuite/e2eshark/t-r-bf16-direct-fx-importer/pytorch/combinations/mlp/runmodel.py", line 131, in <module>
torch_mlir_model = export_and_import(model, test_input)
File "/proj/gdba/kumar/nod/SHARK-TestSuite/e2eshark/t-r-bf16-direct-fx-importer/pytorch/combinations/mlp/runmodel.py", line 59, in export_and_import
fx_importer.import_frozen_exported_program(prog)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 351, in import_frozen_exported_program
self.import_stateless_graph(g)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 377, in import_stateless_graph
node_importer.import_nodes(g.nodes)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 620, in import_nodes
self._import_torch_op_overload(loc, node, target)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 798, in _import_torch_op_overload
self._import_argument(loc, node.args[i], parameter.type)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 860, in _import_argument
return self._import_literal(arg)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 877, in _import_literal
return converter(py_value, self, self._cc)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 1169, in <lambda>
lambda arg, gni, cc: _make_vtensor_literal_op(
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 1006, in _make_vtensor_literal_op
npy_dtype is not None
AssertionError: Can not create literal tensor for unsupported datatype: torch.bfloat16
Steps to Reproduce:
Make sure you have a python env with torch-mlir package installed.
Save following file as model.py
import torch
import torch.nn as nn
# Fx importer related
from typing import Optional
import torch.export
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
def export_and_import(
f,
*args,
fx_importer: Optional[FxImporter] = None,
constraints: Optional[torch.export.Constraint] = None,
**kwargs,
):
context = ir.Context()
torch_d.register_dialect(context)
if fx_importer is None:
fx_importer = FxImporter(context=context)
prog = torch.export.export(f, args, kwargs, constraints=constraints)
fx_importer.import_frozen_exported_program(prog)
return fx_importer.module_op
class mlp(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
# 3 input, 4 output
nn.Linear(3, 4),
nn.ReLU(),
# 3 input, 5 output
nn.Linear(4, 5),
nn.ReLU(),
)
def forward(self, x):
return self.layers(x)
model = mlp()
test_input = torch.randn(8, 3)
test_output = model(test_input)
print("Input:", test_input)
print("Output:", test_output)
model = model.to(torch.bfloat16)
test_input = test_input.to(torch.bfloat16)
torch_mlir_model = export_and_import(model, test_input)
with open("mlp.torch.mlir", "w+") as f:
f.write(torch_mlir_model.operation.get_asm())
Run:
python ./model.py
The text was updated successfully, but these errors were encountered:
Error seen:
Steps to Reproduce:
Make sure you have a python env with torch-mlir package installed.
Save following file as model.py
Run:
The text was updated successfully, but these errors were encountered: