Skip to content

Commit

Permalink
PyTorch ByteStorage support (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 25, 2018
1 parent e4fbfdf commit 3274714
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 37 deletions.
2 changes: 1 addition & 1 deletion src/coreml-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ class CoreMLTensor {
}
else if (data.rawValue && data.rawValue.length > 0) {
this._data = null;
dataType = 'byte';
dataType = 'uint8';
shape = [];
}
}
Expand Down
19 changes: 12 additions & 7 deletions src/pytorch-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ class PyTorchModelFactory {

match(context, host) {
var extension = context.identifier.split('.').pop();
if (extension == 'pt' || extension == 'pth') {
return true;
}
if (extension == 'pkl') {
if (extension == 'pt' || extension == 'pth' || extension == 'pkl') {
var buffer = context.buffer;
var torch = [ 0x80, 0x02, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
if (buffer && buffer.length > torch.length) {
Expand All @@ -35,6 +32,7 @@ class PyTorchModelFactory {

_openModel(context, host, callback) {
try {
var identifier = context.identifier;
var unpickler = new pickle.Unpickler(context.buffer);

var signature = [ 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
Expand Down Expand Up @@ -110,9 +108,10 @@ class PyTorchModelFactory {
constructorTable['torchvision.models.vgg.VGG'] = function () {};
constructorTable['torch.nn.backends.thnn._get_thnn_function_backend'] = function () {};
constructorTable['torch.nn.parameter.Parameter'] = function(data, requires_grad) { this.data = data; this.requires_grad = requires_grad; };
constructorTable['torch.ByteStorage'] = function (size) { this.size = size; this.dataTypeSize = 1; this.dataType = 'uint8'; };
constructorTable['torch.LongStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'int64'; };
constructorTable['torch.FloatStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'float32'; };
constructorTable['torch.DoubleStorage'] = function (size) { this.size = size; this.dataTypeSize = 8; this.dataType = 'float64'; };
constructorTable['torch.LongStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'int64'; };

functionTable['torch._utils._rebuild_tensor'] = function (storage, storage_offset, size, stride) {
var obj = {};
Expand Down Expand Up @@ -156,7 +155,7 @@ class PyTorchModelFactory {
constructor.apply(obj, args);
}
else {
host.exception(new SklearnError("Unknown function '" + name + "'."), false);
host.exception(new PyTorchError("Unknown function '" + name + "' in '" + identifier + "'."), false);
}
return obj;
};
Expand Down Expand Up @@ -258,7 +257,8 @@ class PyTorchGraph {
_loadModule(parent, groups, inputs) {

if (parent.__type__ &&
!parent.__type__.startsWith('torch.nn.modules.container.')) {
!parent.__type__.startsWith('torch.nn.modules.container.') &&
(!parent._modules || parent._modules.length == 0)) {
var node = new PyTorchNode(parent, groups, inputs);
this._nodes.push(node);
return [];
Expand Down Expand Up @@ -569,6 +569,11 @@ class PyTorchTensor {
}
switch (this._dataType)
{
case 'uint8':
results.push(context.dataView.getUint8(context.index, true));
context.index += 1;
context.count++;
break;
case 'float32':
results.push(context.dataView.getFloat32(context.index, true));
context.index += 4;
Expand Down
3 changes: 2 additions & 1 deletion src/sklearn-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class SklearnModelFactory {

var obj = null;
try {
var identifier = context.identifier;
var unpickler = new pickle.Unpickler(context.buffer);

var constructorTable = {};
Expand Down Expand Up @@ -233,7 +234,7 @@ class SklearnModelFactory {
constructor.apply(obj, args);
}
else {
host.exception(new SklearnError("Unknown function '" + name + "'."), false);
host.exception(new SklearnError("Unknown function '" + name + "' in '" + identifier + "'."), false);
}
return obj;
};
Expand Down
26 changes: 5 additions & 21 deletions src/tflite-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TensorFlowLiteModelFactory {
if (!tflite.Model.bufferHasIdentifier(byteBuffer))
{
var identifier = (buffer && buffer.length >= 8 && buffer.slice(4, 8).every((c) => c >= 32 && c <= 127)) ? String.fromCharCode.apply(null, buffer.slice(4, 8)) : '';
callback(new TensorFlowLiteError("Invalid FlatBuffers identifier '" + identifier + "'."));
callback(new TensorFlowLiteError("Invalid FlatBuffers identifier '" + identifier + "' in '" + context.identifier + "'."));
return;
}
model = tflite.Model.getRootAsModel(byteBuffer);
Expand Down Expand Up @@ -493,7 +493,7 @@ class TensorFlowLiteTensor {
}
switch (context.dataType)
{
case 'byte':
case 'uint8':
results.push(context.data.getUint8(context.index));
context.index += 1;
context.count++;
Expand Down Expand Up @@ -543,7 +543,8 @@ class TensorFlowLiteTensor {
class TensorFlowLiteTensorType {

constructor(tensor) {
this._dataType = tensor.type();
var dataType = tflite.TensorType[tensor.type()];
this._dataType = (dataType) ? dataType.toLowerCase() : '?';
this._shape = [];
var shapeLength = tensor.shapeLength();
if (shapeLength > 0) {
Expand All @@ -554,24 +555,7 @@ class TensorFlowLiteTensorType {
}

get dataType() {
if (!TensorFlowLiteTensorType._typeMap)
{
TensorFlowLiteTensorType._typeMap = {};
TensorFlowLiteTensorType._typeMap[tflite.TensorType.FLOAT32] = 'float32';
TensorFlowLiteTensorType._typeMap[tflite.TensorType.FLOAT16] = 'float16';
TensorFlowLiteTensorType._typeMap[tflite.TensorType.INT32] = 'int32';
TensorFlowLiteTensorType._typeMap[tflite.TensorType.UINT8] = 'byte';
TensorFlowLiteTensorType._typeMap[tflite.TensorType.INT64] = 'int64';
TensorFlowLiteTensorType._typeMap[tflite.TensorType.STRING] = 'string';
TensorFlowLiteTensorType._typeMap[tflite.TensorType.BOOL] = 'bool';
TensorFlowLiteTensorType._typeMap[tflite.TensorType.INT16] = 'int16';
TensorFlowLiteTensorType._typeMap[tflite.TensorType.COMPLEX64] = 'complex64';
}
var result = TensorFlowLiteTensorType._typeMap[this._dataType];
if (result) {
return result;
}
return '?';
return this._dataType;
}

get shape() {
Expand Down
14 changes: 7 additions & 7 deletions src/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class View {
}
if (model || factoryList.length == 0) {
if (!model && factoryCount > 1 && errorList.length > 1) {
callback(new NameError(errorList.map((err) => err.message).join('\n'), "Error loading model."), null);
callback(new ModelError(errorList.map((err) => err.message).join('\n')), null);
}
else {
callback(err, model);
Expand All @@ -322,10 +322,10 @@ class View {
case 'pbtxt':
case 'prototxt':
case 'model':
callback(new NameError('Unsupported file content for extension \'.' + extension + '\'.', "Error loading model."), null);
callback(new ModelError("Unsupported file content for extension '." + extension + "' in '" + context.identifier + "'."), null);
break;
default:
callback(new NameError('Unsupported file extension \'.' + extension + '\'.', "Error loading model."), null);
callback(new ModelError("Unsupported file extension '." + extension + "'."), null);
break;
}
}
Expand Down Expand Up @@ -1158,13 +1158,13 @@ if (!DataView.prototype.setFloat16) {
class ArchiveError extends Error {
constructor(message) {
super(message);
this.name = "Error loading archive";
this.name = 'Error loading archive.';
}
}

class NameError extends Error {
constructor(message, name) {
class ModelError extends Error {
constructor(message) {
super(message);
this.name = name;
this.name = 'Error loading model.';
}
}

0 comments on commit 3274714

Please sign in to comment.