Skip to content

Commit

Permalink
api: fixing bugs in views
Browse files Browse the repository at this point in the history
  • Loading branch information
mgouicem committed Mar 16, 2018
1 parent 25f6e3f commit 8186564
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion include/mkldnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ struct view : public primitive {
mkldnn_primitive_t result;
primitive_desc view_pd(input.get_primitive_desc(), dims,
offsets);
mkldnn_primitive_at_t inputs[] = { {input.get(), 0} };
mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
error::wrap_c_api(mkldnn_primitive_create(&result,
view_pd.get(), inputs, nullptr),
"could not create a view primitive");
Expand Down
2 changes: 1 addition & 1 deletion src/common/memory_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ struct view_pd_t: public primitive_desc_t {
virtual const memory_pd_t *output_pd(int index = 0) const override
{ return index == 0 ? dst_pd() : nullptr; }
virtual int n_inputs() const override { return 1; }
virtual int n_outputs() const override { return 1; }
virtual int n_outputs() const override { return 0; }
};

struct concat_pd_t: public primitive_desc_t {
Expand Down
14 changes: 11 additions & 3 deletions src/common/primitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "primitive.hpp"
#include "engine.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"

using namespace mkldnn::impl;
using namespace mkldnn::impl::status;
Expand All @@ -36,10 +37,17 @@ status_t mkldnn_primitive_create(primitive_t **primitive,
const primitive_t **outputs) {
if (utils::any_null(primitive, primitive_desc))
return invalid_arguments;
for (int i = 0; i < primitive_desc->n_inputs(); ++i)
if (inputs[i].primitive == nullptr ||
inputs[i].output_index >= size_t(primitive_desc->n_outputs()))
for (int i = 0; i < primitive_desc->n_inputs(); ++i) {
const auto i_p = inputs[i].primitive;
const auto i_oi = (int)inputs[i].output_index;
const bool ok = true
&& i_p != nullptr
&& utils::implication(i_p->kind() == memory, i_oi == 0)
&& utils::implication(i_p->kind() != memory,
i_oi < i_p->pd()->n_outputs());
if (!ok)
return invalid_arguments;
}
for (int i = 0; i < primitive_desc->n_outputs(); ++i)
if (outputs[i] == nullptr) return invalid_arguments;
return primitive_desc->create_primitive(primitive, inputs, outputs);
Expand Down

0 comments on commit 8186564

Please sign in to comment.