Skip to content

Commit

Permalink
RF Factor.update_labels(): allow dict with labels not in Factor
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed Jul 16, 2021
1 parent 2172e50 commit de9cb8f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 20 deletions.
23 changes: 5 additions & 18 deletions eelbrain/_data_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2943,29 +2943,16 @@ def update_labels(self, labels):
>>> Factor(f, labels={'c': 'v3'})
Factor(['v1', 'v1', 'v1', 'v2', 'v2', 'v2', 'v3', 'v3', 'v3'])
Notes
-----
If ``labels`` contains a key that is not a label of the Factor, a
``KeyError`` is raised.
"""
missing = [old for old in labels if old not in self._codes]
if missing:
if len(missing) == 1:
msg = ("Factor does not contain label %r" % missing[0])
else:
msg = ("Factor does not contain labels %s"
% ', '.join(repr(m) for m in missing))
raise KeyError(msg)

new_labels = {code: labels.get(label, label) for code, label in self._labels.items()}
# check for merged labels
new_labels = {c: labels.get(l, l) for c, l in self._labels.items()}
codes_ = sorted(new_labels)
labels_ = tuple(new_labels[c] for c in codes_)
labels_ = [new_labels[c] for c in codes_]
for i, label in enumerate(labels_):
if label in labels_[:i]:
first_i = labels_.index(label)
if first_i < i:
old_code = codes_[i]
new_code = codes_[labels_.index(label)]
new_code = codes_[first_i]
self.x[self.x == old_code] = new_code
del new_labels[old_code]

Expand Down
5 changes: 3 additions & 2 deletions eelbrain/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,8 +816,9 @@ def test_factor_relabel():
assert_array_equal(f, Factor('cccbbbddd'))
f.update_labels({'d': 'c'})
assert_array_equal(f, Factor('cccbbbccc'))
with pytest.raises(KeyError):
f.update_labels({'a': 'c'})
# label not in f
f.update_labels({'b': 'x', 'a': 'c'})
assert_array_equal(f, Factor('cccxxxccc'))


def test_interaction():
Expand Down

0 comments on commit de9cb8f

Please sign in to comment.