In [1]:
from enum import Enum
from dataclasses import dataclass, asdict
import numpy as np
import math

from operator import add

In [2]:
from minitorch.functional import product, summation, multiply_lists, reduce

In [3]:
def strides_from_shape(shape):
    strides, offset = [1], 1
    for s in reversed(shape[1:]):
        strides.append(s * offset)
        offset = s * offset
    return tuple(reversed(strides))

In [4]:
def to_index(ordinal, shape, out_index) -> None:

    remaining_ordinal = ordinal
    for i, dim in enumerate(shape):
        is_last_dim = i == (len(shape) - 1)

        if not is_last_dim:
            remaining_size = int(product(shape[(i + 1) :].tolist()))
            idx = remaining_ordinal // remaining_size
            remaining_ordinal = remaining_ordinal % remaining_size
            out_index[i] = idx
        else:
            if remaining_ordinal // shape[i - 1] == 0:
                out_index[i] = remaining_ordinal
            else:
                out_index[i] = remaining_ordinal % dim


def broadcast_index(big_index, big_shape, shape, out_index) -> None:
    for i in range(len(shape)):
        offset = i + len(big_shape) - len(shape)
        out_index[i] = big_index[offset] if shape[i] != 1 else 0


def index_to_position(index, strides) -> int:
    index, strides = index.tolist(), strides.tolist()
    return int(summation(multiply_lists(index, strides)))


def shape_broadcast(shape_a, shape_b):

    def expand_dims(*dims):
        max_dim = max(len(dim) for dim in dims)
        dims = [(1,) * (max_dim - len(dim)) + dim for dim in dims]
        return dims

    if len(shape_a) != len(shape_b):
        # Expand dimension to match
        shape_a, shape_b = expand_dims(shape_a, shape_b)

    broadcast_shape = tuple()
    for (a, b) in zip(shape_a, shape_b):
        if a != b:
            if min(a, b) != 1:
                raise IndexError(
                    f"Shapes {shape_a} and {shape_b} cannot be broadcast together."
                )

        broadcast_shape += (max(a, b),)

    return broadcast_shape

In [21]:
fn = add
dim, start = 0, 0

a_arr = np.arange(27).reshape(3, 3, 3)
a_storage = a_arr.flatten()
a_shape = np.array(a_arr.shape)
a_strides = np.array(strides_from_shape(a_shape))
a_index = np.zeros_like(a_shape)

out_shape = list(a_arr.shape)
out_shape[dim] = 1

out_shape = np.array(out_shape)
out_strides = np.array(strides_from_shape(out_shape))
out_arr = np.zeros(tuple(out_shape))
out_storage = np.ones((out_arr.size,)) * start
out_index = np.zeros_like(out_shape)


In [22]:
out_size = int(product(out_shape.tolist()))
for i in range(out_size):
    
    # Get the corresponding out_index
    to_index(i, out_shape, out_index)
    
    # Get all indices in original tensor that would give rise to that index
    a_positions = []
    max_dim = a_shape[dim]
    for j in range(max_dim):
        out_index[dim] = j
        a_positions.append(index_to_position(out_index, a_strides))
    
    a_values = [a_storage[i] for i in a_positions]
    out_storage[i] = reduce(add, out_storage[i])(a_values)

In [23]:
out_storage.reshape(out_shape)

array([[[27., 30., 33.],
        [36., 39., 42.],
        [45., 48., 51.]]])

In [24]:
a_arr

array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]])

In [53]:
# Iterate over each dimension in out_index
out_size = int(product(out_shape.tolist()))
for i in range(out_size):
    to_index(i, out_shape, out_index)
    broadcast_index(out_index, out_shape, a_shape, a_index)
    broadcast_index(out_index, out_shape, b_shape, b_index)

    a_position = index_to_position(a_index, a_strides)
    b_position = index_to_position(b_index, b_strides)
    
    out_storage[i] = add(a_storage[a_position], b_storage[b_position])

out_storage.reshape(out_shape)

array([[ 0.,  2.,  4.,  6.,  8.],
       [ 5.,  7.,  9., 11., 13.]])

In [41]:
np.add(a_arr, b_arr)

array([[ 0,  2,  4,  6,  8],
       [ 5,  7,  9, 11, 13]])

In [30]:
np.exp(in_arr, out_arr)

array([[1.        , 1.01005017, 1.02020134, 1.03045453, 1.04081077],
       [1.        , 1.01005017, 1.02020134, 1.03045453, 1.04081077],
       [1.        , 1.01005017, 1.02020134, 1.03045453, 1.04081077],
       [1.        , 1.01005017, 1.02020134, 1.03045453, 1.04081077]])

In [73]:
big_shape = big_arr.shape
small_shape = small_arr.shape
big_index = (0, 0, 1)
out_index = [1,] * len(small_shape)

broadcast_index(big_shape, big_index, small_shape, out_index)
out_index

[0, 0]