From 901628eab5a52cbb134e4468cf94524d44a1af03 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Fri, 9 Aug 2019 22:16:45 -0700 Subject: [PATCH] PyTorch state dict detection (#133) --- src/pytorch.js | 274 ++++++++++++++++++++++++++++++++++------------- test/models.json | 12 +++ 2 files changed, 211 insertions(+), 75 deletions(-) diff --git a/src/pytorch.js b/src/pytorch.js index 4fabc0a856..7128cbe6fa 100644 --- a/src/pytorch.js +++ b/src/pytorch.js @@ -547,32 +547,7 @@ pytorch.ModelFactory = class { }); root_module = null; - state_dict = []; - - if (obj && !Array.isArray(obj)) { - var array = []; - for (var key of Object.keys(obj)) { - array.push({ key: key, value: obj[key] }); - } - obj = array; - } - if (obj && Array.isArray(obj)) { - for (var item of obj) { - var value = null; - if (item && item.value && item.value.__type__ == 'torch.nn.parameter.Parameter') { - value = item.value[0]; - } - else if (item && item.value && - item.value.__type__.startsWith('torch.') && - item.value.__type__.endsWith('Tensor')) { - value = item.value; - } - else { - value = null; - } - state_dict.push({ key: item.key, value: value }); - } - } + state_dict = pytorch.ModelFactory._convertStateDictLegacyFormat(obj); } if (sys_info.type_sizes && @@ -666,28 +641,186 @@ pytorch.ModelFactory = class { root.state_dict_stylepredictor, root.state_dict_ghiasi ]; for (var dict of candidates) { - if (dict && Array.isArray(dict) && dict.__setitem__ && - dict.every((item) => item.value.__type__ && item.value.__type__.startsWith('torch.') && item.value.__type__.endsWith('Tensor'))) { - delete dict.__setitem__; - return dict; + let state_dict = + pytorch.ModelFactory._convertStateDictList(dict) || + pytorch.ModelFactory._convertStateDictMap(dict) || + pytorch.ModelFactory._convertStateDictGroupMap(dict); + if (state_dict) { + return state_dict; } - if (dict && !Array.isArray(dict)) { - var match = true; - var array = []; - for (var key of Object.keys(dict)) { - var value = dict[key] - if (!key || !value || !value.__type__ || !value.__type__.startsWith('torch.') || !value.__type__.endsWith('Tensor')) { - match = false; - break; + } + return null; + } + + static _convertStateDictList(list) { + if (!list || !Array.isArray(list) || + !list.every((item) => item && item.key && pytorch.ModelFactory._isTensor(item.value))) { + return null; + } + let state_dict = []; + let state_map = {}; + for (let item of list) { + let split = item.key.split('.'); + if (split.length < 2) { + return null; + } + let state = {}; + state.id = item.key; + state.name = split.pop(); + state.value = item.value; + let state_group_name = split.join('.'); + let state_group = state_map[state_group_name]; + if (!state_group) { + state_group = {}; + state_group.name = state_group_name; + state_group.states = []; + state_map[state_group_name] = state_group; + state_dict.push(state_group); + } + state_group.states.push(state); + } + return state_dict; + } + + static _convertStateDictMap(obj) { + if (!obj || Array.isArray(obj)) { + return null + } + let state_dict = []; + let state_map = {}; + for (var key in obj) { + let split = key.split('.'); + if (split.length < 1) { + return null; + } + let state = {}; + state.id = key; + state.name = split.pop(); + state.value = obj[key]; + if (!pytorch.ModelFactory._isTensor(state.value)) { + return null; + } + let state_group_name = split.join('.'); + let state_group = state_map[state_group_name]; + if (!state_group) { + state_group = {}; + state_group.name = state_group_name; + state_group.states = []; + state_map[state_group_name] = state_group; + state_dict.push(state_group); + } + state_group.states.push(state); + } + return state_dict; + } + + static _convertStateDictGroupMap(obj) { + if (!obj || Array.isArray(obj)) { + return null; + } + let state_dict = []; + let state_map = {}; + for (let state_group_name in obj) { + + let state_group = state_map[state_group_name]; + if (!state_group) { + state_group = {}; + state_group.name = state_group_name; + state_group.states = []; + state_group.attributes = []; + state_map[state_group_name] = state_group; + state_dict.push(state_group); + } + var item = obj[state_group_name]; + if (!item) { + return null; + } + if (Array.isArray(item)) { + for (let entry of item) { + if (!entry || !entry.key || !entry.value || !pytorch.ModelFactory._isTensor(entry.value)) { + return null; } - array.push({ key: key, value: value }); + let state = {}; + state.id = state_group_name + '.' + entry.key; + state.name = entry.key; + state.value = entry.value; + state_group.states.push(state); } - if (match) { - return array; + } + else { + for (let key in item) { + let value = item[key]; + if (pytorch.ModelFactory._isTensor(value)) { + state_group.states.push({ name: key, value: value, id: state_group_name + '.' + key }); + } + else if (value !== Object(value)) { + state_group.attributes.push({ name: key, value: value }); + } + else if (value && value.__type__ == 'torch.nn.parameter.Parameter' && value.data) { + state_group.states.push({ name: key, value: value.data, id: state_group_name + '.' + key }); + } + else { + return null; + } } } } - return null; + return state_dict; + } + + + static _convertStateDictLegacyFormat(obj) { + if (!obj) { + return null; + } + if (obj && !Array.isArray(obj)) { + var array = []; + for (var key of Object.keys(obj)) { + array.push({ key: key, value: obj[key] }); + } + obj = array; + } + var state_dict = []; + var state_map = {}; + if (obj && Array.isArray(obj)) { + for (var item of obj) { + if (!item || !item.key || !item.value) { + return null; + } + let state = {}; + state.id = item.key; + state.value = null; + if (item.value.__type__ == 'torch.nn.parameter.Parameter') { + state.value = item.value[0]; + } + else if (pytorch.ModelFactory._isTensor(item.value)) { + state.value = item.value; + } + if (!state.value) { + return null; + } + let split = state.id.split('.'); + if (split.length < 2) { + return null; + } + state.name = split.pop(); + let state_group_name = split.join('.'); + let state_group = state_map[state_group_name]; + if (!state_group) { + state_group = {}; + state_group.name = state_group_name; + state_group.states = []; + state_map[state_group_name] = state_group; + state_dict.push(state_group); + } + state_group.states.push(state); + } + } + return state_dict; + } + + static _isTensor(obj) { + return obj && obj.__type__ && obj.__type__.startsWith('torch.') && obj.__type__.endsWith('Tensor'); } }; @@ -725,31 +858,20 @@ pytorch.Graph = class { for (var output of outputs) { this._outputs.push(new pytorch.Parameter(output, true, [ new pytorch.Argument(output, null, null) ])); } - } else { - var state_group_map = {}; - var state_groups = []; - for (var state of state_dict) { - var key = state.key.split('.'); - state.name = key.pop(); - var id = key.join('.'); - var state_group = state_group_map[id]; - if (!state_group) { - state_group = { id: id, states: [] }; - state_groups.push(state_group); - state_group_map[id] = state_group; - } - state_group.states.push(state); - } - this._nodes = this._nodes.concat(state_groups.map((state_group) => { - var inputs = state_group.states.map((state) => { - var tensor = new pytorch.Tensor(state.key, state.value, sysInfo.little_endian); + } + else { + for (let state_group of state_dict) { + let type = 'torch.nn.modules._.Module'; + let attributes = state_group.attributes || []; + let inputs = state_group.states.map((state) => { + var tensor = new pytorch.Tensor(state.id, state.value, sysInfo.little_endian); var visible = state_group.states.length == 0 || tensor.type.toString() != 'int64' || tensor.value < 1000; return new pytorch.Parameter(state.name, visible, [ - new pytorch.Argument(state.key, null, tensor) + new pytorch.Argument(state.id, null, tensor) ]); }); - return new pytorch.Node(this._metadata, '', state_group.id, { __type__: 'torch.nn.modules._.Module' }, inputs, []); - })); + this._nodes.push(new pytorch.Node(this._metadata, '', state_group.name, type, attributes, inputs, [])); + } } } @@ -845,10 +967,18 @@ pytorch.Graph = class { var group = groups.join('/'); var name = group ? (group + '/' + key) : key; + var type = obj.__type__ || ''; var outputs = [ new pytorch.Parameter('output', true, [ new pytorch.Argument(name, null, null) ]) ]; - var node = new pytorch.Node(this._metadata, group, name, obj, inputs, outputs); + var attributes = []; + for (let name of Object.keys(obj)) { + if (!name.startsWith('_')) { + attributes.push({ name: name, value: obj[name] }); + } + } + + var node = new pytorch.Node(this._metadata, group, name, type, attributes, inputs, outputs); this._nodes.push(node); return node; } @@ -921,22 +1051,16 @@ pytorch.Argument = class { pytorch.Node = class { - constructor(metadata, group, name, obj, inputs, outputs) { + constructor(metadata, group, name, type, attributes, inputs, outputs) { this._metadata = metadata; this._group = group || ''; this._name = name || ''; - var type = obj.__type__.split('.'); - this._operator = type.pop(); - this._package = type.join('.'); - this._attributes = []; + let split = type.split('.'); + this._operator = split.pop(); + this._package = split.join('.'); + this._attributes = attributes.map((attribute) => new pytorch.Attribute(this._metadata, this, attribute.name, attribute.value)); this._inputs = inputs; this._outputs = outputs; - - for (var attributeName of Object.keys(obj)) { - if (!attributeName.startsWith('_')) { - this._attributes.push(new pytorch.Attribute(this._metadata, this, attributeName, obj[attributeName])); - } - } } get name() { diff --git a/test/models.json b/test/models.json index 694127fd8a..98f9bb435a 100644 --- a/test/models.json +++ b/test/models.json @@ -3581,6 +3581,12 @@ "target": "inception_v3_google-1a9a5a14.pth", "source": "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" }, + { + "type": "pytorch", + "target": "mask_r_cnn.pth", + "source": "https://github.com/facebookresearch/kill-the-bits/blob/master/src/models/compressed/mask_r_cnn.pth?raw=true", + "link": "https://github.com/facebookresearch/kill-the-bits/tree/master/src/models/compressed" + }, { "type": "pytorch", "target": "mnist_linear.ckpt", @@ -3613,6 +3619,12 @@ "target": "resnet18-5c106cde.pth", "source": "https://download.pytorch.org/models/resnet18-5c106cde.pth" }, + { + "type": "pytorch", + "target": "resnet18_large_blocks.pth", + "source": "https://github.com/facebookresearch/kill-the-bits/blob/master/src/models/compressed/resnet18_large_blocks.pth?raw=true", + "link": "https://github.com/facebookresearch/kill-the-bits/tree/master/src/models/compressed" + }, { "type": "pytorch", "target": "resnet-18-at-export.pth",