diff --git a/source/protobuf.js b/source/protobuf.js index 00f55111c3..27035d3f6c 100644 --- a/source/protobuf.js +++ b/source/protobuf.js @@ -500,6 +500,7 @@ protobuf.TextReader = class { static open(data) { const buffer = data instanceof Uint8Array ? data : data.peek(); const decoder = base.TextDecoder.open(buffer); + let first = true; for (let i = 0; i < 0x100; i++) { const c = decoder.decode(); if (c === undefined || c === '\0') { @@ -508,10 +509,12 @@ protobuf.TextReader = class { } break; } - if (c < ' ' && c !== '\n' && c !== '\r' && c !== '\t') { + const whitespace = c === ' ' || c === '\n' || c === '\r' || c === '\t'; + if (c < ' ' && !whitespace) { return null; } - if (i === 0) { + if (first && !whitespace) { + first = false; if (c === '#' || c === '[') { continue; } @@ -534,7 +537,7 @@ protobuf.TextReader = class { this.reset(); try { this.start(false); - while (!this.end(false)) { + while (!this.end()) { const tag = this.tag(); tags.set(tag, true); if (this.token() === '{') { @@ -576,13 +579,23 @@ protobuf.TextReader = class { } end() { - if (this._depth > 0 && this._token === '}') { + if (this._depth <= 0) { + throw new protobuf.Error('Invalid depth ' + this.location()); + } + if (this._token === '}') { this.expect('}'); this.match(';'); this._depth--; return true; } - return this._token === undefined; + if (this._token === undefined) { + if (this._depth !== 1) { + throw new protobuf.Error('Unexpected end of input' + this.location()); + } + this._depth--; + return true; + } + return false; } tag() { @@ -728,7 +741,7 @@ protobuf.TextReader = class { else { value = Number.parseInt(token, 10); if (Number.isNaN(token - value)) { - throw new protobuf.Error("Couldn't parse enum '" + token + "'" + this.location()); + throw new protobuf.Error("Couldn't parse enum '" + (token === undefined ? '' : token) + "'" + this.location()); } } this.next(); diff --git a/source/tf.js b/source/tf.js index aea8befcbc..435d7f1659 100644 --- a/source/tf.js +++ b/source/tf.js @@ -199,54 +199,51 @@ tf.ModelFactory = class { throw new tf.Error('File text format is not TensorFlow.js graph-model (' + error.message + ').'); } }; - const openTextProto = (context) => { - const tags = context.tags('pbtxt'); - let format = null; - let saved_model = null; - if (tags.has('saved_model_schema_version') || tags.has('meta_graphs')) { - try { - const stream = context.stream; - const reader = protobuf.TextReader.open(stream); - saved_model = tf.proto.tensorflow.SavedModel.decodeText(reader); - format = 'TensorFlow Saved Model'; - if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) { - format = format + ' v' + saved_model.saved_model_schema_version.toString(); - } - } - catch (error) { - throw new tf.Error('File text format is not tensorflow.SavedModel (' + error.message + ').'); - } + const openTextGraphDef = (context) => { + try { + const stream = context.stream; + const reader = protobuf.TextReader.open(stream); + const graph_def = tf.proto.tensorflow.GraphDef.decodeText(reader); + const meta_graph = new tf.proto.tensorflow.MetaGraphDef(); + meta_graph.graph_def = graph_def; + const saved_model = new tf.proto.tensorflow.SavedModel(); + saved_model.meta_graphs.push(meta_graph); + const format = 'TensorFlow Graph'; + return openSavedModel(saved_model, format, null); } - else if (tags.has('graph_def')) { - try { - const stream = context.stream; - const reader = protobuf.TextReader.open(stream); - const meta_graph = tf.proto.tensorflow.MetaGraphDef.decodeText(reader); - saved_model = new tf.proto.tensorflow.SavedModel(); - saved_model.meta_graphs.push(meta_graph); - format = 'TensorFlow MetaGraph'; - } - catch (error) { - throw new tf.Error('File text format is not tensorflow.MetaGraphDef (' + error.message + ').'); - } + catch (error) { + const message = error && error.message ? error.message : error.toString(); + throw new tf.Error('File text format is not tensorflow.GraphDef (' + message.replace(/\.$/, '') + ').'); } - else if (tags.has('node')) { - try { - const stream = context.stream; - const reader = protobuf.TextReader.open(stream); - const graph_def = tf.proto.tensorflow.GraphDef.decodeText(reader); - const meta_graph = new tf.proto.tensorflow.MetaGraphDef(); - meta_graph.graph_def = graph_def; - saved_model = new tf.proto.tensorflow.SavedModel(); - saved_model.meta_graphs.push(meta_graph); - format = 'TensorFlow Graph'; - } - catch (error) { - const message = error && error.message ? error.message : error.toString(); - throw new tf.Error('File text format is not tensorflow.GraphDef (' + message.replace(/\.$/, '') + ').'); + }; + const openTextMetaGraphDef = (context) => { + try { + const stream = context.stream; + const reader = protobuf.TextReader.open(stream); + const meta_graph = tf.proto.tensorflow.MetaGraphDef.decodeText(reader); + const saved_model = new tf.proto.tensorflow.SavedModel(); + saved_model.meta_graphs.push(meta_graph); + const format = 'TensorFlow MetaGraph'; + return openSavedModel(saved_model, format, null); + } + catch (error) { + throw new tf.Error('File text format is not tensorflow.MetaGraphDef (' + error.message + ').'); + } + }; + const openTextSavedModel = (context) => { + try { + const stream = context.stream; + const reader = protobuf.TextReader.open(stream); + const saved_model = tf.proto.tensorflow.SavedModel.decodeText(reader); + let format = 'TensorFlow Saved Model'; + if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) { + format = format + ' v' + saved_model.saved_model_schema_version.toString(); } + return openSavedModel(saved_model, format, null); + } + catch (error) { + throw new tf.Error('File text format is not tensorflow.SavedModel (' + error.message + ').'); } - return openSavedModel(saved_model, format, null); }; const openBinaryProto = (stream, identifier) => { let saved_model = null; @@ -320,8 +317,12 @@ tf.ModelFactory = class { return openEventFile(context); case 'json': return openJson(context); - case 'pbtxt': - return openTextProto(context); + case 'pbtxt.GraphDef': + return openTextGraphDef(context); + case 'pbtxt.MetaGraphDef': + return openTextMetaGraphDef(context); + case 'pbtxt.SavedModel': + return openTextSavedModel(context); case 'pb': return openBinaryProto(context.stream, context.identifier); case 'saved_metadata': @@ -346,12 +347,24 @@ tf.ModelFactory = class { identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) { return ''; } + const stream = context.stream; + const reader = base.TextReader.open(stream.peek(), 65536); + const line = reader.read(); + if (/\s*node\s*\{/.exec(line)) { + return 'pbtxt.GraphDef'; + } const tags = context.tags('pbtxt'); if (['input_stream', 'output_stream', 'input_side_packet', 'output_side_packet'].some((key) => tags.has(key) || tags.has('node.' + key))) { return ''; } - if (tags.has('node') || tags.has('saved_model_schema_version') || tags.has('meta_graphs') || tags.has('graph_def')) { - return 'pbtxt'; + if (tags.has('saved_model_schema_version') || tags.has('meta_graphs')) { + return 'pbtxt.SavedModel'; + } + if (tags.has('graph_def')) { + return 'pbtxt.MetaGraphDef'; + } + if (tags.has('node')) { + return 'pbtxt.GraphDef'; } } if (extension === 'pb' || extension === 'pbtxt' || extension === 'prototxt' || extension === 'graphdef') { @@ -440,8 +453,14 @@ tf.ModelFactory = class { if (['input_stream', 'output_stream', 'input_side_packet', 'output_side_packet'].some((key) => tags.has(key) || tags.has('node.' + key))) { return false; } - if (tags.has('node') || tags.has('saved_model_schema_version') || tags.has('meta_graphs') || tags.has('graph_def')) { - return true; + if (tags.has('node')) { + return 'pbtxt.GraphDef'; + } + if (tags.has('graph_def')) { + return 'pbtxt.MetaGraphDef'; + } + if (tags.has('saved_model_schema_version') || tags.has('meta_graphs')) { + return 'pbtxt.SavedModel'; } } } diff --git a/test/models.json b/test/models.json index 916a4113f1..0f5e89ea8b 100644 --- a/test/models.json +++ b/test/models.json @@ -5249,6 +5249,12 @@ "target": "inception5h.pb", "source": "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip[tensorflow_inception_graph.pb]" }, + { + "type": "tf", + "target": "invalid_graph.pbtxt", + "source": "https://github.com/lutzroeder/netron/files/6859459/invalid_graph.pbtxt.zip[invalid_graph.pbtxt]", + "error": "File text format is not tensorflow.GraphDef (Unexpected end of input at 8:1) in 'invalid_graph.pbtxt'." + }, { "type": "tf", "target": "mask_rcnn_resnet50_atrous_coco_2018_01_28.ckpt.meta",