diff --git a/source/coreml.js b/source/coreml.js index 90ed389f6c..ef33a4a94e 100644 --- a/source/coreml.js +++ b/source/coreml.js @@ -39,17 +39,6 @@ coreml.ModelFactory = class { } } } - if (identifier === 'metadata.json') { - const obj = context.peek('json'); - if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) { - context.type = 'coreml.metadata'; - return; - } - if (Array.isArray(obj) && obj.some((item) => item && item.metadataOutputVersion && item.specificationVersion)) { - context.type = 'coreml.metadata.mlmodelc'; - return; - } - } if (identifier === 'model.mil') { try { const reader = context.read('text', 2048); @@ -69,6 +58,37 @@ coreml.ModelFactory = class { return; } } + if (identifier.endsWith('.espresso.net')) { + const obj = context.peek('json'); + if (obj && Array.isArray(obj.layers) && obj.format_version) { + context.type = 'espresso.net'; + context.target = obj; + return; + } + } + if (identifier.endsWith('.espresso.shape')) { + const obj = context.peek('json'); + if (obj && obj.layer_shapes) { + context.type = 'espresso.shape'; + context.target = obj; + return; + } + } + if (identifier.endsWith('.espresso.weights')) { + context.type = 'espresso.weights'; + return; + } + if (identifier === 'metadata.json') { + const obj = context.peek('json'); + if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) { + context.type = 'coreml.metadata'; + return; + } + if (Array.isArray(obj) && obj.some((item) => item && item.metadataOutputVersion && item.specificationVersion)) { + context.type = 'coreml.metadata.mlmodelc'; + return; + } + } if (extension === 'bin' && stream.length > 16) { const buffer = stream.peek(Math.min(256, stream.length)); for (let i = 0; i < buffer.length - 4; i++) { @@ -82,7 +102,16 @@ coreml.ModelFactory = class { } filter(context, type) { - return context.type !== 'coreml.manifest.mlmodelc' && type !== 'coreml.mil'; + if (context.type === 'coreml.metadata.mlmodelc' && (type === 'coreml.mil')) { + return false; + } + if (context.type === 'espresso.net' && (type === 'espresso.weights' || type === 'espresso.shape' || type === 'coreml.metadata.mlmodelc')) { + return false; + } + if (context.type === 'espresso.shape' && (type === 'espresso.weights' || type === 'coreml.metadata.mlmodelc')) { + return false; + } + return true; } async open(context) { @@ -147,7 +176,10 @@ coreml.ModelFactory = class { // continue regardless of error } } - return new coreml.Model(metadata, format, model, weights); + format = format || 'Core ML'; + format = `${format} v${model.specificationVersion}`; + context = new coreml.Context.Model(metadata, format, model, weights); + return new coreml.Model(context); }; const openText = async (context) => { let model = null; @@ -158,8 +190,9 @@ coreml.ModelFactory = class { const message = error && error.message ? error.message : error.toString(); throw new coreml.Error(`File format is not coreml.Model (${message.replace(/\.$/, '')}).`); } - const weights = new Map(); - return new coreml.Model(metadata, null, model, weights); + const format = `Core ML v${model.specificationVersion}`; + context = new coreml.Context.Model(metadata, format, model); + return new coreml.Model(context, null); }; const openManifest = async (obj, context, path) => { const entries = Object.values(obj.itemInfoEntries).filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel')); @@ -200,6 +233,21 @@ coreml.ModelFactory = class { case 'coreml.weights': { return openManifestStream(context, '../../../'); } + case 'espresso.net': { + const reader = new coreml.Context.Espresso(metadata, context.target, null, null); + await reader.read(context); + return new coreml.Model(reader); + } + case 'espresso.weights': { + const reader = new coreml.Context.Espresso(metadata, null, context.target, null); + await reader.read(context); + return new coreml.Model(reader); + } + case 'espresso.shape': { + const reader = new coreml.Context.Espresso(metadata, null, null, context.target); + await reader.read(context); + return new coreml.Model(reader); + } default: { throw new coreml.Error(`Unsupported Core ML format '${context.type}'.`); } @@ -209,29 +257,18 @@ coreml.ModelFactory = class { coreml.Model = class { - constructor(metadata, format, model, weights) { - this.format = `${format || 'Core ML'} v${model.specificationVersion}`; + constructor(context) { + this.format = context.format; this.metadata = []; - const context = new coreml.Context(metadata, model, weights); - const graph = new coreml.Graph(context); - this.graphs = [graph]; - if (model.description && model.description.metadata) { - const properties = model.description.metadata; - if (properties.versionString) { - this.version = properties.versionString; - } - if (properties.shortDescription) { - this.description = properties.shortDescription; - } - if (properties.author) { - this.metadata.push(new coreml.Argument('author', properties.author)); - } - if (properties.license) { - this.metadata.push(new coreml.Argument('license', properties.license)); - } - if (metadata.userDefined && Object.keys(properties.userDefined).length > 0) { - /* empty */ - } + this.graphs = [new coreml.Graph(context)]; + if (context.version) { + this.version = context.version; + } + if (context.description) { + this.description = context.description; + } + for (const argument of context.properties) { + this.metadata.push(argument); } } }; @@ -247,16 +284,16 @@ coreml.Graph = class { const type = value.type; const description = value.description; const initializer = value.initializer; - if (!value.obj) { - value.obj = new coreml.Value(name, type, description, initializer); + if (!value.value) { + value.value = new coreml.Value(name, type, description, initializer); } } this.inputs = context.inputs.map((argument) => { - const values = argument.value.map((value) => value.obj); + const values = argument.value.map((value) => value.value); return new coreml.Argument(argument.name, values, null, argument.visible); }); this.outputs = context.outputs.map((argument) => { - const values = argument.value.map((value) => value.obj); + const values = argument.value.map((value) => value.value); return new coreml.Argument(argument.name, values, null, argument.visible); }); for (const obj of context.nodes) { @@ -317,11 +354,11 @@ coreml.Node = class { this.name = obj.name || ''; this.description = obj.description || ''; this.inputs = (obj.inputs || []).map((argument) => { - const values = argument.value.map((value) => value.obj); + const values = argument.value.map((value) => value.value); return new coreml.Argument(argument.name, values, null, argument.visible); }); this.outputs = (obj.outputs || []).map((argument) => { - const values = argument.value.map((value) => value.obj); + const values = argument.value.map((value) => value.value); return new coreml.Argument(argument.name, values, null, argument.visible); }); this.attributes = Object.entries(obj.attributes).map(([name, value]) => { @@ -522,11 +559,15 @@ coreml.OptionalType = class { } }; -coreml.Context = class { +coreml.Context = {}; + +coreml.Context.Model = class { - constructor(metadata, model, weights, values) { + constructor(metadata, format, model, weights, values) { this.metadata = metadata; - this.weights = weights; + this.properties = []; + this.format = format; + this.weights = weights || new Map(); this.values = values || new Map(); this.nodes = []; this.inputs = []; @@ -546,11 +587,29 @@ coreml.Context = class { this.update(value, description); this.outputs.push({ name: description.name, visible: true, value: [value] }); } + if (description && description.metadata) { + const properties = description.metadata; + if (properties.versionString) { + this.version = properties.versionString; + } + if (properties.shortDescription) { + this.description = properties.shortDescription; + } + if (properties.author) { + this.properties.push(new coreml.Argument('author', properties.author)); + } + if (properties.license) { + this.properties.push(new coreml.Argument('license', properties.license)); + } + if (properties.userDefined && Object.keys(properties.userDefined).length > 0) { + /* empty */ + } + } } } context() { - return new coreml.Context(this.metadata, null, this.weights, this.values); + return new coreml.Context.Model(this.metadata, this.format, null, this.weights, this.values); } network(obj) { @@ -650,7 +709,7 @@ coreml.Context = class { const tensor = new coreml.Tensor(tensorType, values, quantization, 'Weights'); const input = this.metadata.input(type, name); const visible = input && input.visible === false ? false : true; - const value = { obj: new coreml.Value('', null, null, tensor) }; + const value = { value: new coreml.Value('', null, null, tensor) }; initializers.push({ name, visible, value: [value] }); }; const vector = (value) => { @@ -1490,6 +1549,79 @@ coreml.Utility = class { } }; +coreml.Context.Espresso = class { + + constructor(metadata, format, net, weights, shape) { + this.metadata = metadata; + this.format = 'Espresso'; + this.properties = []; + this.inputs = []; + this.outputs = []; + this.nodes = []; + this.net = net; + this.weights = weights; + this.shape = shape; + const values = new Map(); + values.map = (name) => { + if (!values.has(name)) { + values.set(name, { name }); + } + return values.get(name); + }; + this.values = values; + } + + async read(context) { + if (!this.net) { + const name = context.identifier.replace(/\.espresso\.(net|weights|shape)$/i, '.espresso.net'); + const content = await context.fetch(name); + this.net = content.read('json'); + } + if (!this.weights) { + const name = context.identifier.replace(/\.espresso\.(net|weights|shape)$/i, '.espresso.weights'); + try { + const content = await context.fetch(name); + this.weights = content.stream; + } catch { + // continue regardless of error + } + } + if (!this.shape) { + const name = context.identifier.replace(/\.espresso\.(net|weights|shape)$/i, '.espresso.shape'); + try { + const content = await context.fetch(name); + this.shape = content.stream; + } catch { + // continue regardless of error + } + } + if (this.net && this.net.format_version) { + const major = Math.floor(this.net.format_version / 100); + const minor = this.net.format_version % 100; + this.format += ` v${major}.${minor}`; + } + if (this.net && Array.isArray(this.net.layers)) { + for (const layer of this.net.layers) { + const top = layer.top.split(',').map((name) => this.values.map(name)); + const bottom = layer.bottom.split(',').map((name) => this.values.map(name)); + const obj = { + type: layer.type, + name: layer.name, + attributes: { ...layer }, + inputs: [{ name: 'top', value: bottom }], + outputs: [{ name: 'bottom', value: top }] + }; + delete obj.attributes.name; + delete obj.attributes.type; + delete obj.attributes.top; + delete obj.attributes.bottom; + delete obj.attributes.weights; + this.nodes.push(obj); + } + } + } +}; + coreml.Error = class extends Error { constructor(message) { super(message); diff --git a/source/view.js b/source/view.js index cf5c89e8bf..1d3262e592 100644 --- a/source/view.js +++ b/source/view.js @@ -5566,7 +5566,7 @@ view.ModelFactoryService = class { this.register('./onnx', ['.onnx', '.onnx.data', '.onn', '.pb', '.onnxtxt', '.pbtxt', '.prototxt', '.txt', '.model', '.pt', '.pth', '.pkl', '.ort', '.ort.onnx', '.ngf', '.json', '.bin', 'onnxmodel']); this.register('./tflite', ['.tflite', '.lite', '.tfl', '.bin', '.pb', '.tmfile', '.h5', '.model', '.json', '.txt', '.dat', '.nb', '.ckpt']); this.register('./mxnet', ['.json', '.params'], ['.mar']); - this.register('./coreml', ['.mlmodel', '.bin', 'manifest.json', 'metadata.json', 'featuredescriptions.json', '.pb', '.pbtxt', '.mil'], ['.mlpackage', '.mlmodelc']); + this.register('./coreml', ['.mlmodel', '.bin', 'manifest.json', 'metadata.json', 'featuredescriptions.json', '.pb', '.pbtxt', '.mil', '.espresso.net', '.espresso.shape', '.espresso.weights'], ['.mlpackage', '.mlmodelc']); this.register('./caffe', ['.caffemodel', '.pbtxt', '.prototxt', '.pt', '.txt']); this.register('./caffe2', ['.pb', '.pbtxt', '.prototxt']); this.register('./torch', ['.t7', '.net']); diff --git a/test/models.json b/test/models.json index 28123a759d..816e99844b 100644 --- a/test/models.json +++ b/test/models.json @@ -1551,8 +1551,7 @@ "type": "coreml", "target": "segmentation.mlmodelc.zip", "source": "https://github.com/user-attachments/files/15315833/segmentation.mlmodelc.zip", - "error": "Core ML Model Archive format is not supported.", - "format": "Core ML Model Archive", + "format": "Espresso v2.0", "link": "https://github.com/lutzroeder/netron/issues/193" }, @@ -1583,6 +1582,13 @@ "format": "Core ML v1", "link": "https://developer.apple.com/machine-learning/models" }, + { + "type": "coreml", + "target": "SqueezeNet.mlmodelc.zip", + "source": "https://github.com/user-attachments/files/15535445/SqueezeNet.mlmodelc.zip", + "format": "Espresso v2.0", + "link": "https://github.com/lutzroeder/netron/issues/193" + }, { "type": "coreml", "target": "SqueezeNetFP16.mlmodel",