Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing Operations in emitC MHLO #41

Closed
2 tasks
OliverScherf opened this issue Jan 27, 2021 · 1 comment
Closed
2 tasks

Missing Operations in emitC MHLO #41

OliverScherf opened this issue Jan 27, 2021 · 1 comment

Comments

@OliverScherf
Copy link
Contributor

OliverScherf commented Jan 27, 2021

For a current use-case I need two missing operations in emitC MHLO:

  • mhlo::slice with Tensor3D
model_generated.h:251:114: error: no matching function for call to ‘slice<Tensor3D<float, 1, 882, 1> >(Tensor3D<float, 1, 882, 6>&, <brace-enclosed initializer list>, <brace-enclosed initializer list>, <brace-enclosed initializer list>)’
  251 | Tensor3D<float, 1, 882, 1> v234 = mhlo::slice<Tensor3D<float, 1, 882, 1>>(v233, {0, 0, 5}, {1, 882, 6}, {1, 1, 1});
      |                                                                                                                  ^
In file included from model_generated.h:3,
                 from test.cpp:2:
/home/oliver/git/mlir-emitc/include/emitc/emitc_mhlo.h:498:6: note: candidate: ‘template<class Dest, class Src, typename std::enable_if<is_tensor_of_dim<1, Src, void>::value, bool>::type <anonymous> > Dest mhlo::slice(Src, Tensor1D<long int, 1>, Tensor1D<long int, 1>, Tensor1D<long int, 1>)’
  498 | Dest slice(Src x, Tensor1D<int64_t, 1> start_indices,
      |      ^~~~~
/home/oliver/git/mlir-emitc/include/emitc/emitc_mhlo.h:498:6: note:   template argument deduction/substitution failed:
/home/oliver/git/mlir-emitc/include/emitc/emitc_mhlo.h:497:64: error: no type named ‘type’ in ‘struct std::enable_if<false, bool>’
  497 | template <typename Dest, typename Src, IsTensorOfDim<1, Src> = true>
      |                                                                ^~~~
/home/oliver/git/mlir-emitc/include/emitc/emitc_mhlo.h:512:6: note: candidate: ‘template<class Dest, class Src, typename std::enable_if<is_tensor_of_dim<2, Src, void>::value, bool>::type <anonymous> > Dest mhlo::slice(Src, Tensor1D<long int, 2>, Tensor1D<long int, 2>, Tensor1D<long int, 2>)’
  512 | Dest slice(Src x, Tensor1D<int64_t, 2> start_indices,
      |      ^~~~~
/home/oliver/git/mlir-emitc/include/emitc/emitc_mhlo.h:512:6: note:   template argument deduction/substitution failed:
/home/oliver/git/mlir-emitc/include/emitc/emitc_mhlo.h:511:64: error: no type named ‘type’ in ‘struct std::enable_if<false, bool>’
  511 | template <typename Dest, typename Src, IsTensorOfDim<2, Src> = true>
      |                                                                ^~~~
  • mhlo::concatenate

Code that causes the issue:

Tensor4D<float, 1, 59, 87, 64> v101 = mhlo::max<>(v100, v74);
Tensor4D<float, 1, 59, 87, 64> v105 = mhlo::max<>(v104, v74);
Tensor4D<float, 1, 59, 87, 128> v106 = mhlo::concatenate<3, Tensor4D<float, 1, 59, 87, 128>>(v101, v105);

Error message:

/home/oliver/git/mlir-emitc/include/emitc/emitc_mhlo.h: In instantiation of ‘Dest mhlo::concatenate(Src1, Src ...) [with long int Dimension = 3; Dest = Tensor<float, 1, 59, 87, 128>; Src1 = Tensor<float, 1, 59, 87, 64>; Src = {Tensor<float, 1, 59, 87, 64>}]’:
model_generated.h:123:104:   required from here
/home/oliver/git/mlir-emitc/include/emitc/emitc_mhlo.h:458:9: error: no type named ‘type’ in ‘struct concat<3, float, Tensor<float, 1, 59, 87, 64> >’
  458 |   using Rest = typename concat<Dimension, ET_Src, Src...>::type;
      |         ^~~~
/home/oliver/git/mlir-emitc/include/emitc/emitc_mhlo.h: In instantiation of ‘Dest mhlo::concatenate(Src1, Src ...) [with long int Dimension = 3; Dest = Tensor<float, 1, 29, 43, 256>; Src1 = Tensor<float, 1, 29, 43, 128>; Src = {Tensor<float, 1, 29, 43, 128>}]’:
model_generated.h:151:104:   required from here
/home/oliver/git/mlir-emitc/include/emitc/emitc_mhlo.h:458:9: error: no type named ‘type’ in ‘struct concat<3, float, Tensor<float, 1, 29, 43, 128> >’
/home/oliver/git/mlir-emitc/include/emitc/emitc_mhlo.h: In instantiation of ‘Dest mhlo::concatenate(Src1, Src ...) [with long int Dimension = 3; Dest = Tensor<float, 1, 14, 21, 384>; Src1 = Tensor<float, 1, 14, 21, 192>; Src = {Tensor<float, 1, 14, 21, 192>}]’:
model_generated.h:179:104:   required from here
@marbre
Copy link
Member

marbre commented Jan 27, 2021

I split this one into separate issues.

@marbre marbre closed this as completed Jan 27, 2021
This was referenced Jan 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants