Skip to content

Commit

Permalink
FIX Model: interaction with Var
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed Feb 27, 2017
1 parent df45d33 commit 5c10c64
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
25 changes: 22 additions & 3 deletions eelbrain/_data_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@
* Dataset
Effect types
------------
These are elementary effects in a Model, and identified by :func:`is_effect`
- _Effect
- Factor
- Interaction
- NestedEffect
- Var
- NonBasicEffect
"""

from __future__ import division
Expand Down Expand Up @@ -1608,7 +1622,9 @@ def __rsub__(self, other):
return Var(x, info=info)

def __mul__(self, other):
if iscategorial(other):
if isinstance(other, Model):
return Model((self,)) * other
elif iscategorial(other):
return Model((self, other, self % other))
elif isinstance(other, Var):
x = self.x * other.x
Expand Down Expand Up @@ -6087,7 +6103,7 @@ class Interaction(_Effect):
Parameters
----------
base : list
base : sequence
List of data-objects that form the basis of the interaction.
Attributes
Expand Down Expand Up @@ -6581,7 +6597,10 @@ def __mod__(self, other):
out = []
for e_self in self.effects:
for e_other in Model(other).effects:
out.append(e_self % e_other)
if isinstance(e_self, Var) and isinstance(e_other, Var):
out.append(e_self * e_other)
else:
out.append(e_self % e_other)
return Model(out)

def __eq__(self, other):
Expand Down
17 changes: 17 additions & 0 deletions eelbrain/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ def test_model():
"Test Model class"
a = Factor('ab', repeat=3, name='a')
b = Factor('ab', tile=3, name='b')
u = Var([1, 1, 1, -1, -1, -1], 'u')
v = Var([1., 2., 3., 4., 5., 6.], 'v')
w = Var([1., 0., 0., 1., 1., 0.], 'w')

Expand Down Expand Up @@ -752,6 +753,22 @@ def test_model():
mp = pickle.loads(pickle.dumps(m, pickle.HIGHEST_PROTOCOL))
assert_array_equal(m.full, mp.full)

# nested Vars
m = (v + w) * u
assert_dataobj_equal(m.effects[2], u)
assert_dataobj_equal(m.effects[3], v * u)
assert_dataobj_equal(m.effects[4], w * u)
m = u * (v + w)
assert_dataobj_equal(m.effects[0], u)
assert_dataobj_equal(m.effects[3], u * v)
assert_dataobj_equal(m.effects[4], u * w)
m = (v + w) % u
assert_dataobj_equal(m.effects[0], v * u)
assert_dataobj_equal(m.effects[1], w * u)
m = u % (v + w)
assert_dataobj_equal(m.effects[0], u * v)
assert_dataobj_equal(m.effects[1], u * w)


def test_ndvar():
"Test the NDVar class"
Expand Down

0 comments on commit 5c10c64

Please sign in to comment.