Skip to content

Commit

Permalink
Fix related to #158 and #156.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Aug 15, 2021
1 parent fcde390 commit 8434e3c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
13 changes: 5 additions & 8 deletions tests/contrib/networks/test_tf_ctx_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,13 @@ def init_session():
return sess

@staticmethod
def create_minibatch(config, labels_count):
def __create_minibatch(config, labels_scaler):
assert(isinstance(config, DefaultNetworkConfig))
assert(isinstance(labels_count, int))

l_uint_min = 0
l_uint_max = labels_count - 1
assert(isinstance(labels_scaler, BaseLabelScaler))

bags = []
for i in range(config.BagsPerMinibatch):
uint_label = random.randint(l_uint_min, l_uint_max)
uint_label = random.randint(0, labels_scaler.LabelsCount)
bag = Bag(uint_label=uint_label)
for j in range(config.BagSize):
bag.add_sample(InputSample._generate_test(config))
Expand All @@ -61,7 +58,7 @@ def run_feeding(network, network_config, create_minibatch_func, logger,
init_config(network_config)
# Init network.
network.compile(config=network_config, reset_graph=True, graph_seed=42)
minibatch = create_minibatch_func(config=network_config, labels_scaler=labels_scaler)
minibatch = create_minibatch_func(config=network_config, labels_count=labels_scaler)

network_optimiser = network_config.Optimiser.minimize(network.Cost)

Expand Down Expand Up @@ -117,7 +114,7 @@ def test(self):
logger.debug("Feed to the network: {}".format(type(network)))
self.run_feeding(network=network,
network_config=cfg,
create_minibatch_func=self.create_minibatch,
create_minibatch_func=self.__create_minibatch,
labels_scaler=labels_scaler,
logger=logger)

Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/networks/test_tf_mi_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def __create_minibatch(config, labels_scaler):
assert(isinstance(config, DefaultNetworkConfig))
assert(isinstance(labels_scaler, BaseLabelScaler))
bags = []
label = TestNeutralLabel()
no_label = labels_scaler.get_no_label_instance()
empty_sample = InputSample.create_empty(terms_per_context=config.TermsPerContext,
frames_per_context=config.FramesPerContext,
synonyms_per_context=config.SynonymsPerContext)
for i in range(config.BagsPerMinibatch):
bag = Bag(label)
bag = Bag(labels_scaler.label_to_uint(no_label))
for j in range(config.BagSize):
bag.add_sample(empty_sample)
bags.append(bag)
Expand Down

0 comments on commit 8434e3c

Please sign in to comment.