Skip to content

Commit

Permalink
feat: add new abstract classes for components
Browse files Browse the repository at this point in the history
BREAKING CHANGE:
nothing works after this
  • Loading branch information
yoptar committed Feb 15, 2018
1 parent 1d42119 commit 696dd29
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 152 deletions.
29 changes: 29 additions & 0 deletions deeppavlov/core/models/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Copyright 2017 Neural Networks and Deep Learning lab, MIPT
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import ABCMeta, abstractmethod


class Component(metaclass=ABCMeta):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@abstractmethod
def __call__(self, *args, **kwargs):
pass

def reset(self):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,15 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

"""
:class:`deeppavlov.models.model.Trainable` is an abstract base class that expresses the interface
for all models that can be trained (ex. neural networks, scikit-learn estimators, gensim models,
etc.). All trainable models should inherit from this class.
"""

from abc import abstractmethod

from typing import Tuple, Iterable

from .component import Component
from .serializable import Serializable


class Trainable(Serializable):
class Estimator(Component, Serializable):
"""
:attr:`train_now` expresses a developer intent for whether a model as part of a pipeline
should be trained in the current experiment run or not.
Expand All @@ -40,9 +36,5 @@ def __init__(self, train_now=False, **kwargs):
super().__init__(**kwargs)

@abstractmethod
def save(self, *args, **kwargs):
pass

@abstractmethod
def load(self, *args, **kwargs):
def fit(self, data: Tuple[list, list]):
pass
6 changes: 2 additions & 4 deletions deeppavlov/core/models/keras_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
from keras.models import Model
from keras.layers import Dense, Input

from deeppavlov.core.models.trainable import Trainable
from deeppavlov.core.models.inferable import Inferable
from deeppavlov.core.common.attributes import check_attr_true
from deeppavlov.core.models.nn_model import NNModel
from deeppavlov.core.common.file import save_json, read_json
from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.common.log import get_logger
Expand All @@ -38,7 +36,7 @@
log = get_logger(__name__)


class KerasModel(Trainable, Inferable, metaclass=TfModelMeta):
class KerasModel(NNModel, metaclass=TfModelMeta):
"""
Class builds keras model with tensorflow backend
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,28 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

"""
:class:`deeppavlov.models.model.Inferable` is an abstract base class that expresses the interface
for all models that can work for inferring. The scope of all inferring models is larger than the scope
of trainable models. For example, encoders can be inferred, but can't be trained.
All inferring models should inherit from this class.
"""
from abc import abstractmethod

from typing import Tuple

from .component import Component
from .serializable import Serializable


class Inferable(Serializable):
class NNModel(Component, Serializable):
"""
:attr:`train_now` expresses a developer intent for whether a model as part of a pipeline
should be trained in the current experiment run or not.
"""

def __init__(self, **kwargs):
def __init__(self, train_now=False, **kwargs):
mode = kwargs.get('mode', None)
if mode == 'train':
self.train_now = train_now
else:
self.train_now = False
super().__init__(**kwargs)

@abstractmethod
def infer(self, instance):
"""
Infer a model. Any model can infer other model and ask it to do something (predict, encode,
etc. via this method)
:param instance: pass data instance to an inferring model
:param args: all needed params for inferring
:return a result of inferring
"""
def train_on_batch(self, batch: Tuple[list, list]):
pass
21 changes: 13 additions & 8 deletions deeppavlov/core/models/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""
from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.log import get_logger
from abc import ABCMeta
from abc import ABCMeta, abstractmethod

"""
:class:`deeppavlov.models.model.Serializable` is an abstract base class that expresses the interface
for all models that can serialize data to a path.
Expand All @@ -37,15 +38,15 @@ def __init__(self, save_path, load_path=None, **kwargs):

if save_path:
self.save_path = expand_path(save_path)
self.save_path.parent.mkdir(exist_ok=True)
self.save_path.parent.mkdir(parents=True, exist_ok=True)
else:
self.save_path = None

mode = kwargs.get('mode', 'infer')

if load_path:
self.load_path = expand_path(load_path)
if mode != 'train' and self.load_path != self.save_path:
if mode != 'train' and self.save_path and self.load_path != self.save_path:
log.warning("Load path '{}' differs from save path '{}' in '{}' mode for {}."
.format(self.load_path, self.save_path, mode, self.__class__.__name__))
elif mode != 'train' and self.save_path:
Expand All @@ -56,8 +57,12 @@ def __init__(self, save_path, load_path=None, **kwargs):
self.load_path = None
log.warning("No load path is set for {}!".format(self.__class__.__name__))

def __new__(cls, *args, **kwargs):
if cls is Serializable:
raise TypeError(
"TypeError: Can't instantiate abstract class {} directly".format(cls.__name__))
return object.__new__(cls)
super().__init__()

@abstractmethod
def save(self, *args, **kwargs):
pass

@abstractmethod
def load(self, *args, **kwargs):
pass
106 changes: 0 additions & 106 deletions deeppavlov/core/models/sklearn_model.py

This file was deleted.

8 changes: 4 additions & 4 deletions deeppavlov/core/models/tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import tensorflow as tf
from overrides import overrides

from deeppavlov.core.models.trainable import Trainable
from deeppavlov.core.models.inferable import Inferable
from deeppavlov.core.models.nn_model import NNModel
from deeppavlov.core.models.component import Component
from deeppavlov.core.common.attributes import check_attr_true
from deeppavlov.core.common.log import get_logger
from deeppavlov.core.common.errors import ConfigError
Expand All @@ -36,7 +36,7 @@
log = get_logger(__name__)


class TFModel(Trainable, Inferable, metaclass=TfModelMeta):
class TFModel(NNModel, metaclass=TfModelMeta):
def __init__(self, **kwargs):
self._saver = tf.train.Saver
super().__init__(**kwargs)
Expand Down Expand Up @@ -126,6 +126,6 @@ def load(self):
log.error('checkpoint not found!')


class SimpleTFModel(Trainable, Inferable, metaclass=TfModelMeta):
class SimpleTFModel(NNModel, metaclass=TfModelMeta):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

0 comments on commit 696dd29

Please sign in to comment.