In [1]:
from typing import Tuple, List

Index = Tuple[int, ...]
Shape = Tuple[int, ...]
OutIndex = List[int]

In [2]:
def broadcast_index(
    big_index: Index, big_shape: Shape, shape: Shape, out_index: OutIndex
) -> None:
    """
    Convert a `big_index` into `big_shape` to a smaller `out_index`
    into `shape` following broadcasting rules. In this case
    it may be larger or with more dimensions than the `shape`
    given. Additional dimensions may need to be mapped to 0 or
    removed.

    Args:
        big_index : multidimensional index of bigger tensor
        big_shape : tensor shape of bigger tensor
        shape : tensor shape of smaller tensor
        out_index : multidimensional index of smaller tensor

    Returns:
        None
    """
    # TODO: Implement for Task 2.2.
    len1, len2 = len(big_shape), len(shape)
    if len1 < len2:
        big_shape = (1,) * (len2 - len1) + big_shape
        big_index = (0,) * (len2 - len1) + big_index
    elif len2 < len1:
        shape = (1,) * (len1 - len2) + shape

    # Initialize out_index with zeros
    out_index[:] = [0] * len(shape)

    # Map big_index to out_index following broadcasting rules
    for i in range(len(shape)):
        if big_shape[i] == 1:
            out_index[i] = 0
        elif shape[i] == 1:
            out_index[i] = 0
        else:
            out_index[i] = big_index[i]

In [3]:
from typing import Tuple, List

UserShape = Tuple[int, ...]

def shape_broadcast(shape1: UserShape, shape2: UserShape) -> UserShape:
    """
    Broadcast two shapes to create a new union shape.

    Args:
        shape1 : first shape
        shape2 : second shape

    Returns:
        broadcasted shape

    Raises:
        IndexingError : if cannot broadcast
    """
    # Make the shapes the same length by adding 1s to the left
    len1, len2 = len(shape1), len(shape2)
    if len1 < len2:
        shape1 = (1,) * (len2 - len1) + shape1
    elif len2 < len1:
        shape2 = (1,) * (len1 - len2) + shape2

    # Check for compatibility and compute the broadcasted shape
    broadcasted_shape = []
    for s1, s2 in zip(shape1, shape2):
        if s1 == s2 or s1 == 1 or s2 == 1:
            broadcasted_shape.append(max(s1, s2))
        else:
            raise IndexingError(f"Cannot broadcast shapes {shape1} and {shape2}")

    return tuple(broadcasted_shape)

In [5]:
# Пример использования shape_broadcast
shape1 = (1, 4, 5)
shape2 = (1, 2, 4, 5)
broadcasted_shape = shape_broadcast(shape1, shape2)
print("Broadcasted shape:", broadcasted_shape)

# Пример использования broadcast_index
big_index = (1, 2, 3)
big_shape = (3, 4, 5)
shape = (1, 2, 4, 5)
out_index = [0] * len(shape)

broadcast_index(big_index, big_shape, shape, out_index)
print("out_index:", out_index)

Broadcasted shape: (1, 2, 4, 5)
out_index: [0, 1, 2, 3]
