Skip to content

Commit

Permalink
Merge pull request #99 from nicolay-r/0.21.0-rc
Browse files Browse the repository at this point in the history
0.21.0 rc
  • Loading branch information
nicolay-r authored Apr 27, 2021
2 parents 3051fb9 + 82f166b commit 18251dc
Show file tree
Hide file tree
Showing 55 changed files with 135 additions and 65 deletions.
2 changes: 1 addition & 1 deletion common/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def IO(self):
return self.__io

# TODO. Remove epochs count, since it is related to NeuralNetworks only.
def run_training(self, epochs_count):
def run_training(self, epochs_count, seed):
raise NotImplementedError()

def predict(self, data_type=DataType.Test):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging

from arekit.common.experiment.formats.base import BaseExperiment
from arekit.contrib.experiments.common import entity_to_group_func
from arekit.contrib.experiments.ruattitudes.documents import RuAttitudesDocumentOperations
from arekit.contrib.experiments.ruattitudes.folding import create_ruattitudes_experiment_data_folding
from arekit.contrib.experiments.ruattitudes.opinions import RuAttitudesOpinionOperations
from arekit.contrib.experiments.ruattitudes.utils import read_ruattitudes_in_memory
from arekit.contrib.experiment_rusentrel.common import entity_to_group_func
from arekit.contrib.experiment_rusentrel.ds.documents import RuAttitudesDocumentOperations
from arekit.contrib.experiment_rusentrel.ds.folding import create_ruattitudes_experiment_data_folding
from arekit.contrib.experiment_rusentrel.ds.opinions import RuAttitudesOpinionOperations
from arekit.contrib.experiment_rusentrel.ds.utils import read_ruattitudes_in_memory
from arekit.contrib.source.ruattitudes.io_utils import RuAttitudesVersions

logger = logging.getLogger(__name__)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from arekit.common.experiment.data.base import DataIO
from arekit.common.experiment.folding.types import FoldingType
from arekit.contrib.experiments.ruattitudes.experiment import RuAttitudesExperiment
from arekit.contrib.experiments.rusentrel.experiment import RuSentRelExperiment
from arekit.contrib.experiments.rusentrel_ds.experiment import RuSentRelWithRuAttitudesExperiment
from arekit.contrib.experiments.types import ExperimentTypes
from arekit.contrib.experiment_rusentrel.ds.experiment import RuAttitudesExperiment
from arekit.contrib.experiment_rusentrel.joined.experiment import RuSentRelWithRuAttitudesExperiment
from arekit.contrib.experiment_rusentrel.sl.experiment import RuSentRelExperiment
from arekit.contrib.experiment_rusentrel.types import ExperimentTypes


def create_experiment(exp_type,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from arekit.common.experiment.formats.documents import DocumentOperations
from arekit.contrib.experiments.ruattitudes.documents import RuAttitudesDocumentOperations
from arekit.contrib.experiments.rusentrel.documents import RuSentrelDocumentOperations
from arekit.contrib.experiment_rusentrel.ds.documents import RuAttitudesDocumentOperations
from arekit.contrib.experiment_rusentrel.sl.documents import RuSentrelDocumentOperations


class RuSentrelWithRuAttitudesDocumentOperations(DocumentOperations):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

from arekit.common.experiment.folding.types import FoldingType
from arekit.common.experiment.formats.base import BaseExperiment
from arekit.contrib.experiments.common import entity_to_group_func
from arekit.common.experiment.io_utils import BaseIOUtils
from arekit.contrib.experiments.ruattitudes.documents import RuAttitudesDocumentOperations
from arekit.contrib.experiments.ruattitudes.folding import create_ruattitudes_experiment_data_folding
from arekit.contrib.experiments.ruattitudes.opinions import RuAttitudesOpinionOperations
from arekit.contrib.experiments.ruattitudes.utils import read_ruattitudes_in_memory
from arekit.contrib.experiments.rusentrel.documents import RuSentrelDocumentOperations
from arekit.contrib.experiments.rusentrel.folding import create_rusentrel_experiment_data_folding
from arekit.contrib.experiments.rusentrel.opinions import RuSentrelOpinionOperations
from arekit.contrib.experiments.rusentrel_ds.documents import RuSentrelWithRuAttitudesDocumentOperations
from arekit.contrib.experiments.rusentrel_ds.opinions import RuSentrelWithRuAttitudesOpinionOperations
from arekit.contrib.experiments.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.experiment_rusentrel.common import entity_to_group_func
from arekit.contrib.experiment_rusentrel.ds.documents import RuAttitudesDocumentOperations
from arekit.contrib.experiment_rusentrel.ds.folding import create_ruattitudes_experiment_data_folding
from arekit.contrib.experiment_rusentrel.ds.opinions import RuAttitudesOpinionOperations
from arekit.contrib.experiment_rusentrel.ds.utils import read_ruattitudes_in_memory
from arekit.contrib.experiment_rusentrel.joined.documents import RuSentrelWithRuAttitudesDocumentOperations
from arekit.contrib.experiment_rusentrel.joined.opinions import RuSentrelWithRuAttitudesOpinionOperations
from arekit.contrib.experiment_rusentrel.sl.documents import RuSentrelDocumentOperations
from arekit.contrib.experiment_rusentrel.sl.folding import create_rusentrel_experiment_data_folding
from arekit.contrib.experiment_rusentrel.sl.opinions import RuSentrelOpinionOperations
from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.source.ruattitudes.io_utils import RuAttitudesVersions
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.formats.opinions import OpinionOperations
from arekit.contrib.experiments.rusentrel.opinions import RuSentrelOpinionOperations
from arekit.contrib.experiment_rusentrel.sl.opinions import RuSentrelOpinionOperations


class RuSentrelWithRuAttitudesOpinionOperations(OpinionOperations):
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from arekit.common.experiment.folding.types import FoldingType
from arekit.common.experiment.formats.base import BaseExperiment
from arekit.contrib.experiments.common import entity_to_group_func
from arekit.common.experiment.io_utils import BaseIOUtils
from arekit.contrib.experiments.rusentrel.documents import RuSentrelDocumentOperations
from arekit.contrib.experiments.rusentrel.folding import create_rusentrel_experiment_data_folding
from arekit.contrib.experiments.rusentrel.opinions import RuSentrelOpinionOperations
from arekit.contrib.experiments.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.experiment_rusentrel.common import entity_to_group_func
from arekit.contrib.experiment_rusentrel.sl.documents import RuSentrelDocumentOperations
from arekit.contrib.experiment_rusentrel.sl.folding import create_rusentrel_experiment_data_folding
from arekit.contrib.experiment_rusentrel.sl.opinions import RuSentrelOpinionOperations
from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions

logger = logging.getLogger(__name__)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from arekit.common.experiment.formats.opinions import OpinionOperations
from arekit.common.experiment.io_utils import BaseIOUtils
from arekit.common.opinions.collection import OpinionCollection
from arekit.contrib.experiments.rusentrel.labels_formatter import RuSentRelNeutralLabelsFormatter
from arekit.contrib.experiment_rusentrel.sl.labels_formatter import RuSentRelNeutralLabelsFormatter
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions
from arekit.contrib.source.rusentrel.labels_fmt import RuSentRelLabelsFormatter
from arekit.contrib.source.rusentrel.opinions.collection import RuSentRelOpinionCollection
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from arekit.contrib.experiments.synonyms.collection import StemmerBasedSynonymCollection
from arekit.contrib.experiment_rusentrel.synonyms.collection import StemmerBasedSynonymCollection
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions
from arekit.contrib.source.rusentrel.synonyms import RuSentRelSynonymsCollectionHelper
from arekit.processing.lemmatization.base import Stemmer
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import sys
import unittest


sys.path.append('../')

from arekit.contrib.experiments.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.bert.core.input.providers.label.binary import BinaryLabelProvider
from arekit.contrib.experiment_rusentrel.common import entity_to_group_func
from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.common.experiment.data_type import DataType
from arekit.contrib.experiments.common import entity_to_group_func
from arekit.common.experiment.input.formatters.helper.balancing import SampleRowBalancerHelper
from arekit.common.experiment.input.formatters.sample import BaseSampleFormatter
from arekit.common.experiment.input.providers.text.single import BaseSingleTextProvider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from arekit.common.entities.base import Entity
from arekit.common.news.parsed.base import ParsedNews
from arekit.contrib.experiments.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.source.ruattitudes.collection import RuAttitudesCollection
from arekit.contrib.source.ruattitudes.io_utils import RuAttitudesVersions
from arekit.contrib.source.ruattitudes.news.base import RuAttitudesNews
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from arekit.processing.lemmatization.mystem import MystemWrapper

from arekit.contrib.experiments.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider

from arekit.contrib.source.rusentrel.news.base import RuSentRelNews
from arekit.contrib.source.rusentrel.sentence import RuSentRelSentence
from arekit.contrib.source.rusentrel.entities.entity import RuSentRelEntity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from arekit.common.opinions.collection import OpinionCollection
from arekit.common.synonyms import SynonymsCollection
from arekit.common.utils import progress_bar_iter
from arekit.contrib.experiments.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions
from arekit.contrib.source.rusentrel.labels_fmt import RuSentRelLabelsFormatter
from arekit.contrib.source.rusentrel.opinions.collection import RuSentRelOpinionCollection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection
from arekit.contrib.source.rusentiframes.types import RuSentiFramesVersions
from arekit.contrib.source.tests.text.news import init_rusentrel_doc
from arekit.contrib.experiments.synonyms.provider import RuSentRelSynonymsCollectionProvider

from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider

from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.entities.formatters.str_rus_cased_fmt import RussianEntitiesCasedFormatter
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import unittest

from arekit.contrib.experiments.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.source.ruattitudes.io_utils import RuAttitudesVersions
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions
from arekit.processing.lemmatization.mystem import MystemWrapper
Expand Down
File renamed without changes.
6 changes: 5 additions & 1 deletion contrib/networks/context/architectures/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,19 @@ def compile_hidden_states_only(self, config):
self.__init_embedding_hidden_states()
self.init_body_dependent_hidden_states()

def compile(self, config, reset_graph):
def compile(self, config, reset_graph, graph_seed=None):
assert(isinstance(config, DefaultNetworkConfig))
assert(isinstance(reset_graph, bool))
assert(isinstance(graph_seed, int) or graph_seed is None)

self.__cfg = config

if reset_graph:
tf.reset_default_graph()

if graph_seed is not None:
tf.set_random_seed(graph_seed)

self.init_input()
self.__init_embedding_hidden_states()
self.init_body_dependent_hidden_states()
Expand Down
5 changes: 2 additions & 3 deletions contrib/networks/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from tensorflow.python.training.saver import Saver

from arekit.common.evaluation.evaluators.base import BaseEvaluator
from arekit.common.experiment.scales.base import BaseLabelScaler
from arekit.common.experiment.labeling import LabeledCollection
from arekit.common.model.base import BaseModel
Expand Down Expand Up @@ -104,8 +103,8 @@ def __dispose_session(self):
"""
self.__sess.close()

def run_training(self, epochs_count):
self.__network.compile(self.Config, reset_graph=True)
def run_training(self, epochs_count, seed):
self.__network.compile(self.Config, reset_graph=True, graph_seed=seed)
self.set_optimiser()
self.__notify_initialized()

Expand Down
2 changes: 1 addition & 1 deletion contrib/networks/core/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def iter_input_dependent_hidden_parameters(self):
return
yield

def compile(self, config, reset_graph):
def compile(self, config, reset_graph, graph_seed):
raise NotImplementedError()

def create_feed_dict(self, input, data_type):
Expand Down
6 changes: 5 additions & 1 deletion contrib/networks/multi/architectures/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,16 @@ def DropoutKeepProb(self):

# region body

def compile(self, config, reset_graph):
def compile(self, config, reset_graph, graph_seed=None):
assert(isinstance(config, BaseMultiInstanceConfig))
assert(isinstance(graph_seed, int) or graph_seed is None)

self.__cfg = config
tf.reset_default_graph()

if graph_seed is not None:
tf.set_random_seed(graph_seed)

with tf.variable_scope(self.__ctx_network_scope):
self.__context_network.compile_hidden_states_only(config=config.ContextConfig)

Expand Down
8 changes: 6 additions & 2 deletions contrib/networks/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,22 @@ class NetworksTrainingEngine(ExperimentEngine):
def __init__(self, bags_collection_type, experiment,
load_model, config,
create_network_func,
prepare_model_root=True):
prepare_model_root=True,
seed=None):
assert(callable(create_network_func))
assert(isinstance(config, DefaultNetworkConfig))
assert(issubclass(bags_collection_type, BagsCollection))
assert(isinstance(load_model, bool))
assert(isinstance(seed, int) or seed is None)

super(NetworksTrainingEngine, self).__init__(experiment)

self.__clear_model_root_before_experiment = prepare_model_root
self.__config = config
self.__create_network_func = create_network_func
self.__bags_collection_type = bags_collection_type
self.__load_model = load_model
self.__seed = seed

def __get_model_dir(self):
return self._experiment.DataIO.ModelIO.get_model_dir()
Expand Down Expand Up @@ -93,7 +97,7 @@ def _handle_iteration(self, it_index):

# Run model
with callback:
model.run_training(epochs_count=callback.Epochs)
model.run_training(epochs_count=callback.Epochs, seed=self.__seed)

del network
del model
Expand Down
2 changes: 1 addition & 1 deletion contrib/networks/tests/test_tf_ctx_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test(self):
logger.info("Clases count: {}".format(config.ClassesCount))

init_config(config)
network.compile(config, reset_graph=True)
network.compile(config, reset_graph=True, graph_seed=42)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion contrib/networks/tests/test_tf_ctx_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run_feeding(network, network_config, create_minibatch_func, logger,
labels_scaler = ThreeLabelScaler()
init_config(network_config)
# Init network.
network.compile(config=network_config, reset_graph=True)
network.compile(config=network_config, reset_graph=True, graph_seed=42)
minibatch = create_minibatch_func(config=network_config,
labels_scaler=labels_scaler)

Expand Down
3 changes: 2 additions & 1 deletion contrib/networks/tests/test_tf_input_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

from arekit.contrib.networks.tests.text.news import init_rusentrel_doc
from arekit.contrib.networks.features.term_indices import IndicesFeature
from arekit.contrib.experiments.synonyms.provider import RuSentRelSynonymsCollectionProvider

from arekit.tests.text.linked_opinions import iter_same_sentence_linked_text_opinions
from arekit.tests.text.utils import terms_to_str

from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection
from arekit.contrib.source.rusentiframes.types import RuSentiFramesVersions

from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider


class TestTfInputFeatures(unittest.TestCase):

Expand Down
2 changes: 1 addition & 1 deletion contrib/networks/tests/test_tf_mi_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def mpmi(context_config, context_network):

network = MaxPoolingOverSentences(context_network=context_network)
init_config(config)
network.compile(config, reset_graph=True)
network.compile(config, reset_graph=True, graph_seed=42)

def test(self):
logging.basicConfig(level=logging.INFO)
Expand Down
Binary file modified contrib/source/data/rusentrel-v1_1.zip
Binary file not shown.
50 changes: 36 additions & 14 deletions contrib/source/rusentrel/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,51 @@ def __is_supported(version):
raise NotImplementedError("Collection does not supported")
return True

@staticmethod
def __number_from_string(s):
digit_chars = [chr for chr in s if chr.isdigit()]

if len(digit_chars) == 0:
return None

return int(u"".join(digit_chars))

@staticmethod
def __iter_indicies_from_dataset(version, folder_name):
assert(isinstance(folder_name, unicode))
assert(RuSentRelIOUtils.__is_supported(version))

used = set()

for filename in RuSentRelIOUtils.iter_filenames_from_zip(version):
if not folder_name in filename:
continue

index = RuSentRelIOUtils.__number_from_string(filename)

if index is None:
continue

if index in used:
continue

used.add(index)

yield index

# region public methods

@staticmethod
def iter_test_indices(version):
assert(RuSentRelIOUtils.__is_supported(version))

if version == RuSentRelVersions.V11:
missed = [70]
for i in xrange(RuSentRelIOUtils.__sep_doc_id, 76):
if i in missed:
continue
yield i
for index in RuSentRelIOUtils.__iter_indicies_from_dataset(version=version, folder_name=u"test/"):
yield index

@staticmethod
def iter_train_indices(version):
assert(RuSentRelIOUtils.__is_supported(version))

if version == RuSentRelVersions.V11:
missed = [9, 22, 26]
for i in xrange(1, RuSentRelIOUtils.__sep_doc_id):
if i in missed:
continue
yield i
for index in RuSentRelIOUtils.__iter_indicies_from_dataset(version=version, folder_name=u"train/"):
yield index

@staticmethod
def iter_collection_indices(version):
Expand Down
Loading

0 comments on commit 18251dc

Please sign in to comment.