Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion intel_pytorch_extension_py/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@
from .reshape import *
from .mlp import *
from .linear_fuse_relu import *
from .module import *
from .jit_script import *
4 changes: 3 additions & 1 deletion tests/cpu/test_lazy_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch
import intel_pytorch_extension as ipex

from common_ipex_conf import AutoMixPrecision, AutoDNNL

import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.nn import Parameter
Expand Down Expand Up @@ -882,7 +884,7 @@ def test_view(self):
y = torch.randn(new_shape)
out_cpu = x_cpu_view * y
# test if the shape of x_dpcpp_view is compatible with y
out_dpcpp = x_dpcpp_view * y
out_dpcpp = x_dpcpp_view * y.to(device)
self.assertTrue(ipex.is_dil_tensor(out_dpcpp))
self.assertEqual(ipex.get_dil_tensor_sizes(out_dpcpp), [1, 4, 4, 4])
self.assertEqual(ipex.get_dil_tensor_strides(out_dpcpp), [64, 16, 4, 1])
Expand Down
37 changes: 25 additions & 12 deletions torch_ipex/csrc/cpu/DevOPs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "dbl/Common.h"
#include "dbl/Conv.h"
#include "dbl/Pool.h"
#include "dbl/DNNLChecker.h"
#include "ShadeDataContext.h"

#include "dil/dil.hpp"
Expand Down Expand Up @@ -173,19 +174,31 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::dil_convolution_bac

at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) {
DEBUG("AtenIpexCPUDev::convolution_overrideable\n");
// NOTE: DO NOT always call contiguous. It may break lazy-reorder. Because contiguous will call reorder instantly.
if (check_auto_dnnl()) {
return dil_convolution(
input.is_contiguous() ? input : input.contiguous(),
weight.is_contiguous() ? weight : weight.contiguous(),
bias.defined() ? (bias.is_contiguous() ? bias :bias.contiguous()) : bias,
stride,
padding,
dilation,
groups);
} else {
return mkldnn_convolution(input, weight, bias, padding, stride, dilation, groups);

try {
if (check_auto_dnnl()) {
std::vector<at::Tensor> dnnl_input_tensors;
dnnl_input_tensors.push_back(input);
dnnl_input_tensors.push_back(weight);
dnnl_input_tensors.push_back(bias);
if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors))
return AtenIpexCPUDev::dil_convolution(input.is_contiguous() ? input : input.contiguous(), weight.is_contiguous() ? weight : weight.contiguous(), bias.is_contiguous() ? bias : bias.contiguous(), stride, padding, dilation, groups);
}
} catch (std::exception& e) {
#if defined(_DEBUG)
TORCH_WARN(e.what());
#endif
}

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bias.layout() == c10::kStrided);
auto&& _ipex_input = bridge::shallowFallbackToCPUTensor(input);
auto&& _ipex_weight = bridge::shallowFallbackToCPUTensor(weight);
auto&& _ipex_bias = bridge::shallowFallbackToCPUTensor(bias);
auto&& _ipex_result = at::mkldnn_convolution(_ipex_input, _ipex_weight, _ipex_bias, padding, stride, dilation, groups);
static_cast<void>(_ipex_result); // Avoid warnings in case not used
return bridge::shallowUpgradeToDPCPPTensor(_ipex_result);
}

at::Tensor AtenIpexCPUDev::mkldnn_convolution(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) {
Expand Down
11 changes: 10 additions & 1 deletion torch_ipex/csrc/cpu/dbl/DNNLChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ namespace dbl {
namespace chk {

bool dnnl_support_the_tensors(const std::vector<at::Tensor> &tensor_vec) {
return dnnl_tensor_has_data(tensor_vec) &&
return all_is_dpcpp(tensor_vec) &&
dnnl_tensor_has_data(tensor_vec) &&
dnnl_support_the_dimension_of(tensor_vec) &&
dnnl_support_the_data_type_of(tensor_vec);
}
Expand Down Expand Up @@ -62,6 +63,14 @@ bool dnnl_tensor_has_data(const std::vector<at::Tensor> &tensor_vec) {
return true;
}

bool all_is_dpcpp(const std::vector<at::Tensor> &tensor_vec) {
for (auto it = tensor_vec.begin(); it != tensor_vec.end(); ++it)
if (!it->device().is_dpcpp())
return false;

return true;
}

} // namespace chk
} // namespace dbl
} // namespace cpu
Expand Down
10 changes: 9 additions & 1 deletion torch_ipex/csrc/cpu/dbl/DNNLChecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,21 @@ bool dnnl_support_the_data_type_of(const std::vector<at::Tensor> &tensor_vec);
bool dnnl_support_the_dimension_of(const std::vector<at::Tensor> &tensor_vec);

/**
* Check if the input tensor has data
* Check if all input tensors has data
*
* @param tensor_vec input tensors
*
*/
static inline bool dnnl_tensor_has_data(const std::vector<at::Tensor> &tensor_vec);

/**
* Check if all input tensors are dpcpp tensor
*
* @param tensor_vec input tensors
*
*/
bool all_is_dpcpp(const std::vector<at::Tensor> &tensor_vec);

} // namespace chk
} // namespace dbl
} // namespace cpu
Expand Down