Skip to content

Commit

Permalink
Handle NER usecase (#263)
Browse files Browse the repository at this point in the history
* Handle NER usecase

* Fix test
  • Loading branch information
Dref360 committed Aug 29, 2023
1 parent f9402e2 commit da99ee3
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion baal/transformers_trainer_wrapper.py
Expand Up @@ -79,7 +79,7 @@ def predict_on_dataset_generator(
)

out = map_on_tensor(lambda o: o.view([iterations, -1, *o.size()[1:]]), out)
out = map_on_tensor(lambda o: o.permute(1, 2, *range(3, o.ndimension()), 0), out)
out = map_on_tensor(lambda o: o.permute(1, *range(3, o.ndimension()), 2, 0), out)
out = map_on_tensor(lambda x: x.detach(), out)
if half:
out = map_on_tensor(lambda x: x.half(), out)
Expand Down
8 changes: 4 additions & 4 deletions baal/utils/iterutils.py
@@ -1,10 +1,10 @@
from collections.abc import Sequence
from collections.abc import Sequence, MutableMapping


def map_on_tensor(fn, val):
"""Map a function on a Tensor or a list of Tensors"""
if isinstance(val, Sequence):
return [fn(v) for v in val]
elif isinstance(val, dict):
return {k: fn(v) for k, v in val.items()}
return [map_on_tensor(fn, v) for v in val]
elif isinstance(val, (dict, MutableMapping)):
return {k: map_on_tensor(fn, v) for k, v in val.items()}
return fn(val)
4 changes: 2 additions & 2 deletions tests/transformers_trainer_wrapper_test.py
Expand Up @@ -40,11 +40,11 @@ def test_predict_on_dataset_generator(self):

# iteration == 1
pred = self.wrapper.predict_on_dataset_generator(self.dataset, 1, False)
assert next(pred).shape == (5, 10, 100, 1)
assert next(pred).shape == (5, 100, 10, 1)

# iterations > 1
pred = self.wrapper.predict_on_dataset_generator(self.dataset, 10, False)
assert next(pred).shape == (5, 10, 100, 10)
assert next(pred).shape == (5, 100, 10, 10)

# Test generators
l_gen = self.wrapper.predict_on_dataset_generator(self.dataset, 10, False)
Expand Down

0 comments on commit da99ee3

Please sign in to comment.