Skip to content

Commit

Permalink
Merge pull request #66 from lscheinkman/RES-1038.1
Browse files Browse the repository at this point in the history
refactor load_gsc_weights_from_pytorch
  • Loading branch information
lscheinkman committed Sep 13, 2019
2 parents 08ea3a7 + ca143e2 commit 1c598d4
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion nupic/research/frameworks/tensorflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
# http://numenta.org/licenses/
#

from .pytorch_utils import load_pytorch_weights
from .pytorch_utils import load_gsc_weights_from_pytorch
13 changes: 9 additions & 4 deletions nupic/research/frameworks/tensorflow/utils/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import tensorflow as tf
import tensorflow.keras.backend as K

import nupic.torch.models

TF_LOGGER = tf.get_logger()


Expand Down Expand Up @@ -109,16 +111,19 @@ def _reflatten_linear_weight(x):
}


def load_pytorch_weights(model_tf, model_pt, weights_map=None):
def load_gsc_weights_from_pytorch(model_tf, model_pt, weights_map=None):
"""
Update tensorflow model weights using pre-trained pytorch model
:param model_tf: Clean tensorflow model
Update tensorflow model weights using pre-trained GSC pytorch model
:param model_tf: Untrained GSC model (tensorflow).
:type model_tf: :class:`nupic.tensorflow.models.GSCSparseCNN`
:param model_pt: Pre-trained pytorch model
:param model_pt: Pre-trained GSC model (pytorch).
:type model_pt: :class:`nupic.torch.models.GSCSparseCNN`
:param weights_map: Dictionay mapping tensorflow variables to pytorch state
:type weights_map: dict
"""
if not isinstance(model_pt, nupic.torch.models.GSCSparseCNN):
raise NotImplementedError()

if weights_map is None:
weights_map = _GSC_SPARSE_MAP
state_dict = model_pt.state_dict()
Expand Down
6 changes: 3 additions & 3 deletions projects/tensorflow/gsc/import_pytorch_gsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torch.hub

from nupic.research.frameworks.tensorflow.utils import load_pytorch_weights
from nupic.research.frameworks.tensorflow.utils import load_gsc_weights_from_pytorch
from nupic.tensorflow.models import GSCSparseCNN, GSCSuperSparseCNN

if __name__ == "__main__":
Expand All @@ -35,7 +35,7 @@

print("Converting gsc_sparse_cnn from pytorch to tensorflow")
model_tf = GSCSparseCNN(data_format="channels_last")
load_pytorch_weights(model_tf, model_pt)
load_gsc_weights_from_pytorch(model_tf, model_pt)
model_tf.compile(optimizer="sgd",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
Expand All @@ -50,7 +50,7 @@
pretrained=True)
print("Converting gsc_super_sparse_cnn from pytorch to tensorflow")
model_tf = GSCSuperSparseCNN(data_format="channels_last")
load_pytorch_weights(model_tf, model_pt)
load_gsc_weights_from_pytorch(model_tf, model_pt)
print("Saving pre-trained tensorflow version of gsc_super_sparse_cnn as "
"gsc_super_sparse_cnn.h5")
model_tf.save_weights("gsc_super_sparse_cnn.h5")
6 changes: 3 additions & 3 deletions tests/unit/frameworks/tensorflow/load_pytorch_weights_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.platform import test

from nupic.research.frameworks.tensorflow.utils import load_pytorch_weights
from nupic.research.frameworks.tensorflow.utils import load_gsc_weights_from_pytorch
from nupic.tensorflow.models import GSCSparseCNN, GSCSuperSparseCNN


Expand All @@ -45,7 +45,7 @@ def test_gsc_sparse_cnn(self):
model_pt.eval()

model_tf = GSCSparseCNN(pre_trained=False)
load_pytorch_weights(model_tf, model_pt)
load_gsc_weights_from_pytorch(model_tf, model_pt)
model_tf.compile(optimizer="sgd",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
Expand All @@ -69,7 +69,7 @@ def test_gsc_super_sparse_cnn(self):
model_pt.eval()

model_tf = GSCSuperSparseCNN(pre_trained=False)
load_pytorch_weights(model_tf, model_pt)
load_gsc_weights_from_pytorch(model_tf, model_pt)
model_tf.compile(optimizer="sgd",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
Expand Down

0 comments on commit 1c598d4

Please sign in to comment.