In [1]:
'''
Link to the explanation video:
https://drive.google.com/file/d/18As9KEln4FgOReIXqjxmudeilEShhL1a/view?usp=sharing
Last 3 minutes are explanations of the comparisons with PyTorch :)
'''
import torch

## Q1

In [2]:
# defining the error for expand_as function
class BroadcastError(Exception):
    def __init__(self, a, b):
        super().__init__(f'The tensors can\'t be broadcasted together!')

### Q1.b

In [4]:
def is_broadcastable(a, b):
    smaller_dim = min(len(a.shape), len(b.shape))
    bigger_dim = max(len(a.shape), len(b.shape))
    is_a_bigger = len(a.shape) > len(b.shape)
    expected_broadcasted_axes = []
    # traversing the dimensions of the smaller tensor
    for i in range(smaller_dim):
        # i-th index from the end
        idx = -1-i
        # if the axes are different and none of them is a 1, they're not broadcastable
        if a.shape[idx] != b.shape[idx] and (a.shape[idx] != 1 and b.shape[idx] != 1):
            return False
        expected_broadcasted_axes.append(max(a.shape[idx], b.shape[idx]))

    # appending the expected axes values
    for i in range(smaller_dim, bigger_dim):
        idx = -1-i
        if is_a_bigger:
            expected_broadcasted_axes.append(a.shape[idx])
        elif len(a.shape) == len(b.shape):
            expected_broadcasted_axes.append(max(a.shape[idx], b.shape[idx]))
        else:
            expected_broadcasted_axes.append(b.shape[idx])
    # returning the shape of the expected broadcasted tensor
    return True, torch.Size(reversed(expected_broadcasted_axes))

### Q1.a

In [5]:
def expand_as(expand_from, expand_to):
    a = expand_from
    b = expand_to
    # we can't broadcast if the tensor to be broadcasted is bigger than the target tensor
    if len(a.shape) > len(b.shape) or not is_broadcastable(a, b):
        raise BroadcastError(a, b)
    # can't broadcast source to target where the i-th axis in source is bigger than i-th axis in target (rightmost)
    for i in range(len(a.shape)):
        idx = -1-i
        if a.shape[idx] > b.shape[idx]:
            raise BroadcastError(a, b)

    dimensions_diff = len(b.shape) - len(a.shape)
    # cloning the smaller tensor for building the broadcasted tensor
    c = a.clone()

    # adding axes to c
    for i in range(dimensions_diff):
        c = torch.unsqueeze(c, 0)

    # making sure they're of the same length
    assert len(c.shape) == len(b.shape)

    other = b
    c_axes, other_axes = list(c.shape), list(other.shape)

    num_axes = len(c_axes)
    # concatenating the shorter tensor to itself along the relevant axes
    for i in range(num_axes):
        idx = -1-i
        current_axis = num_axes+idx
        # if it's a 1, concatenate the current axis diff times
        if c_axes[idx] < other_axes[idx]:
            diff = other_axes[idx] - c_axes[idx]
            # cloning c to be able to use it as it is now for the concatenation
            cloned = c.clone()
            for j in range(diff):
                c = torch.cat((c, cloned), dim=current_axis)
    return c

### Q1.c

In [6]:
def broadcast_tensors(a, b):
    # if both can't be broadcasted to each other
    if not is_broadcastable(a, b) and not is_broadcastable(b, a):
        raise BroadcastError(a, b)

    bigger_shape = a.shape if len(a.shape) > len(b.shape) else b.shape
    smaller_shape = a.shape if len(a.shape) < len(b.shape) else b.shape
    bigger_shape_length = max(len(a.shape), len(b.shape))
    smaller_shape_length = min(len(a.shape), len(b.shape))

    # getting the resulting tensor size
    result_axes = []

    for i in range(smaller_shape_length):
        idx = -1-i
        result_axes.append(max(a.shape[idx], b.shape[idx]))
    # add the remaining axes
    result_axes.extend(bigger_shape[:bigger_shape_length - smaller_shape_length])
    # reverse the axes because we started from the lower axes for correct format
    result_axes = result_axes[::-1]

    # creating a tensor of the correct shape
    tensor_with_final_shape = torch.zeros(result_axes)

    # distinguishing between the 2 cases for intuitive usage of the function
    # where calling the function with a, b we'll receive broadcasted_a, broadcasted_b
    return expand_as(a, tensor_with_final_shape), expand_as(b, tensor_with_final_shape)

### Q1.d

In [12]:
tensors_pairs = [[torch.tensor([3, 3]), torch.tensor([2])],
    [torch.tensor([[1,3,5],[1,3,5]]), torch.tensor([2])],
    [torch.tensor([1,2]), torch.tensor([[2,3,4], [5,6,7]])],
    [torch.arange(10**4).reshape(10, 10, 10, 1, 10), torch.arange(10**5).view(10, 10, 10, 10, 10)],
    [torch.tensor([1,2,3]), torch.tensor([[2,3,4], [5,6,7]])],
    [torch.tensor([[1,2,3]]), torch.tensor([[2,3,4], [5,6,7]])],

    [torch.tensor([[[1,2,3]]]), torch.tensor([[2,3,4], [5,6,7]])],
    [torch.arange(10**3).reshape(10, 1, 10, 1, 10), torch.arange(10**5).view(10, 10, 10, 10, 10)],
    [torch.arange(10**3).reshape(10, 10, 1, 10), torch.arange(10**5).view(10, 10, 10, 10, 10)],
    [torch.arange(10**2).reshape(10, 1, 1, 10), torch.arange(10**5).view(10, 10, 10, 10, 10)],

    [torch.arange(10**2).reshape(10, 10), torch.arange(10**5).view(10, 10, 10, 10, 10)]]

In [14]:
number_of_examples = 3

for i in range(number_of_examples):
    a, b = tensors_pairs[i][0], tensors_pairs[i][1]
    print(f'a - {a}', f'b - {b}', sep='\n')
    print(f'a.shape - {a.shape}', f'b.shape - {b.shape}', sep='\n')
    print()
    print(f'Is broadcastable implementation: {is_broadcastable(a, b)}')
    try:
        print(f'expand_as implementation: {expand_as(a, b)}', f'PyTorch: {a.expand_as(b)}', sep='\t')
    except:
        print('An error occured - expand_as(a,b)')
    try:
        print(f'expand_as implementation (reversed): {expand_as(b, a)}', f'PyTorch: {b.expand_as(a)}', sep='\t')
    except:
        print('An error occured - expand_as(b,a)')
    try:
        print(f'Broadcast implementation: {broadcast_tensors(a, b)}', f'PyTorch: {torch.broadcast_tensors(a, b)}', sep='\t')
    except:
        print('An error occured - broadcast_tensors(a,b)')

    print('\n\n')

a - tensor([3, 3])
b - tensor([2])
a.shape - torch.Size([2])
b.shape - torch.Size([1])

Is broadcastable implementation: (True, torch.Size([2]))
An error occured - expand_as(a,b)
expand_as implementation (reversed): tensor([2, 2])	PyTorch: tensor([2, 2])
Broadcast implementation: (tensor([3, 3]), tensor([2, 2]))	PyTorch: (tensor([3, 3]), tensor([2, 2]))



a - tensor([[1, 3, 5],
        [1, 3, 5]])
b - tensor([2])
a.shape - torch.Size([2, 3])
b.shape - torch.Size([1])

Is broadcastable implementation: (True, torch.Size([2, 3]))
An error occured - expand_as(a,b)
expand_as implementation (reversed): tensor([[2, 2, 2],
        [2, 2, 2]])	PyTorch: tensor([[2, 2, 2],
        [2, 2, 2]])
Broadcast implementation: (tensor([[1, 3, 5],
        [1, 3, 5]]), tensor([[2, 2, 2],
        [2, 2, 2]]))	PyTorch: (tensor([[1, 3, 5],
        [1, 3, 5]]), tensor([[2, 2, 2],
        [2, 2, 2]]))



a - tensor([1, 2])
b - tensor([[2, 3, 4],
        [5, 6, 7]])
a.shape - torch.Size([2])
b.shape - torch.Size