Skip to content

Commit

Permalink
feat: remastered TFModel class
Browse files Browse the repository at this point in the history
  • Loading branch information
mu-arkhipov committed Feb 18, 2018
1 parent 62d8d10 commit 2e530e7
Showing 1 changed file with 80 additions and 87 deletions.
167 changes: 80 additions & 87 deletions deeppavlov/core/models/tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,119 +13,112 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import sys
from abc import abstractmethod
from collections import defaultdict
import numpy as np

import tensorflow as tf
from overrides import overrides

from deeppavlov.core.models.nn_model import NNModel
from deeppavlov.core.models.serializable import Serializable
from deeppavlov.core.common.attributes import check_attr_true
from deeppavlov.core.common.log import get_logger
from deeppavlov.core.common.errors import ConfigError
from .tf_backend import TfModelMeta

"""
Here is an abstract class for neural network models based on Tensorflow.
If you use something different, ex. Pytorch, then write similar to this class, inherit it from
Trainable and Inferable interfaces and make a pull-request to deeppavlov.
"""

log = get_logger(__name__)


class TFModel(Serializable, metaclass=TfModelMeta):
def __init__(self, **kwargs):
self._saver = tf.train.Saver
super().__init__(**kwargs)
class TFModel(NNModel, metaclass=TfModelMeta):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@abstractmethod
def _add_placeholders(self):
"""
Add all needed placeholders for a computational graph.
"""
pass
def train_on_batch(self, x_batch, y_batch):
""" Perform single step of optimization given batch of samples
@abstractmethod
def run_sess(self, *args, **kwargs):
"""
1. Call _build_graph()
2. Define all comuptations.
3. Run tf.sess.
3. Reset state if needed.
:return:
"""
pass
Args:
x_batch: batch of x-s it could be anything: dict, list, numpy array. However,
it must contain number of samples
y_batch: batch of y-s. It must contain same number of samples as x_batch.
Returns:
loss: mean loss over batch
@abstractmethod
def _train_step(self, features, *args):
"""
Define a single training step. Feed dict to tf session.
:param features: input features
:param args: any other inputs, including target vector, you need to pass for training
:return: metric to return, usually loss
"""
pass

def load(self):
"""Load model parameters from self.load_path"""
path = str(self.load_path.resolve())
# Check presence of the model files
if tf.train.checkpoint_exists(path):
print('[loading model from {}]'.format(path), file=sys.stderr)
saver = tf.train.Saver()
saver.restore(self.sess, path)

def save(self):
"""Save model parameters to self.save_path"""
path = str(self.save_path.resolve())
print('[saving model to {}]'.format(path), file=sys.stderr)
saver = tf.train.Saver()
saver.restore(self.sess, path)

@abstractmethod
def _forward(self, features, *args):
"""
Pass an instance to get a prediction.
:param features: input features
:param args: any other inputs you need to pass for training
:return: prediction
def __call__(self, x_batch):
""" Infer y_batch from x_batch
Args:
x_batch: a batch of samples
Returns:
y_batch: a batch of samples inferred from x_batch
"""
pass

@check_attr_true('train_now')
def train(self, features, *args, **kwargs):
"""
Just a wrapper for a private method.
"""
return self._train_step(features, *args, **kwargs)
def get_train_op(self, loss, learning_rate, learnable_scopes=None, optimizer=None):
""" Get train operation for given loss
def infer(self, instance, *args, **kwargs):
"""
Just a wrapper for a private method.
Args:
loss: loss, tf tensor or scalar
learning_rate: scalar or placeholder
learnable_scopes: which scopes are trainable (None for all)
optimizer: instance of tf.train.Optimizer, default Adam
Returns:
train_op
"""
return self._forward(instance, *args, **kwargs)

def save(self):
save_path = str(self.save_path)
saver = tf.train.Saver()
log.info('[saving model to {}]'.format(save_path))
saver.save(self.sess, save_path)
log.info('model saved')

def get_checkpoint_state(self):
if self.load_path:
try:
if self.load_path.parent.is_dir():
return tf.train.get_checkpoint_state(self.load_path.parent)
else:
raise ConfigError
except ConfigError:
log.error('Provided `load_path` is incorrect!', exc_info=True)
sys.exit(1)
if learnable_scopes is None:
variables_to_train = tf.trainable_variables()
else:
log.warning('No `load_path` is provided for {}'.format(self.__class__.__name__))

@overrides
def load(self):
variables_to_train = []
for scope_name in learnable_scopes:
for var in tf.trainable_variables():
if var.name.startswith(scope_name):
variables_to_train.append(var)

if optimizer is None:
optimizer = tf.train.AdamOptimizer

# For batch norm it is necessary to update running averages
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
train_op = optimizer(learning_rate).minimize(loss, var_list=variables_to_train)
return train_op

@staticmethod
def print_number_of_parameters():
"""
Load session from checkpoint
Print number of *trainable* parameters in the network
"""
ckpt = self.get_checkpoint_state()
if ckpt and ckpt.model_checkpoint_path:
log.info('[restoring checkpoint from {}]'.format(ckpt.model_checkpoint_path))
self._saver().restore(self.sess, ckpt.model_checkpoint_path)
log.info('session restored')
else:
log.error('checkpoint not found!')


class SimpleTFModel(NNModel, metaclass=TfModelMeta):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
log.info('Number of parameters: ')
variables = tf.trainable_variables()
blocks = defaultdict(int)
for var in variables:
# Get the top level scope name of variable
block_name = var.name.split('/')[0]
number_of_parameters = np.prod(var.get_shape().as_list())
blocks[block_name] += number_of_parameters
for block_name in blocks:
log.info(block_name, blocks[block_name])
total_num_parameters = np.sum(list(blocks.values()))
log.info('Total number of parameters equal {}'.format(total_num_parameters))

0 comments on commit 2e530e7

Please sign in to comment.