Skip to content

Commit

Permalink
#229. Done.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 16, 2021
1 parent f20d855 commit adc5742
Show file tree
Hide file tree
Showing 18 changed files with 38 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from arekit.common.entities.base import Entity
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.entities.types import EntityType
from arekit.common.languages.ru.cases import RussianCases
from arekit.common.languages.ru.number import RussianNumberType
from arekit.processing.languages.ru.cases import RussianCases
from arekit.processing.languages.ru.number import RussianNumberType
from arekit.processing.pos.russian import RussianPOSTagger


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from arekit.common.context.terms_mapper import TextTermsMapper
from arekit.common.languages.pos import PartOfSpeechType
from arekit.processing.languages.pos import PartOfSpeechType
from arekit.processing.pos.base import POSTagger


Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from arekit.common.languages.mods import BaseLanguageMods
from arekit.processing.languages.mods import BaseLanguageMods


class RussianLanguageMods(BaseLanguageMods):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from arekit.common.languages.pos import PartOfSpeechType
from arekit.processing.languages.pos import PartOfSpeechType


class PartOfSpeechTypesService(object):
Expand Down
11 changes: 6 additions & 5 deletions arekit/processing/pos/mystem_wrap.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from arekit.common.languages.pos import PartOfSpeechType
from arekit.common.languages.ru.cases import RussianCases, RussianCasesService
from arekit.common.languages.ru.number import RussianNumberType, RussianNumberTypeService
from arekit.common.languages.ru.pos_service import PartOfSpeechTypesService
from arekit.processing.pos.russian import RussianPOSTagger
from pymystem3 import Mystem

from arekit.processing.languages.pos import PartOfSpeechType
from arekit.processing.languages.ru.cases import RussianCases, RussianCasesService
from arekit.processing.languages.ru.number import RussianNumberType, RussianNumberTypeService
from arekit.processing.languages.ru.pos_service import PartOfSpeechTypesService
from arekit.processing.pos.russian import RussianPOSTagger


class POSMystemWrapper(RussianPOSTagger):

Expand Down
4 changes: 2 additions & 2 deletions arekit/processing/text/pipeline_frames.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from arekit.common.frames.text_variant import TextFrameVariant
from arekit.common.frames.variants.collection import FrameVariantsCollection
from arekit.common.languages.mods import BaseLanguageMods
from arekit.common.languages.ru.mods import RussianLanguageMods
from arekit.common.text.pipeline_ctx import PipelineContext
from arekit.common.text.pipeline_item import TextParserPipelineItem
from arekit.processing.languages.mods import BaseLanguageMods
from arekit.processing.languages.ru.mods import RussianLanguageMods


class FrameVariantsParser(TextParserPipelineItem):
Expand Down
2 changes: 1 addition & 1 deletion arekit/processing/text/pipeline_frames_lemmatized.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from arekit.common.languages.ru.mods import RussianLanguageMods
from arekit.common.text.pipeline_ctx import PipelineContext
from arekit.common.text.stemmer import Stemmer
from arekit.processing.languages.ru.mods import RussianLanguageMods
from arekit.processing.text.pipeline_frames import FrameVariantsParser


Expand Down
10 changes: 6 additions & 4 deletions tests/contrib/networks/test_tf_ctx_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import sys
import unittest


sys.path.append('../../../')

from tests.contrib.networks.tf_networks.supported import get_supported
from tests.contrib.networks.tf_networks.utils import init_config

from arekit.common.languages.ru.pos_service import PartOfSpeechTypesService
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
from arekit.processing.languages.ru.pos_service import PartOfSpeechTypesService


class TestContextNetworkCompilation(unittest.TestCase):
Expand All @@ -17,15 +17,17 @@ def test(self):
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

pos_items_count = PartOfSpeechTypesService.get_mystem_pos_count()

for config, network in get_supported():
assert(isinstance(config, DefaultNetworkConfig))
config.modify_classes_count(3)
config.set_pos_count(PartOfSpeechTypesService.get_mystem_pos_count())
config.set_pos_count(pos_items_count)

logger.info("Compile: {}".format(type(network)))
logger.info("Clases count: {}".format(config.ClassesCount))

init_config(config)
init_config(config=config, pos_items_count=pos_items_count)
network.compile(config, reset_graph=True, graph_seed=42)


Expand Down
9 changes: 8 additions & 1 deletion tests/contrib/networks/test_tf_ctx_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from arekit.contrib.networks.core.feeding.batch.base import MiniBatch
from arekit.contrib.networks.core.nn import NeuralNetwork

from arekit.processing.languages.ru.pos_service import PartOfSpeechTypesService


class TestContextNetworkFeeding(unittest.TestCase):

Expand Down Expand Up @@ -56,7 +58,12 @@ def run_feeding(network, network_config, create_minibatch_func, logger,
assert(isinstance(labels_scaler, BaseLabelScaler))
assert(callable(create_minibatch_func))

init_config(network_config)
pos_items_count = PartOfSpeechTypesService.get_mystem_pos_count()

# Init config.
init_config(config=network_config,
pos_items_count=pos_items_count)

# Init network.
network.compile(config=network_config, reset_graph=True, graph_seed=42)
minibatch = create_minibatch_func(config=network_config, labels_scaler=labels_scaler)
Expand Down
7 changes: 6 additions & 1 deletion tests/contrib/networks/test_tf_mi_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
from arekit.contrib.networks.multi.architectures.max_pooling import MaxPoolingOverSentences
from arekit.contrib.networks.multi.configurations.base import BaseMultiInstanceConfig
from arekit.processing.languages.ru.pos_service import PartOfSpeechTypesService


class TestMultiInstanceCompile(unittest.TestCase):
Expand All @@ -27,7 +28,11 @@ def mpmi(context_config, context_network):
config.modify_classes_count(3)

network = MaxPoolingOverSentences(context_network=context_network)
init_config(config)

pos_items_count = PartOfSpeechTypesService.get_mystem_pos_count()
init_config(config=config,
pos_items_count=pos_items_count)

network.compile(config, reset_graph=True, graph_seed=42)

def test(self):
Expand Down
7 changes: 4 additions & 3 deletions tests/contrib/networks/tf_networks/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import numpy as np

from arekit.common.languages.ru.pos_service import PartOfSpeechTypesService
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig


def init_config(config):
def init_config(config, pos_items_count):
assert(isinstance(config, DefaultNetworkConfig))
assert(isinstance(pos_items_count, int))

config.modify_classes_count(3)
config.set_term_embedding(np.zeros((100, 100)))
config.set_class_weights([1] * config.ClassesCount)
config.set_pos_count(PartOfSpeechTypesService.get_mystem_pos_count())
config.set_pos_count(pos_items_count)

# Notify other subscribers that initialization process has been completed.
config.init_initializers()
Expand Down

0 comments on commit adc5742

Please sign in to comment.