diff --git a/intel_pytorch_extension_py/ops/__init__.py b/intel_pytorch_extension_py/ops/__init__.py index 930157eed..182ab2325 100644 --- a/intel_pytorch_extension_py/ops/__init__.py +++ b/intel_pytorch_extension_py/ops/__init__.py @@ -5,5 +5,4 @@ from .reshape import * from .mlp import * from .linear_fuse_relu import * -from .module import * from .jit_script import * diff --git a/tests/cpu/test_lazy_reorder.py b/tests/cpu/test_lazy_reorder.py index c1dfe8cf6..d06bce7a9 100644 --- a/tests/cpu/test_lazy_reorder.py +++ b/tests/cpu/test_lazy_reorder.py @@ -13,6 +13,8 @@ import torch import intel_pytorch_extension as ipex +from common_ipex_conf import AutoMixPrecision, AutoDNNL + import torch.nn as nn import torch.backends.cudnn as cudnn from torch.nn import Parameter @@ -882,7 +884,7 @@ def test_view(self): y = torch.randn(new_shape) out_cpu = x_cpu_view * y # test if the shape of x_dpcpp_view is compatible with y - out_dpcpp = x_dpcpp_view * y + out_dpcpp = x_dpcpp_view * y.to(device) self.assertTrue(ipex.is_dil_tensor(out_dpcpp)) self.assertEqual(ipex.get_dil_tensor_sizes(out_dpcpp), [1, 4, 4, 4]) self.assertEqual(ipex.get_dil_tensor_strides(out_dpcpp), [64, 16, 4, 1]) diff --git a/torch_ipex/csrc/cpu/DevOPs.cpp b/torch_ipex/csrc/cpu/DevOPs.cpp index 66b506d33..859f656b8 100644 --- a/torch_ipex/csrc/cpu/DevOPs.cpp +++ b/torch_ipex/csrc/cpu/DevOPs.cpp @@ -14,6 +14,7 @@ #include "dbl/Common.h" #include "dbl/Conv.h" #include "dbl/Pool.h" +#include "dbl/DNNLChecker.h" #include "ShadeDataContext.h" #include "dil/dil.hpp" @@ -173,19 +174,31 @@ std::tuple AtenIpexCPUDev::dil_convolution_bac at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { DEBUG("AtenIpexCPUDev::convolution_overrideable\n"); - // NOTE: DO NOT always call contiguous. It may break lazy-reorder. Because contiguous will call reorder instantly. - if (check_auto_dnnl()) { - return dil_convolution( - input.is_contiguous() ? input : input.contiguous(), - weight.is_contiguous() ? weight : weight.contiguous(), - bias.defined() ? (bias.is_contiguous() ? bias :bias.contiguous()) : bias, - stride, - padding, - dilation, - groups); - } else { - return mkldnn_convolution(input, weight, bias, padding, stride, dilation, groups); + + try { + if (check_auto_dnnl()) { + std::vector dnnl_input_tensors; + dnnl_input_tensors.push_back(input); + dnnl_input_tensors.push_back(weight); + dnnl_input_tensors.push_back(bias); + if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors)) + return AtenIpexCPUDev::dil_convolution(input.is_contiguous() ? input : input.contiguous(), weight.is_contiguous() ? weight : weight.contiguous(), bias.is_contiguous() ? bias : bias.contiguous(), stride, padding, dilation, groups); + } + } catch (std::exception& e) { +#if defined(_DEBUG) + TORCH_WARN(e.what()); +#endif } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.layout() == c10::kStrided); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.layout() == c10::kStrided); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bias.layout() == c10::kStrided); + auto&& _ipex_input = bridge::shallowFallbackToCPUTensor(input); + auto&& _ipex_weight = bridge::shallowFallbackToCPUTensor(weight); + auto&& _ipex_bias = bridge::shallowFallbackToCPUTensor(bias); + auto&& _ipex_result = at::mkldnn_convolution(_ipex_input, _ipex_weight, _ipex_bias, padding, stride, dilation, groups); + static_cast(_ipex_result); // Avoid warnings in case not used + return bridge::shallowUpgradeToDPCPPTensor(_ipex_result); } at::Tensor AtenIpexCPUDev::mkldnn_convolution(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { diff --git a/torch_ipex/csrc/cpu/dbl/DNNLChecker.cpp b/torch_ipex/csrc/cpu/dbl/DNNLChecker.cpp index b49f346e8..9984ae180 100644 --- a/torch_ipex/csrc/cpu/dbl/DNNLChecker.cpp +++ b/torch_ipex/csrc/cpu/dbl/DNNLChecker.cpp @@ -9,7 +9,8 @@ namespace dbl { namespace chk { bool dnnl_support_the_tensors(const std::vector &tensor_vec) { - return dnnl_tensor_has_data(tensor_vec) && + return all_is_dpcpp(tensor_vec) && + dnnl_tensor_has_data(tensor_vec) && dnnl_support_the_dimension_of(tensor_vec) && dnnl_support_the_data_type_of(tensor_vec); } @@ -62,6 +63,14 @@ bool dnnl_tensor_has_data(const std::vector &tensor_vec) { return true; } +bool all_is_dpcpp(const std::vector &tensor_vec) { + for (auto it = tensor_vec.begin(); it != tensor_vec.end(); ++it) + if (!it->device().is_dpcpp()) + return false; + + return true; +} + } // namespace chk } // namespace dbl } // namespace cpu diff --git a/torch_ipex/csrc/cpu/dbl/DNNLChecker.h b/torch_ipex/csrc/cpu/dbl/DNNLChecker.h index fc6eae28a..5ca5606fd 100644 --- a/torch_ipex/csrc/cpu/dbl/DNNLChecker.h +++ b/torch_ipex/csrc/cpu/dbl/DNNLChecker.h @@ -62,13 +62,21 @@ bool dnnl_support_the_data_type_of(const std::vector &tensor_vec); bool dnnl_support_the_dimension_of(const std::vector &tensor_vec); /** - * Check if the input tensor has data + * Check if all input tensors has data * * @param tensor_vec input tensors * */ static inline bool dnnl_tensor_has_data(const std::vector &tensor_vec); +/** + * Check if all input tensors are dpcpp tensor + * + * @param tensor_vec input tensors + * + */ +bool all_is_dpcpp(const std::vector &tensor_vec); + } // namespace chk } // namespace dbl } // namespace cpu