Skip to content
This repository has been archived by the owner on May 25, 2020. It is now read-only.

Commit

Permalink
Save and load weights separately for each model's param (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas committed Jul 30, 2018
1 parent 97c3782 commit d231b03
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 13 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ Then we marked out each utterance with our emotions classifier that predicts one
To mark-up your own corpus with emotions you can use, for example, [DeepMoji tool](https://github.com/bfelbo/DeepMoji)
or any other emotions classifier that you have.

#### Initalizing model weights from file
For some tools (for example [`tools/train.py`](tools/train.py)) you can specify the path to model's initialization weights via `--init_weights` argument.

The weights may come from a trained CakeChat model or from a model with a different architecture.
In the latter case some parameters of Cakechat model may be left without initialization:
a parameter will be initialized with a saved value if the parameter's name and shape are
identical to the saved parameter, otherwise the parameter will keep its default initialization weights.

See `load_weights` function for the details.

### Training your own model

1. Put your training text corpus to [`data/corpora_processed/`](data/corpora_processed/).
Expand Down
2 changes: 1 addition & 1 deletion cakechat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
LEARNING_RATE = 1.0 # Learning rate for the chosen optimizer (currently using Adadelta, see model.py)

# model params
NN_MODEL_PREFIX = 'cakechat' # Specify prefix to be prepended to model's name
NN_MODEL_PREFIX = 'cakechat_v1.3' # Specify prefix to be prepended to model's name

# predictions params
MAX_PREDICTIONS_LENGTH = 40 # Max. number of tokens which can be generated on the prediction step
Expand Down
57 changes: 45 additions & 12 deletions cakechat/dialog_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,20 +622,54 @@ def is_reverse_model(self):
return self._is_reverse_model

def load_weights(self):
with open(self.model_load_path, 'rb') as f:
loaded_file = np.load(f)
# Just using .values() would't work here because we need to keep the order of elements
ordered_params = [loaded_file['arr_%d' % i] for i in xrange(len(loaded_file.files))]
set_all_param_values(self._net['dist'], ordered_params)
_logger.info('\nLoading saved weights from file:\n{}\n'.format(self.model_load_path))
saved_var_name_to_var = OrderedDict(np.load(self.model_load_path))

def save_model(self, save_model_path):
ensure_dir(os.path.dirname(save_model_path))
ordered_params = get_all_param_values(self._net['dist'])
var_name_to_var = OrderedDict([(v.name, v) for v in get_all_params(self._net['dist'])])
initialized_vars, missing_vars, mismatched_vars = [], [], []

with open(save_model_path, 'wb') as f:
np.savez(f, *ordered_params)
for var_name, var in var_name_to_var.iteritems():
if var_name not in saved_var_name_to_var:
missing_vars.append(var_name)
continue

_logger.info('\nSaved model:\n{}\n'.format(save_model_path))
default_var_value = var.get_value()
saved_var_value = saved_var_name_to_var[var_name]

if default_var_value.shape != saved_var_value.shape:
mismatched_vars.append((var_name, default_var_value.shape, saved_var_value.shape))
continue

# Checks passed, set parameter value
var.set_value(saved_var_value)
initialized_vars.append(var_name)
del saved_var_name_to_var[var_name]

laconic_logger.info('\nRestored saved params:')
for var_name in initialized_vars:
laconic_logger.info('\t' + var_name)

laconic_logger.warning('\nMissing saved params:')
for var_name in missing_vars:
laconic_logger.warning('\t' + var_name)

laconic_logger.warning('\nShapes-mismatched params (saved -> current):')
for var_name, default_shape, saved_shape in mismatched_vars:
laconic_logger.warning('\t{0:<40} {1:<12} -> {2:<12}'.format(var_name, saved_shape, default_shape))

laconic_logger.warning('\nUnused saved params:')
for var_name in saved_var_name_to_var:
laconic_logger.warning('\t' + var_name)

laconic_logger.info('')

def save_model(self, save_path):
all_params = get_all_params(self._net['dist'])
with open(save_path, 'wb') as f:
params = {v.name: v.get_value() for v in all_params}
np.savez(f, **params)

_logger.info('\nSaved model:\n{}\n'.format(save_path))

@staticmethod
def delete_model(delete_path):
Expand Down Expand Up @@ -684,7 +718,6 @@ def get_nn_model(index_to_token, index_to_condition, model_init_path=None, w2v_m
model_exists = resolver.resolve()

if model_exists:
_logger.info('\nLoading weights from file:\n{}\n'.format(model.model_load_path))
model.load_weights()
elif model_init_path:
raise FileNotFoundException('Can\'t initialize model from file:\n{}\n'.format(model_init_path))
Expand Down

0 comments on commit d231b03

Please sign in to comment.