Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fx importer does not support bfloat16 #2843

Open
kumardeepakamd opened this issue Jan 31, 2024 · 1 comment
Open

Fx importer does not support bfloat16 #2843

kumardeepakamd opened this issue Jan 31, 2024 · 1 comment
Assignees

Comments

@kumardeepakamd
Copy link
Collaborator

Error seen:

Traceback (most recent call last):
  File "/proj/gdba/kumar/nod/SHARK-TestSuite/e2eshark/t-r-bf16-direct-fx-importer/pytorch/combinations/mlp/runmodel.py", line 131, in <module>
    torch_mlir_model = export_and_import(model, test_input)
  File "/proj/gdba/kumar/nod/SHARK-TestSuite/e2eshark/t-r-bf16-direct-fx-importer/pytorch/combinations/mlp/runmodel.py", line 59, in export_and_import
    fx_importer.import_frozen_exported_program(prog)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 351, in import_frozen_exported_program
    self.import_stateless_graph(g)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 377, in import_stateless_graph
    node_importer.import_nodes(g.nodes)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 620, in import_nodes
    self._import_torch_op_overload(loc, node, target)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 798, in _import_torch_op_overload
    self._import_argument(loc, node.args[i], parameter.type)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 860, in _import_argument
    return self._import_literal(arg)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 877, in _import_literal
    return converter(py_value, self, self._cc)
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 1169, in <lambda>
    lambda arg, gni, cc: _make_vtensor_literal_op(
  File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch_mlir/extras/fx_importer.py", line 1006, in _make_vtensor_literal_op
    npy_dtype is not None
AssertionError: Can not create literal tensor for unsupported datatype: torch.bfloat16

Steps to Reproduce:
Make sure you have a python env with torch-mlir package installed.
Save following file as model.py

import torch
import torch.nn as nn

# Fx importer related
from typing import Optional
import torch.export
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d


def export_and_import(
    f,
    *args,
    fx_importer: Optional[FxImporter] = None,
    constraints: Optional[torch.export.Constraint] = None,
    **kwargs,
):
    context = ir.Context()
    torch_d.register_dialect(context)

    if fx_importer is None:
        fx_importer = FxImporter(context=context)
    prog = torch.export.export(f, args, kwargs, constraints=constraints)
    fx_importer.import_frozen_exported_program(prog)
    return fx_importer.module_op


class mlp(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            # 3 input, 4 output
            nn.Linear(3, 4),
            nn.ReLU(),
            # 3 input, 5 output
            nn.Linear(4, 5),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.layers(x)

model = mlp()
test_input = torch.randn(8, 3)
test_output = model(test_input)
print("Input:", test_input)
print("Output:", test_output)
model = model.to(torch.bfloat16)
test_input = test_input.to(torch.bfloat16)
torch_mlir_model = export_and_import(model, test_input)
with open("mlp.torch.mlir", "w+") as f:
    f.write(torch_mlir_model.operation.get_asm())

Run:

python ./model.py
@kumardeepakamd
Copy link
Collaborator Author

Any work started on this?

dan-garvey added a commit that referenced this issue Feb 14, 2024
this introduces an additional soft dependency on the python ml_dtypes
python packages in order to support bfloat16

Addresses #2843
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants