Skip to content

Commit

Permalink
Add TensorFlow.js Gzip support (#294) (#563)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Sep 21, 2021
1 parent ee9bf0b commit 1ea631c
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 9 deletions.
40 changes: 33 additions & 7 deletions source/tf.js
Expand Up @@ -4,6 +4,7 @@

var tf = tf || {};
var base = base || require('./base');
var gzip = gzip || require('./gzip');
var json = json || require('./json');
var protobuf = protobuf || require('./protobuf');

Expand Down Expand Up @@ -173,9 +174,11 @@ tf.ModelFactory = class {
}
}
if (extension === 'json') {
const obj = context.open('json');
if (obj && obj.modelTopology && (obj.format === 'graph-model' || Array.isArray(obj.modelTopology.node))) {
return 'tf.json';
for (const type of [ 'json', 'json.gz' ]) {
const obj = context.open(type);
if (obj && obj.modelTopology && (obj.format === 'graph-model' || Array.isArray(obj.modelTopology.node))) {
return 'tf.' + type;
}
}
}
if (extension === 'index' || extension === 'ckpt') {
Expand Down Expand Up @@ -338,9 +341,9 @@ tf.ModelFactory = class {
}
return openSavedModel(saved_model, format, producer);
};
const openJson = (context) => {
const openJson = (context, type) => {
try {
const obj = context.open('json');
const obj = context.open(type);
const format = 'TensorFlow.js ' + (obj.format || 'graph-model');
const producer = obj.convertedBy || obj.generatedBy || '';
const meta_graph = new tf.proto.tensorflow.MetaGraphDef();
Expand Down Expand Up @@ -399,7 +402,28 @@ tf.ModelFactory = class {
};
return Promise.all(shards.values()).then((streams) => {
for (const key of shards.keys()) {
shards.set(key, streams.shift().peek());
const stream = streams.shift();
const buffer = stream.peek();
shards.set(key, buffer);
}
if (type === 'json.gz') {
try {
for (const key of shards.keys()) {
const stream = shards.get(key);
const archive = gzip.Archive.open(stream);
if (archive) {
const entries = archive.entries;
if (entries.size === 1) {
const stream = entries.values().next().value;
const buffer = stream.peek();
shards.set(key, buffer);
}
}
}
}
catch (error) {
// continue regardless of error
}
}
return openShards(shards);
}).catch(() => {
Expand Down Expand Up @@ -528,7 +552,9 @@ tf.ModelFactory = class {
case 'tf.events':
return openEventFile(context);
case 'tf.json':
return openJson(context);
return openJson(context, 'json');
case 'tf.json.gz':
return openJson(context, 'json.gz');
case 'tf.pbtxt.GraphDef':
return openTextGraphDef(context);
case 'tf.pbtxt.MetaGraphDef':
Expand Down
24 changes: 22 additions & 2 deletions source/view.js
Expand Up @@ -1363,6 +1363,26 @@ view.ModelContext = class {
}
break;
}
case 'json.gz': {
try {
const archive = gzip.Archive.open(stream);
if (archive) {
const entries = archive.entries;
if (entries.size === 1) {
const stream = entries.values().next().value;
const reader = json.TextReader.open(stream);
if (reader) {
const obj = reader.read();
this._content.set(type, obj);
}
}
}
}
catch (err) {
// continue regardless of error
}
break;
}
case 'pkl': {
let unpickler = null;
try {
Expand Down Expand Up @@ -1595,8 +1615,8 @@ view.ModelFactoryService = class {
if (archive) {
const entries = archive.entries;
containers.set('gzip', entries);
if (archive.entries.size === 1) {
stream = archive.entries.values().next().value;
if (entries.size === 1) {
stream = entries.values().next().value;
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Expand Up @@ -5616,6 +5616,13 @@
"format": "TensorFlow.js graph-model",
"link": "https://github.com/intel/webml-polyfill/issues/880"
},
{
"type": "tfjs",
"target": "posenet_mobilenet_float_075_1_default_1.zip",
"source": "https://github.com/lutzroeder/netron/files/7204409/posenet_mobilenet_float_075_1_default_1.zip",
"format": "TensorFlow.js graph-model",
"link": "https://github.com/lutzroeder/netron/issues/294"
},
{
"type": "tfjs",
"target": "sentiment_cnn_v1/model.json",
Expand Down

0 comments on commit 1ea631c

Please sign in to comment.