Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using iree to convert jax with 64 convs fails #5240

Closed
cioc-shoreline opened this issue Mar 28, 2021 · 2 comments
Closed

Using iree to convert jax with 64 convs fails #5240

cioc-shoreline opened this issue Mar 28, 2021 · 2 comments
Assignees
Labels
bug 🐞 Something isn't working help wanted Extra attention is needed

Comments

@cioc-shoreline
Copy link

Describe the bug
When trying to convert jax based models using iree with JAX_ENABLE_X64=True, it fails. When going through convs with float64 inputs, the compiler tries to convert to float32. E.g.

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "sample_conversion_failure.py", line 15, in <module>
    binary = iree_jax_support.aot(func, 
  File "/home/charles/iree_jax/iree_jax_support.py", line 65, in aot
    binary = iree.compiler.xla.compile_str(hlo_proto, **options)
  File "/home/charles/.local/lib/python3.8/site-packages/iree/compiler/xla.py", line 184, in compile_str
    result = invoke_pipeline([import_cl, compile_cl], immediate_input=xla_content)
  File "/home/charles/.local/lib/python3.8/site-packages/iree/compiler/tools.py", line 220, in invoke_pipeline
    raise CompilerToolError(stage.completed)
iree.compiler.tools.CompilerToolError: Error invoking IREE compiler tool iree-translate
Diagnostics:
<unknown>:0: error: loc("constant.2"): unsupported attribute kind for conversion from 'tensor<1x2xf64>' to 'tensor<1x2xf32>'
<unknown>:0: error: loc("constant.2"): 'mhlo.constant' op unable to legalize operation types
<unknown>:0: note: loc("constant.2"): see current operation: %0 = "mhlo.constant"() {value = dense<[[1.000000e+00, -1.000000e+00]]> : tensor<1x2xf64>} : () -> tensor<1x2xf64>
<unknown>:0: error: conversion from source -> vm failed
<unknown>:0: note: see current operation: "module"() ( {
  "func"() ( {
  ^bb0(%arg0: tensor<128x128xf32>):  // no predecessors
  }) {iree.module.export, iree.reflection = {f = "I15!B11!t2d128d128R15!B11!t2d128d127", fv = "1"}, sym_name = "main", type = (tensor<128x128xf32>) -> tensor<128x127xf32>} : () -> ()
  "func"() ( {
  ^bb0(%arg0: tensor<128x128xf64>):  // no predecessors
    %0 = "mhlo.constant"() {value = dense<[[1.000000e+00, -1.000000e+00]]> : tensor<1x2xf64>} : () -> tensor<1x2xf64>
    %1 = "mhlo.constant"() {value = dense<1.000000e+02> : tensor<128x127xf64>} : () -> tensor<128x127xf64>
    %2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<128x128xf64>) -> tensor<1x1x128x128xf64>
    %3 = "mhlo.reverse"(%0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf64>) -> tensor<1x2xf64>
    %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<1x2xf64>) -> tensor<1x1x1x2xf64>
    %5 = "mhlo.transpose"(%2) {permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>} : (tensor<1x1x128x128xf64>) -> tensor<1x128x128x1xf64>
    %6 = "mhlo.transpose"(%4) {permutation = dense<[2, 3, 1, 0]> : tensor<4xi64>} : (tensor<1x1x1x2xf64>) -> tensor<1x2x1x1xf64>
    %7 = "mhlo.convolution"(%5, %6) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<1x128x128x1xf64>, tensor<1x2x1x1xf64>) -> tensor<1x128x127x1xf64>
    %8 = "mhlo.transpose"(%7) {permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<1x128x127x1xf64>) -> tensor<1x1x128x127xf64>
    %9 = "mhlo.reshape"(%8) : (tensor<1x1x128x127xf64>) -> tensor<128x127xf64>
    %10 = "mhlo.multiply"(%9, %1) : (tensor<128x127xf64>, tensor<128x127xf64>) -> tensor<128x127xf64>
    "std.return"(%10) : (tensor<128x127xf64>) -> ()
  }) {iree.module.export, iree.reflection = {f = "I15!B11!t2d128d128R15!B11!t2d128d127", fv = "1"}, sym_name = "main", type = (tensor<128x128xf64>) -> tensor<128x127xf64>} : () -> ()
  "module_terminator"() : () -> ()
}) : () -> ()

To Reproduce
Steps to reproduce the behavior:
Here is the script to reproduce:

import jax
import jax.numpy as jnp
import numpy as np
import iree_jax_support

def func(params, x):
    v = jax.scipy.signal.convolve2d(x, jnp.array([[1, -1]], dtype=jnp.float64), mode='valid')
    return 100 * v

trace_args = [
  {},
  jnp.zeros((128,128), dtype=jnp.float64)
]

binary = iree_jax_support.aot(func, 
                              *trace_args,
                              target_backends=["vmla"])

with open("my_model.fbs", "wb+") as f:
    f.write(binary)

Run with:

JAX_ENABLE_X64=True python sample_conversion_failure.py
@cioc-shoreline cioc-shoreline added bug 🐞 Something isn't working help wanted Extra attention is needed labels Mar 28, 2021
@benvanik
Copy link
Collaborator

64-bit data types aren't fully supported yet - you'll want to use f32 for now

@GMNGeoffrey
Copy link
Contributor

I'm going to close this as a dupe of #5223. Sorry for the confusion. As that issue notes we really need to better document type support, and then find the gaps and figure out how to handle them in a way that doesn't unnecessarily throw someone off a performance cliff when they unnecessarily used i64 for shape dimensions or such

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants