Skip to content

Commit

Permalink
PyTorch state dict detection (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Aug 10, 2019
1 parent 52e6f40 commit 901628e
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 75 deletions.
274 changes: 199 additions & 75 deletions src/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -547,32 +547,7 @@ pytorch.ModelFactory = class {
});

root_module = null;
state_dict = [];

if (obj && !Array.isArray(obj)) {
var array = [];
for (var key of Object.keys(obj)) {
array.push({ key: key, value: obj[key] });
}
obj = array;
}
if (obj && Array.isArray(obj)) {
for (var item of obj) {
var value = null;
if (item && item.value && item.value.__type__ == 'torch.nn.parameter.Parameter') {
value = item.value[0];
}
else if (item && item.value &&
item.value.__type__.startsWith('torch.') &&
item.value.__type__.endsWith('Tensor')) {
value = item.value;
}
else {
value = null;
}
state_dict.push({ key: item.key, value: value });
}
}
state_dict = pytorch.ModelFactory._convertStateDictLegacyFormat(obj);
}

if (sys_info.type_sizes &&
Expand Down Expand Up @@ -666,28 +641,186 @@ pytorch.ModelFactory = class {
root.state_dict_stylepredictor, root.state_dict_ghiasi
];
for (var dict of candidates) {
if (dict && Array.isArray(dict) && dict.__setitem__ &&
dict.every((item) => item.value.__type__ && item.value.__type__.startsWith('torch.') && item.value.__type__.endsWith('Tensor'))) {
delete dict.__setitem__;
return dict;
let state_dict =
pytorch.ModelFactory._convertStateDictList(dict) ||
pytorch.ModelFactory._convertStateDictMap(dict) ||
pytorch.ModelFactory._convertStateDictGroupMap(dict);
if (state_dict) {
return state_dict;
}
if (dict && !Array.isArray(dict)) {
var match = true;
var array = [];
for (var key of Object.keys(dict)) {
var value = dict[key]
if (!key || !value || !value.__type__ || !value.__type__.startsWith('torch.') || !value.__type__.endsWith('Tensor')) {
match = false;
break;
}
return null;
}

static _convertStateDictList(list) {
if (!list || !Array.isArray(list) ||
!list.every((item) => item && item.key && pytorch.ModelFactory._isTensor(item.value))) {
return null;
}
let state_dict = [];
let state_map = {};
for (let item of list) {
let split = item.key.split('.');
if (split.length < 2) {
return null;
}
let state = {};
state.id = item.key;
state.name = split.pop();
state.value = item.value;
let state_group_name = split.join('.');
let state_group = state_map[state_group_name];
if (!state_group) {
state_group = {};
state_group.name = state_group_name;
state_group.states = [];
state_map[state_group_name] = state_group;
state_dict.push(state_group);
}
state_group.states.push(state);
}
return state_dict;
}

static _convertStateDictMap(obj) {
if (!obj || Array.isArray(obj)) {
return null
}
let state_dict = [];
let state_map = {};
for (var key in obj) {
let split = key.split('.');
if (split.length < 1) {
return null;
}
let state = {};
state.id = key;
state.name = split.pop();
state.value = obj[key];
if (!pytorch.ModelFactory._isTensor(state.value)) {
return null;
}
let state_group_name = split.join('.');
let state_group = state_map[state_group_name];
if (!state_group) {
state_group = {};
state_group.name = state_group_name;
state_group.states = [];
state_map[state_group_name] = state_group;
state_dict.push(state_group);
}
state_group.states.push(state);
}
return state_dict;
}

static _convertStateDictGroupMap(obj) {
if (!obj || Array.isArray(obj)) {
return null;
}
let state_dict = [];
let state_map = {};
for (let state_group_name in obj) {

let state_group = state_map[state_group_name];
if (!state_group) {
state_group = {};
state_group.name = state_group_name;
state_group.states = [];
state_group.attributes = [];
state_map[state_group_name] = state_group;
state_dict.push(state_group);
}
var item = obj[state_group_name];
if (!item) {
return null;
}
if (Array.isArray(item)) {
for (let entry of item) {
if (!entry || !entry.key || !entry.value || !pytorch.ModelFactory._isTensor(entry.value)) {
return null;
}
array.push({ key: key, value: value });
let state = {};
state.id = state_group_name + '.' + entry.key;
state.name = entry.key;
state.value = entry.value;
state_group.states.push(state);
}
if (match) {
return array;
}
else {
for (let key in item) {
let value = item[key];
if (pytorch.ModelFactory._isTensor(value)) {
state_group.states.push({ name: key, value: value, id: state_group_name + '.' + key });
}
else if (value !== Object(value)) {
state_group.attributes.push({ name: key, value: value });
}
else if (value && value.__type__ == 'torch.nn.parameter.Parameter' && value.data) {
state_group.states.push({ name: key, value: value.data, id: state_group_name + '.' + key });
}
else {
return null;
}
}
}
}
return null;
return state_dict;
}


static _convertStateDictLegacyFormat(obj) {
if (!obj) {
return null;
}
if (obj && !Array.isArray(obj)) {
var array = [];
for (var key of Object.keys(obj)) {
array.push({ key: key, value: obj[key] });
}
obj = array;
}
var state_dict = [];
var state_map = {};
if (obj && Array.isArray(obj)) {
for (var item of obj) {
if (!item || !item.key || !item.value) {
return null;
}
let state = {};
state.id = item.key;
state.value = null;
if (item.value.__type__ == 'torch.nn.parameter.Parameter') {
state.value = item.value[0];
}
else if (pytorch.ModelFactory._isTensor(item.value)) {
state.value = item.value;
}
if (!state.value) {
return null;
}
let split = state.id.split('.');
if (split.length < 2) {
return null;
}
state.name = split.pop();
let state_group_name = split.join('.');
let state_group = state_map[state_group_name];
if (!state_group) {
state_group = {};
state_group.name = state_group_name;
state_group.states = [];
state_map[state_group_name] = state_group;
state_dict.push(state_group);
}
state_group.states.push(state);
}
}
return state_dict;
}

static _isTensor(obj) {
return obj && obj.__type__ && obj.__type__.startsWith('torch.') && obj.__type__.endsWith('Tensor');
}
};

Expand Down Expand Up @@ -725,31 +858,20 @@ pytorch.Graph = class {
for (var output of outputs) {
this._outputs.push(new pytorch.Parameter(output, true, [ new pytorch.Argument(output, null, null) ]));
}
} else {
var state_group_map = {};
var state_groups = [];
for (var state of state_dict) {
var key = state.key.split('.');
state.name = key.pop();
var id = key.join('.');
var state_group = state_group_map[id];
if (!state_group) {
state_group = { id: id, states: [] };
state_groups.push(state_group);
state_group_map[id] = state_group;
}
state_group.states.push(state);
}
this._nodes = this._nodes.concat(state_groups.map((state_group) => {
var inputs = state_group.states.map((state) => {
var tensor = new pytorch.Tensor(state.key, state.value, sysInfo.little_endian);
}
else {
for (let state_group of state_dict) {
let type = 'torch.nn.modules._.Module';
let attributes = state_group.attributes || [];
let inputs = state_group.states.map((state) => {
var tensor = new pytorch.Tensor(state.id, state.value, sysInfo.little_endian);
var visible = state_group.states.length == 0 || tensor.type.toString() != 'int64' || tensor.value < 1000;
return new pytorch.Parameter(state.name, visible, [
new pytorch.Argument(state.key, null, tensor)
new pytorch.Argument(state.id, null, tensor)
]);
});
return new pytorch.Node(this._metadata, '', state_group.id, { __type__: 'torch.nn.modules._.Module' }, inputs, []);
}));
this._nodes.push(new pytorch.Node(this._metadata, '', state_group.name, type, attributes, inputs, []));
}
}
}

Expand Down Expand Up @@ -845,10 +967,18 @@ pytorch.Graph = class {

var group = groups.join('/');
var name = group ? (group + '/' + key) : key;
var type = obj.__type__ || '';

var outputs = [ new pytorch.Parameter('output', true, [ new pytorch.Argument(name, null, null) ]) ];

var node = new pytorch.Node(this._metadata, group, name, obj, inputs, outputs);
var attributes = [];
for (let name of Object.keys(obj)) {
if (!name.startsWith('_')) {
attributes.push({ name: name, value: obj[name] });
}
}

var node = new pytorch.Node(this._metadata, group, name, type, attributes, inputs, outputs);
this._nodes.push(node);
return node;
}
Expand Down Expand Up @@ -921,22 +1051,16 @@ pytorch.Argument = class {

pytorch.Node = class {

constructor(metadata, group, name, obj, inputs, outputs) {
constructor(metadata, group, name, type, attributes, inputs, outputs) {
this._metadata = metadata;
this._group = group || '';
this._name = name || '';
var type = obj.__type__.split('.');
this._operator = type.pop();
this._package = type.join('.');
this._attributes = [];
let split = type.split('.');
this._operator = split.pop();
this._package = split.join('.');
this._attributes = attributes.map((attribute) => new pytorch.Attribute(this._metadata, this, attribute.name, attribute.value));
this._inputs = inputs;
this._outputs = outputs;

for (var attributeName of Object.keys(obj)) {
if (!attributeName.startsWith('_')) {
this._attributes.push(new pytorch.Attribute(this._metadata, this, attributeName, obj[attributeName]));
}
}
}

get name() {
Expand Down
12 changes: 12 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -3581,6 +3581,12 @@
"target": "inception_v3_google-1a9a5a14.pth",
"source": "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth"
},
{
"type": "pytorch",
"target": "mask_r_cnn.pth",
"source": "https://github.com/facebookresearch/kill-the-bits/blob/master/src/models/compressed/mask_r_cnn.pth?raw=true",
"link": "https://github.com/facebookresearch/kill-the-bits/tree/master/src/models/compressed"
},
{
"type": "pytorch",
"target": "mnist_linear.ckpt",
Expand Down Expand Up @@ -3613,6 +3619,12 @@
"target": "resnet18-5c106cde.pth",
"source": "https://download.pytorch.org/models/resnet18-5c106cde.pth"
},
{
"type": "pytorch",
"target": "resnet18_large_blocks.pth",
"source": "https://github.com/facebookresearch/kill-the-bits/blob/master/src/models/compressed/resnet18_large_blocks.pth?raw=true",
"link": "https://github.com/facebookresearch/kill-the-bits/tree/master/src/models/compressed"
},
{
"type": "pytorch",
"target": "resnet-18-at-export.pth",
Expand Down

0 comments on commit 901628e

Please sign in to comment.