Skip to content

Commit

Permalink
Fix related to #158 and #156.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Aug 15, 2021
1 parent fdbc4fd commit fcde390
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
3 changes: 0 additions & 3 deletions tests/contrib/networks/test_samples_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

sys.path.append('../../../')

from tests.contrib.networks.labels import TestThreeLabelScaler

from arekit.common.experiment import const
from arekit.contrib.networks.core.input.rows_parser import ParsedSampleRow
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
Expand Down Expand Up @@ -83,7 +81,6 @@ def __test_core(self, words_vocab, config, samples_filepath):
assert(isinstance(samples_filepath, unicode))

samples = []
labels_scaler = TestThreeLabelScaler()
for i, row in enumerate(self.__iter_tsv_gzip(input_file=samples_filepath)):

# Perform row parsing process.
Expand Down
15 changes: 11 additions & 4 deletions tests/contrib/networks/test_tf_ctx_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@

sys.path.append('../../../')

from tests.contrib.networks.labels import TestThreeLabelScaler
from tests.contrib.networks.tf_networks.supported import get_supported
from tests.contrib.networks.tf_networks.utils import init_config

from arekit.common.experiment.data_type import DataType
from arekit.common.labels.scaler import BaseLabelScaler

from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
from arekit.contrib.networks.sample import InputSample
from arekit.contrib.networks.core.feeding.bags.bag import Bag
from arekit.contrib.networks.core.feeding.batch.base import MiniBatch
from arekit.contrib.networks.core.nn import NeuralNetwork

from tests.contrib.networks.tf_networks.supported import get_supported
from tests.contrib.networks.tf_networks.utils import init_config


class TestContextNetworkFeeding(unittest.TestCase):

Expand Down Expand Up @@ -48,16 +50,18 @@ def create_minibatch(config, labels_count):

@staticmethod
def run_feeding(network, network_config, create_minibatch_func, logger,
labels_scaler,
display_hidden_values=True,
display_idp_values=True):
assert(isinstance(network, NeuralNetwork))
assert(isinstance(network_config, DefaultNetworkConfig))
assert(isinstance(labels_scaler, BaseLabelScaler))
assert(callable(create_minibatch_func))

init_config(network_config)
# Init network.
network.compile(config=network_config, reset_graph=True, graph_seed=42)
minibatch = create_minibatch_func(config=network_config, labels_count=3)
minibatch = create_minibatch_func(config=network_config, labels_scaler=labels_scaler)

network_optimiser = network_config.Optimiser.minimize(network.Cost)

Expand Down Expand Up @@ -107,11 +111,14 @@ def test(self):
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)

labels_scaler = TestThreeLabelScaler()

for cfg, network in get_supported():
logger.debug("Feed to the network: {}".format(type(network)))
self.run_feeding(network=network,
network_config=cfg,
create_minibatch_func=self.create_minibatch,
labels_scaler=labels_scaler,
logger=logger)


Expand Down
9 changes: 5 additions & 4 deletions tests/contrib/networks/test_tf_mi_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@

sys.path.append('../../../')

from tests.contrib.networks.labels import TestNeutralLabel
from tests.contrib.networks.labels import TestNeutralLabel, TestThreeLabelScaler
from tests.contrib.networks.test_tf_ctx_feed import TestContextNetworkFeeding
from tests.contrib.networks.tf_networks.supported import get_supported

from arekit.common.labels.scaler import BaseLabelScaler

from arekit.contrib.networks.core.feeding.bags.bag import Bag
from arekit.contrib.networks.core.feeding.batch.multi import MultiInstanceMiniBatch

from arekit.contrib.networks.multi.configurations.max_pooling import MaxPoolingOverSentencesConfig
from arekit.contrib.networks.context.configurations.base.base import DefaultNetworkConfig
from arekit.contrib.networks.sample import InputSample
from arekit.contrib.networks.multi.architectures.max_pooling import MaxPoolingOverSentences

from arekit.common.labels.scaler import BaseLabelScaler


class TestMultiInstanceFeed(unittest.TestCase):
Expand Down Expand Up @@ -51,6 +49,8 @@ def test(self):
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

labels_scaler = TestThreeLabelScaler()

for ctx_config, ctx_network in get_supported():
for config, network in self.multiinstances_supported(ctx_config, ctx_network):
logger.info(type(network))
Expand All @@ -59,6 +59,7 @@ def test(self):
network_config=config,
create_minibatch_func=self.__create_minibatch,
logger=logger,
labels_scaler=labels_scaler,
display_idp_values=False)


Expand Down

0 comments on commit fcde390

Please sign in to comment.