Skip to content

Commit

Permalink
Update TensorFlow .pbtxt detection (#774)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jul 22, 2021
1 parent 33e95f0 commit bfa29fd
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 56 deletions.
25 changes: 19 additions & 6 deletions source/protobuf.js
Expand Up @@ -500,6 +500,7 @@ protobuf.TextReader = class {
static open(data) {
const buffer = data instanceof Uint8Array ? data : data.peek();
const decoder = base.TextDecoder.open(buffer);
let first = true;
for (let i = 0; i < 0x100; i++) {
const c = decoder.decode();
if (c === undefined || c === '\0') {
Expand All @@ -508,10 +509,12 @@ protobuf.TextReader = class {
}
break;
}
if (c < ' ' && c !== '\n' && c !== '\r' && c !== '\t') {
const whitespace = c === ' ' || c === '\n' || c === '\r' || c === '\t';
if (c < ' ' && !whitespace) {
return null;
}
if (i === 0) {
if (first && !whitespace) {
first = false;
if (c === '#' || c === '[') {
continue;
}
Expand All @@ -534,7 +537,7 @@ protobuf.TextReader = class {
this.reset();
try {
this.start(false);
while (!this.end(false)) {
while (!this.end()) {
const tag = this.tag();
tags.set(tag, true);
if (this.token() === '{') {
Expand Down Expand Up @@ -576,13 +579,23 @@ protobuf.TextReader = class {
}

end() {
if (this._depth > 0 && this._token === '}') {
if (this._depth <= 0) {
throw new protobuf.Error('Invalid depth ' + this.location());
}
if (this._token === '}') {
this.expect('}');
this.match(';');
this._depth--;
return true;
}
return this._token === undefined;
if (this._token === undefined) {
if (this._depth !== 1) {
throw new protobuf.Error('Unexpected end of input' + this.location());
}
this._depth--;
return true;
}
return false;
}

tag() {
Expand Down Expand Up @@ -728,7 +741,7 @@ protobuf.TextReader = class {
else {
value = Number.parseInt(token, 10);
if (Number.isNaN(token - value)) {
throw new protobuf.Error("Couldn't parse enum '" + token + "'" + this.location());
throw new protobuf.Error("Couldn't parse enum '" + (token === undefined ? '' : token) + "'" + this.location());
}
}
this.next();
Expand Down
119 changes: 69 additions & 50 deletions source/tf.js
Expand Up @@ -199,54 +199,51 @@ tf.ModelFactory = class {
throw new tf.Error('File text format is not TensorFlow.js graph-model (' + error.message + ').');
}
};
const openTextProto = (context) => {
const tags = context.tags('pbtxt');
let format = null;
let saved_model = null;
if (tags.has('saved_model_schema_version') || tags.has('meta_graphs')) {
try {
const stream = context.stream;
const reader = protobuf.TextReader.open(stream);
saved_model = tf.proto.tensorflow.SavedModel.decodeText(reader);
format = 'TensorFlow Saved Model';
if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) {
format = format + ' v' + saved_model.saved_model_schema_version.toString();
}
}
catch (error) {
throw new tf.Error('File text format is not tensorflow.SavedModel (' + error.message + ').');
}
const openTextGraphDef = (context) => {
try {
const stream = context.stream;
const reader = protobuf.TextReader.open(stream);
const graph_def = tf.proto.tensorflow.GraphDef.decodeText(reader);
const meta_graph = new tf.proto.tensorflow.MetaGraphDef();
meta_graph.graph_def = graph_def;
const saved_model = new tf.proto.tensorflow.SavedModel();
saved_model.meta_graphs.push(meta_graph);
const format = 'TensorFlow Graph';
return openSavedModel(saved_model, format, null);
}
else if (tags.has('graph_def')) {
try {
const stream = context.stream;
const reader = protobuf.TextReader.open(stream);
const meta_graph = tf.proto.tensorflow.MetaGraphDef.decodeText(reader);
saved_model = new tf.proto.tensorflow.SavedModel();
saved_model.meta_graphs.push(meta_graph);
format = 'TensorFlow MetaGraph';
}
catch (error) {
throw new tf.Error('File text format is not tensorflow.MetaGraphDef (' + error.message + ').');
}
catch (error) {
const message = error && error.message ? error.message : error.toString();
throw new tf.Error('File text format is not tensorflow.GraphDef (' + message.replace(/\.$/, '') + ').');
}
else if (tags.has('node')) {
try {
const stream = context.stream;
const reader = protobuf.TextReader.open(stream);
const graph_def = tf.proto.tensorflow.GraphDef.decodeText(reader);
const meta_graph = new tf.proto.tensorflow.MetaGraphDef();
meta_graph.graph_def = graph_def;
saved_model = new tf.proto.tensorflow.SavedModel();
saved_model.meta_graphs.push(meta_graph);
format = 'TensorFlow Graph';
}
catch (error) {
const message = error && error.message ? error.message : error.toString();
throw new tf.Error('File text format is not tensorflow.GraphDef (' + message.replace(/\.$/, '') + ').');
};
const openTextMetaGraphDef = (context) => {
try {
const stream = context.stream;
const reader = protobuf.TextReader.open(stream);
const meta_graph = tf.proto.tensorflow.MetaGraphDef.decodeText(reader);
const saved_model = new tf.proto.tensorflow.SavedModel();
saved_model.meta_graphs.push(meta_graph);
const format = 'TensorFlow MetaGraph';
return openSavedModel(saved_model, format, null);
}
catch (error) {
throw new tf.Error('File text format is not tensorflow.MetaGraphDef (' + error.message + ').');
}
};
const openTextSavedModel = (context) => {
try {
const stream = context.stream;
const reader = protobuf.TextReader.open(stream);
const saved_model = tf.proto.tensorflow.SavedModel.decodeText(reader);
let format = 'TensorFlow Saved Model';
if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) {
format = format + ' v' + saved_model.saved_model_schema_version.toString();
}
return openSavedModel(saved_model, format, null);
}
catch (error) {
throw new tf.Error('File text format is not tensorflow.SavedModel (' + error.message + ').');
}
return openSavedModel(saved_model, format, null);
};
const openBinaryProto = (stream, identifier) => {
let saved_model = null;
Expand Down Expand Up @@ -320,8 +317,12 @@ tf.ModelFactory = class {
return openEventFile(context);
case 'json':
return openJson(context);
case 'pbtxt':
return openTextProto(context);
case 'pbtxt.GraphDef':
return openTextGraphDef(context);
case 'pbtxt.MetaGraphDef':
return openTextMetaGraphDef(context);
case 'pbtxt.SavedModel':
return openTextSavedModel(context);
case 'pb':
return openBinaryProto(context.stream, context.identifier);
case 'saved_metadata':
Expand All @@ -346,12 +347,24 @@ tf.ModelFactory = class {
identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
return '';
}
const stream = context.stream;
const reader = base.TextReader.open(stream.peek(), 65536);
const line = reader.read();
if (/\s*node\s*\{/.exec(line)) {
return 'pbtxt.GraphDef';
}
const tags = context.tags('pbtxt');
if (['input_stream', 'output_stream', 'input_side_packet', 'output_side_packet'].some((key) => tags.has(key) || tags.has('node.' + key))) {
return '';
}
if (tags.has('node') || tags.has('saved_model_schema_version') || tags.has('meta_graphs') || tags.has('graph_def')) {
return 'pbtxt';
if (tags.has('saved_model_schema_version') || tags.has('meta_graphs')) {
return 'pbtxt.SavedModel';
}
if (tags.has('graph_def')) {
return 'pbtxt.MetaGraphDef';
}
if (tags.has('node')) {
return 'pbtxt.GraphDef';
}
}
if (extension === 'pb' || extension === 'pbtxt' || extension === 'prototxt' || extension === 'graphdef') {
Expand Down Expand Up @@ -440,8 +453,14 @@ tf.ModelFactory = class {
if (['input_stream', 'output_stream', 'input_side_packet', 'output_side_packet'].some((key) => tags.has(key) || tags.has('node.' + key))) {
return false;
}
if (tags.has('node') || tags.has('saved_model_schema_version') || tags.has('meta_graphs') || tags.has('graph_def')) {
return true;
if (tags.has('node')) {
return 'pbtxt.GraphDef';
}
if (tags.has('graph_def')) {
return 'pbtxt.MetaGraphDef';
}
if (tags.has('saved_model_schema_version') || tags.has('meta_graphs')) {
return 'pbtxt.SavedModel';
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions test/models.json
Expand Up @@ -5249,6 +5249,12 @@
"target": "inception5h.pb",
"source": "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip[tensorflow_inception_graph.pb]"
},
{
"type": "tf",
"target": "invalid_graph.pbtxt",
"source": "https://github.com/lutzroeder/netron/files/6859459/invalid_graph.pbtxt.zip[invalid_graph.pbtxt]",
"error": "File text format is not tensorflow.GraphDef (Unexpected end of input at 8:1) in 'invalid_graph.pbtxt'."
},
{
"type": "tf",
"target": "mask_rcnn_resnet50_atrous_coco_2018_01_28.ckpt.meta",
Expand Down

0 comments on commit bfa29fd

Please sign in to comment.