@@ -43,22 +43,25 @@ using float32_t = float;
4343#include " mlir/Parser/Parser.h"
4444#include " mlir/Tools/mlir-opt/MlirOptMain.h"
4545
46+ #include " dnnl_types.h"
4647#include " graph/utils/json.hpp"
4748
49+ using Strides = llvm::SmallVector<int64_t , DNNL_MAX_NDIMS>;
50+
4851class JsonParser {
4952 dnnl::impl::graph::utils::json::json_reader_t _reader;
5053 mlir::OpBuilder _builder;
5154 mlir::Location _loc;
5255 mlir::Block *_entryBlock;
53- std::vector <size_t > &_inputIds;
54- std::vector< size_t > &_outputIds ;
56+ llvm::SmallVector <size_t > &_inputIds;
57+ std::unordered_map<std:: size_t , Strides> _strides ;
5558 // Function input and operations output values. Used to connect the
5659 // operations inputs and outputs.
5760 std::unordered_map<std::size_t , mlir::Value> _valueMap;
5861 // Temporary value holders, used by the parser
59- std::vector <mlir::Value> _operands;
60- std::vector <mlir::Type> _resultTypes;
61- std::vector <mlir::NamedAttribute> _attributes;
62+ llvm::SmallVector <mlir::Value> _operands;
63+ llvm::SmallVector <mlir::Type> _resultTypes;
64+ llvm::SmallVector <mlir::NamedAttribute> _attributes;
6265 std::string _str;
6366 std::string _str2;
6467 std::size_t _uS;
@@ -70,9 +73,9 @@ class JsonParser {
7073 std::vector<std::float32_t > _fa32;
7174
7275 JsonParser (mlir::MLIRContext &context, std::istream &stream,
73- std::vector <size_t > &inputIds, std::vector< size_t > &outputIds )
76+ llvm::SmallVector <size_t > &inputIds)
7477 : _reader(&stream), _builder(&context), _loc(_builder.getUnknownLoc()),
75- _inputIds (inputIds), _outputIds(outputIds ), _valueMap(), _operands(),
78+ _inputIds (inputIds), _strides( ), _valueMap(), _operands(),
7679 _resultTypes(), _attributes(), _str(), _str2(), _uS(), _i64(), _f32(),
7780 _uaS(), _ia64(), _ia642(), _fa32() {
7881 // Creating a dummy function since we don't know the actual type yet.
@@ -82,7 +85,8 @@ class JsonParser {
8285 _builder.setInsertionPointToStart (_entryBlock);
8386 }
8487
85- mlir::ModuleOp parse ();
88+ mlir::ModuleOp parse (llvm::SmallVector<size_t > &outputIds,
89+ std::unordered_map<std::size_t , Strides> &strides);
8690 void readOp ();
8791 mlir::Attribute readAttr ();
8892 mlir::Type readTensorType ();
@@ -120,11 +124,12 @@ class JsonParser {
120124 }
121125 }
122126
123- template <typename T> inline void readNumArray (std::vector<T> &vec) {
127+ template <typename T, template <typename ...> class Container , typename ... Any>
128+ inline void readNumArray (Container<T, Any...> &c) {
124129 _reader.begin_array ();
125130 for (T value; _reader.next_array_item ();) {
126131 _reader.read_number (&value);
127- vec .push_back (value);
132+ c .push_back (value);
128133 }
129134 }
130135
@@ -175,14 +180,16 @@ class JsonParser {
175180 * @param json JSON string containing the oneDNN graph.
176181 * @param inputIds Input tensor IDs are added to this vector.
177182 * @param outputIds Output tensor IDs are added to this vector.
183+ * @param strides Strides for each tensor are added to this map.
178184 * @return The resulting MLIR module.
179185 */
180- static mlir::ModuleOp parse (mlir::MLIRContext &context,
181- const std::string_view &json,
182- std::vector<size_t > &inputIds,
183- std::vector<size_t > &outputIds) {
186+ static mlir::ModuleOp
187+ parse (mlir::MLIRContext &context, const std::string_view &json,
188+ llvm::SmallVector<size_t > &inputIds,
189+ llvm::SmallVector<size_t > &outputIds,
190+ std::unordered_map<std::size_t , Strides> &strides) {
184191 std::istringstream stream (json.data ());
185- JsonParser parser (context, stream, inputIds, outputIds );
186- return parser.parse ();
192+ JsonParser parser (context, stream, inputIds);
193+ return parser.parse (outputIds, strides );
187194 }
188195};
0 commit comments