Skip to content

Commit

Permalink
Add Espresso test file (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jun 4, 2024
1 parent 78be4a9 commit 79451fc
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 51 deletions.
228 changes: 180 additions & 48 deletions source/coreml.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,6 @@ coreml.ModelFactory = class {
}
}
}
if (identifier === 'metadata.json') {
const obj = context.peek('json');
if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
context.type = 'coreml.metadata';
return;
}
if (Array.isArray(obj) && obj.some((item) => item && item.metadataOutputVersion && item.specificationVersion)) {
context.type = 'coreml.metadata.mlmodelc';
return;
}
}
if (identifier === 'model.mil') {
try {
const reader = context.read('text', 2048);
Expand All @@ -69,6 +58,37 @@ coreml.ModelFactory = class {
return;
}
}
if (identifier.endsWith('.espresso.net')) {
const obj = context.peek('json');
if (obj && Array.isArray(obj.layers) && obj.format_version) {
context.type = 'espresso.net';
context.target = obj;
return;
}
}
if (identifier.endsWith('.espresso.shape')) {
const obj = context.peek('json');
if (obj && obj.layer_shapes) {
context.type = 'espresso.shape';
context.target = obj;
return;
}
}
if (identifier.endsWith('.espresso.weights')) {
context.type = 'espresso.weights';
return;
}
if (identifier === 'metadata.json') {
const obj = context.peek('json');
if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
context.type = 'coreml.metadata';
return;
}
if (Array.isArray(obj) && obj.some((item) => item && item.metadataOutputVersion && item.specificationVersion)) {
context.type = 'coreml.metadata.mlmodelc';
return;
}
}
if (extension === 'bin' && stream.length > 16) {
const buffer = stream.peek(Math.min(256, stream.length));
for (let i = 0; i < buffer.length - 4; i++) {
Expand All @@ -82,7 +102,16 @@ coreml.ModelFactory = class {
}

filter(context, type) {
return context.type !== 'coreml.manifest.mlmodelc' && type !== 'coreml.mil';
if (context.type === 'coreml.metadata.mlmodelc' && (type === 'coreml.mil')) {
return false;
}
if (context.type === 'espresso.net' && (type === 'espresso.weights' || type === 'espresso.shape' || type === 'coreml.metadata.mlmodelc')) {
return false;
}
if (context.type === 'espresso.shape' && (type === 'espresso.weights' || type === 'coreml.metadata.mlmodelc')) {
return false;
}
return true;
}

async open(context) {
Expand Down Expand Up @@ -147,7 +176,10 @@ coreml.ModelFactory = class {
// continue regardless of error
}
}
return new coreml.Model(metadata, format, model, weights);
format = format || 'Core ML';
format = `${format} v${model.specificationVersion}`;
context = new coreml.Context.Model(metadata, format, model, weights);
return new coreml.Model(context);
};
const openText = async (context) => {
let model = null;
Expand All @@ -158,8 +190,9 @@ coreml.ModelFactory = class {
const message = error && error.message ? error.message : error.toString();
throw new coreml.Error(`File format is not coreml.Model (${message.replace(/\.$/, '')}).`);
}
const weights = new Map();
return new coreml.Model(metadata, null, model, weights);
const format = `Core ML v${model.specificationVersion}`;
context = new coreml.Context.Model(metadata, format, model);
return new coreml.Model(context, null);
};
const openManifest = async (obj, context, path) => {
const entries = Object.values(obj.itemInfoEntries).filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel'));
Expand Down Expand Up @@ -200,6 +233,21 @@ coreml.ModelFactory = class {
case 'coreml.weights': {
return openManifestStream(context, '../../../');
}
case 'espresso.net': {
const reader = new coreml.Context.Espresso(metadata, context.target, null, null);
await reader.read(context);
return new coreml.Model(reader);
}
case 'espresso.weights': {
const reader = new coreml.Context.Espresso(metadata, null, context.target, null);
await reader.read(context);
return new coreml.Model(reader);
}
case 'espresso.shape': {
const reader = new coreml.Context.Espresso(metadata, null, null, context.target);
await reader.read(context);
return new coreml.Model(reader);
}
default: {
throw new coreml.Error(`Unsupported Core ML format '${context.type}'.`);
}
Expand All @@ -209,29 +257,18 @@ coreml.ModelFactory = class {

coreml.Model = class {

constructor(metadata, format, model, weights) {
this.format = `${format || 'Core ML'} v${model.specificationVersion}`;
constructor(context) {
this.format = context.format;
this.metadata = [];
const context = new coreml.Context(metadata, model, weights);
const graph = new coreml.Graph(context);
this.graphs = [graph];
if (model.description && model.description.metadata) {
const properties = model.description.metadata;
if (properties.versionString) {
this.version = properties.versionString;
}
if (properties.shortDescription) {
this.description = properties.shortDescription;
}
if (properties.author) {
this.metadata.push(new coreml.Argument('author', properties.author));
}
if (properties.license) {
this.metadata.push(new coreml.Argument('license', properties.license));
}
if (metadata.userDefined && Object.keys(properties.userDefined).length > 0) {
/* empty */
}
this.graphs = [new coreml.Graph(context)];
if (context.version) {
this.version = context.version;
}
if (context.description) {
this.description = context.description;
}
for (const argument of context.properties) {
this.metadata.push(argument);
}
}
};
Expand All @@ -247,16 +284,16 @@ coreml.Graph = class {
const type = value.type;
const description = value.description;
const initializer = value.initializer;
if (!value.obj) {
value.obj = new coreml.Value(name, type, description, initializer);
if (!value.value) {
value.value = new coreml.Value(name, type, description, initializer);
}
}
this.inputs = context.inputs.map((argument) => {
const values = argument.value.map((value) => value.obj);
const values = argument.value.map((value) => value.value);
return new coreml.Argument(argument.name, values, null, argument.visible);
});
this.outputs = context.outputs.map((argument) => {
const values = argument.value.map((value) => value.obj);
const values = argument.value.map((value) => value.value);
return new coreml.Argument(argument.name, values, null, argument.visible);
});
for (const obj of context.nodes) {
Expand Down Expand Up @@ -317,11 +354,11 @@ coreml.Node = class {
this.name = obj.name || '';
this.description = obj.description || '';
this.inputs = (obj.inputs || []).map((argument) => {
const values = argument.value.map((value) => value.obj);
const values = argument.value.map((value) => value.value);
return new coreml.Argument(argument.name, values, null, argument.visible);
});
this.outputs = (obj.outputs || []).map((argument) => {
const values = argument.value.map((value) => value.obj);
const values = argument.value.map((value) => value.value);
return new coreml.Argument(argument.name, values, null, argument.visible);
});
this.attributes = Object.entries(obj.attributes).map(([name, value]) => {
Expand Down Expand Up @@ -522,11 +559,15 @@ coreml.OptionalType = class {
}
};

coreml.Context = class {
coreml.Context = {};

coreml.Context.Model = class {

constructor(metadata, model, weights, values) {
constructor(metadata, format, model, weights, values) {
this.metadata = metadata;
this.weights = weights;
this.properties = [];
this.format = format;
this.weights = weights || new Map();
this.values = values || new Map();
this.nodes = [];
this.inputs = [];
Expand All @@ -546,11 +587,29 @@ coreml.Context = class {
this.update(value, description);
this.outputs.push({ name: description.name, visible: true, value: [value] });
}
if (description && description.metadata) {
const properties = description.metadata;
if (properties.versionString) {
this.version = properties.versionString;
}
if (properties.shortDescription) {
this.description = properties.shortDescription;
}
if (properties.author) {
this.properties.push(new coreml.Argument('author', properties.author));
}
if (properties.license) {
this.properties.push(new coreml.Argument('license', properties.license));
}
if (properties.userDefined && Object.keys(properties.userDefined).length > 0) {
/* empty */
}
}
}
}

context() {
return new coreml.Context(this.metadata, null, this.weights, this.values);
return new coreml.Context.Model(this.metadata, this.format, null, this.weights, this.values);
}

network(obj) {
Expand Down Expand Up @@ -650,7 +709,7 @@ coreml.Context = class {
const tensor = new coreml.Tensor(tensorType, values, quantization, 'Weights');
const input = this.metadata.input(type, name);
const visible = input && input.visible === false ? false : true;
const value = { obj: new coreml.Value('', null, null, tensor) };
const value = { value: new coreml.Value('', null, null, tensor) };
initializers.push({ name, visible, value: [value] });
};
const vector = (value) => {
Expand Down Expand Up @@ -1490,6 +1549,79 @@ coreml.Utility = class {
}
};

coreml.Context.Espresso = class {

constructor(metadata, format, net, weights, shape) {
this.metadata = metadata;
this.format = 'Espresso';
this.properties = [];
this.inputs = [];
this.outputs = [];
this.nodes = [];
this.net = net;
this.weights = weights;
this.shape = shape;
const values = new Map();
values.map = (name) => {
if (!values.has(name)) {
values.set(name, { name });
}
return values.get(name);
};
this.values = values;
}

async read(context) {
if (!this.net) {
const name = context.identifier.replace(/\.espresso\.(net|weights|shape)$/i, '.espresso.net');
const content = await context.fetch(name);
this.net = content.read('json');
}
if (!this.weights) {
const name = context.identifier.replace(/\.espresso\.(net|weights|shape)$/i, '.espresso.weights');
try {
const content = await context.fetch(name);
this.weights = content.stream;
} catch {
// continue regardless of error
}
}
if (!this.shape) {
const name = context.identifier.replace(/\.espresso\.(net|weights|shape)$/i, '.espresso.shape');
try {
const content = await context.fetch(name);
this.shape = content.stream;
} catch {
// continue regardless of error
}
}
if (this.net && this.net.format_version) {
const major = Math.floor(this.net.format_version / 100);
const minor = this.net.format_version % 100;
this.format += ` v${major}.${minor}`;
}
if (this.net && Array.isArray(this.net.layers)) {
for (const layer of this.net.layers) {
const top = layer.top.split(',').map((name) => this.values.map(name));
const bottom = layer.bottom.split(',').map((name) => this.values.map(name));
const obj = {
type: layer.type,
name: layer.name,
attributes: { ...layer },
inputs: [{ name: 'top', value: bottom }],
outputs: [{ name: 'bottom', value: top }]
};
delete obj.attributes.name;
delete obj.attributes.type;
delete obj.attributes.top;
delete obj.attributes.bottom;
delete obj.attributes.weights;
this.nodes.push(obj);
}
}
}
};

coreml.Error = class extends Error {
constructor(message) {
super(message);
Expand Down
2 changes: 1 addition & 1 deletion source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -5566,7 +5566,7 @@ view.ModelFactoryService = class {
this.register('./onnx', ['.onnx', '.onnx.data', '.onn', '.pb', '.onnxtxt', '.pbtxt', '.prototxt', '.txt', '.model', '.pt', '.pth', '.pkl', '.ort', '.ort.onnx', '.ngf', '.json', '.bin', 'onnxmodel']);
this.register('./tflite', ['.tflite', '.lite', '.tfl', '.bin', '.pb', '.tmfile', '.h5', '.model', '.json', '.txt', '.dat', '.nb', '.ckpt']);
this.register('./mxnet', ['.json', '.params'], ['.mar']);
this.register('./coreml', ['.mlmodel', '.bin', 'manifest.json', 'metadata.json', 'featuredescriptions.json', '.pb', '.pbtxt', '.mil'], ['.mlpackage', '.mlmodelc']);
this.register('./coreml', ['.mlmodel', '.bin', 'manifest.json', 'metadata.json', 'featuredescriptions.json', '.pb', '.pbtxt', '.mil', '.espresso.net', '.espresso.shape', '.espresso.weights'], ['.mlpackage', '.mlmodelc']);
this.register('./caffe', ['.caffemodel', '.pbtxt', '.prototxt', '.pt', '.txt']);
this.register('./caffe2', ['.pb', '.pbtxt', '.prototxt']);
this.register('./torch', ['.t7', '.net']);
Expand Down
10 changes: 8 additions & 2 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -1551,8 +1551,7 @@
"type": "coreml",
"target": "segmentation.mlmodelc.zip",
"source": "https://github.com/user-attachments/files/15315833/segmentation.mlmodelc.zip",
"error": "Core ML Model Archive format is not supported.",
"format": "Core ML Model Archive",
"format": "Espresso v2.0",
"link": "https://github.com/lutzroeder/netron/issues/193"
},

Expand Down Expand Up @@ -1583,6 +1582,13 @@
"format": "Core ML v1",
"link": "https://developer.apple.com/machine-learning/models"
},
{
"type": "coreml",
"target": "SqueezeNet.mlmodelc.zip",
"source": "https://github.com/user-attachments/files/15535445/SqueezeNet.mlmodelc.zip",
"format": "Espresso v2.0",
"link": "https://github.com/lutzroeder/netron/issues/193"
},
{
"type": "coreml",
"target": "SqueezeNetFP16.mlmodel",
Expand Down

0 comments on commit 79451fc

Please sign in to comment.