Skip to content

Commit

Permalink
Added reduction funcitons
Browse files Browse the repository at this point in the history
  • Loading branch information
bclarkson-code committed Jan 14, 2024
1 parent a528bbd commit 669df29
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 0 deletions.
85 changes: 85 additions & 0 deletions src/tricycle_v2/reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np

from tricycle_v2.ops import _parse_subscripts, einsum, to_tensor
from tricycle_v2.tensor import Tensor


def radd(tensor: Tensor, subscript: str):
"""
Generate an indicator tensor that, when einsummed with the tensor, results
in a tensor that is equal to the result of summing along the indices
that dont appear in the output of the subscript
"""
indices, output = _parse_subscripts(subscript)
assert (
len(indices) == 1
), f"Can only reduce a single tensor at a time. Indices suggeststed: {len(indices)} tensors: {indices}"
[idx] = indices

indicator_indices = ""
reduce_along_axes = []
for i, char in enumerate(idx):
if char not in output:
indicator_indices += char
reduce_along_axes.append(i)

if not reduce_along_axes:
return tensor

indicator_shape = [tensor.shape[i] for i in reduce_along_axes]
indicator = to_tensor(np.ones(indicator_shape, dtype=np.bool_), requires_grad=False)

new_subscript = f"{idx},{indicator_indices}->{output}"
return einsum(new_subscript, tensor, indicator)


def rmax(tensor: Tensor, subscript: str):
"""
Generate an indicator tensor that, when einsummed with the tensor, results
in a tensor that is equal to the result of max applied along the indices
that dont appear in the output of the subscript
"""
indices, output = _parse_subscripts(subscript)
assert (
len(indices) == 1
), f"Can only reduce a single tensor at a time. Indices suggeststed: {len(indices)} tensors: {indices}"
[idx] = indices

reduce_along_axes = [i for i, char in enumerate(idx) if char not in output]

if not reduce_along_axes:
return tensor

indicator = (
tensor == np.max(tensor, axis=tuple(reduce_along_axes), keepdims=True)
).astype(int)

new_subscript = f"{idx},{idx}->{output}"

return einsum(new_subscript, tensor, indicator)


def rmin(tensor: Tensor, subscript: str):
"""
Generate an indicator tensor that, when einsummed with the tensor, results
in a tensor that is equal to the result of min applied along the indices
that dont appear in the output of the subscript
"""
indices, output = _parse_subscripts(subscript)
assert (
len(indices) == 1
), f"Can only reduce a single tensor at a time. Indices suggeststed: {len(indices)} tensors: {indices}"
[idx] = indices

reduce_along_axes = [i for i, char in enumerate(idx) if char not in output]

if not reduce_along_axes:
return tensor

indicator = (
tensor == np.min(tensor, axis=tuple(reduce_along_axes), keepdims=True)
).astype(int)

new_subscript = f"{idx},{idx}->{output}"

return einsum(new_subscript, tensor, indicator)
72 changes: 72 additions & 0 deletions tests/test_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np

from tricycle_v2.ops import to_tensor
from tricycle_v2.reduce import radd, rmax, rmin


def test_can_radd():
in_tensor = to_tensor(np.arange(3 * 4 * 5).reshape(3, 4, 5))

out_tensor = radd(in_tensor, "ijk->ik")

assert out_tensor.shape == (3, 5)

assert np.allclose(
out_tensor,
np.array(
[[30, 34, 38, 42, 46], [110, 114, 118, 122, 126], [190, 194, 198, 202, 206]]
),
)

out_tensor.backward()

assert np.allclose(
in_tensor.grad,
np.ones_like(in_tensor),
)


def test_can_rmax():
in_tensor = to_tensor(np.arange(3 * 4 * 5).reshape(3, 4, 5))

out_tensor = rmax(in_tensor, "ijk->ik")

assert out_tensor.shape == (3, 5)
assert np.allclose(
out_tensor,
np.array([[15, 16, 17, 18, 19], [35, 36, 37, 38, 39], [55, 56, 57, 58, 59]]),
)

out_tensor.backward()

assert np.allclose(
in_tensor.grad,
[
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]],
],
)


def test_can_rmin():
in_tensor = to_tensor(np.arange(3 * 4 * 5).reshape(3, 4, 5))

out_tensor = rmin(in_tensor, "ijk->ik")

assert out_tensor.shape == (3, 5)
assert np.allclose(
out_tensor,
np.array([[0, 1, 2, 3, 4], [20, 21, 22, 23, 24], [40, 41, 42, 43, 44]]),
)

out_tensor.backward()

assert np.allclose(
in_tensor.grad,
[
[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
],
)

0 comments on commit 669df29

Please sign in to comment.