# Understanding `Aggregate`

In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
from tracr.rasp import rasp
from tracr.compiler import compiling
import numpy as np

In [2]:
mat = [[1,0,1],
       [0,1,0],
       [1,1,1]]

seq = [1, 3, 2]

selector = rasp.ConstantSelector(mat)
sop = rasp.ConstantSOp(seq)
out = rasp.Aggregate(selector, sop)

print(*out([0,0,0]), sep="\n")

1.5
3.0
2.0


In [3]:
mat = np.array(mat, dtype=bool)
seq = np.array(seq)

print('Setup: we want to aggregate')
print(mat, 'with', seq)
print()
interm = mat * seq
print("Row-wise multiplying gives")
print(interm)
print()
print("...and then averaging across rows gives")
out = interm.sum(axis=1) / mat.sum(axis=1)
print(out[:, np.newaxis])
print()
print("Note we average only over those elements in each row"
      " marked by True in the original select matrix.")

#print(np.einsum("ij,j->ji", mat, seq))
#
#np.mean((mat * seq).T, axis=0)


Setup: we want to aggregate
[[ True False  True]
 [False  True False]
 [ True  True  True]] with [1 3 2]

Row-wise multiplying gives
[[1 0 2]
 [0 3 0]
 [1 3 2]]

...and then averaging across rows gives
[[1.5]
 [3. ]
 [2. ]]

Note we average only over those elements in each row marked by True in the original select matrix.
