Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion intel_pytorch_extension_py/ops/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@ def embeddingbag(weights, inputs, offsets, scale_grad_by_freq, mode, sparse, per
else:
assert(0, "unimplement embeddingbag path in extension")
'''
torch_embedding_bag = torch.embedding_bag

def embeddingbag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
ret = EmbeddingBagFunction.apply(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
if indices.device==torch.device("dpcpp"):
ret = EmbeddingBagFunction.apply(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
else:
ret = torch_embedding_bag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
return ret


Expand Down
13 changes: 12 additions & 1 deletion torch_ipex/csrc/cpu/dbl/DNNLChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace chk {
bool dnnl_support_the_tensors(const std::vector<at::Tensor> &tensor_vec) {
return dnnl_tensor_has_data(tensor_vec) &&
dnnl_support_the_dimension_of(tensor_vec) &&
dnnl_support_the_data_type_of(tensor_vec);
dnnl_support_the_data_type_of(tensor_vec) &&
dnnl_support_the_device_type_of(tensor_vec);
}

bool dnnl_inplace_support_the_tensors(const std::vector<at::Tensor> &tensor_vec) {
Expand Down Expand Up @@ -44,6 +45,16 @@ bool dnnl_support_the_data_type_of(const std::vector<at::Tensor> &tensor_vec) {
return true;
}

bool dnnl_support_the_device_type_of(const std::vector<at::Tensor> &tensor_vec) {
for (auto it = tensor_vec.begin(); it != tensor_vec.end(); ++it) {
if (!it.device().is_dpcpp()) {
return false;
}
}

return true;
}

bool dnnl_support_the_dimension_of(const std::vector<at::Tensor> &tensor_vec) {
for (auto it = tensor_vec.begin(); it != tensor_vec.end(); ++it) {
if (it->dim() <= 0) {
Expand Down
8 changes: 8 additions & 0 deletions torch_ipex/csrc/cpu/dbl/DNNLChecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ bool dnnl_support_the_memory_layout_of(const std::vector<at::Tensor> &tensor_vec
*/
bool dnnl_support_the_data_type_of(const std::vector<at::Tensor> &tensor_vec);

/**
* Check if the device type of the input tenosrs can be supported by DNNL
*
* @param tensor_vec input tensors
*
*/
bool dnnl_support_the_device_type_of(const std::vector<at::Tensor> &tensor_vec);

/**
* Check if the dimension of the input tenosrs can be supported by DNNL. The dimension
* of the input tensor should be > 0.
Expand Down