# Teacher's Assignment No. 11

***Author:*** *Ofir Paz* $\qquad$ ***Version:*** *14.03.2024* $\qquad$ ***Course:*** *22961 - Deep Learning*

Welcome to the first assignment of the course *Deep Learning*. \
In this first assignemnt we will implement:

$\quad$ [**a.**](#a) $\space$ `A.expand_as(B)` functionality. \
$\quad$ [**b.**](#b) $\space$ A function that tests if two tensors can be broadcasted together, and return the size of the broadcast. \
$\quad$ [**c.**](#c) $\space$ A function which broadcasts two tensors. \
$\quad$ [**d.**](#d) $\space$ Tests for the different implemented functions.

Allowed functions in this assignment:

$\qquad$ `squeeze` `unsqueeze` `cat` `stack` `x.reshape` `x.reshape_as` `x.clone`

Unallowed functions in this assignment:

$\qquad$ `x.expand` `x.expand_as` `x.repeat` `broadcast_tensors` `broadcast_to` `vmap`

## Imports

First, we will import the required packages for this assignment
- [pytorch](https://pytorch.org/) - One of the most fundemental and famous tensor handling library.

In [1]:
import torch  # pytorch

## Part A <a id='a'></a>

In this part, I implemented `expand_as`, which is a function that takes two tensors and checks if the first tensor is expandable to the second tensor, and if it is, expands it. This while not altering the dimentions of the second tensor whatsoever. We first start with the help method, `is_legal_expand`, which returns `True` or `False` if the first tensor can be expanded to the second one.

In [2]:
def is_legal_expand(a_shape: torch.Size, b_shape: torch.Size) -> bool:
  """Check if the shape of `a` can be expanded to the shape of `b`.
  
  The laws of tensor expansion are as follows:
  
  a.  Start from the last dimension (the rightmost) of both tensors and check if:
      They are equal to each other, or the current dimension of `a` is equal to 1.
      If none of these conditions are met, return `False`.

  b.  Move one dimension to the left in each of the tensors and repeat the above check.

  c.  Once we have iterated over all dimensions from right to left in at least one of the tensors
      without encountering any error, return `True`.

  Args:
    a_shape: shape of tensor a.
    b_shape: shape of tensor b.

  Returns:
    True if `a` can be expanded to `b`, False otherwise.
  """

  # reverse the shapes of tensors.
  a_shape = a_shape[::-1]
  b_shape = b_shape[::-1]

  # Check if there are enough dimensions in `b` to expand `a` to.
  if len(a_shape) > len(b_shape):
    return False

  # iterate over the shapes of tensors.
  for a_dim, b_dim in zip(a_shape, b_shape):
    if a_dim != b_dim and a_dim != 1:
      return False

  return True

In [3]:
def expand_as(tensor_a: torch.Tensor, tensor_b: torch.Tensor) -> torch.Tensor | bool:
  """Expand `tensor_a` to the same size as `tensor_b`
  
  This is my implementation of the expand_as function.
  If `tensor_a` can't be expanded to `tensor_b`, the function returns False.

  During the expansion stage, we duplicate the tensor values according to the following rules:

  a.  If the tensor has smaller dimensions than the other (i.e., fewer dimensions),
      we prepend singleton dimensions to it until the number of dimensions in both 
      tensors is equal.

  b.  Whenever there's a mismatch in dimensionality, the expansion compatibility
      check succeeds. Therefore, `tensor_a` has a singleton dimension of size 1,
      which is expanded along that dimension.

  Args:
    tensor_a (torch.Tensor): tensor to expand.
    tensor_b (torch.Tensor): tensor to expand to.

  Returns:
    torch.Tensor | False: `tensor_a` expanded to the same size as `tensor_b` or
    `False` if `tensor_a` can't be expanded to `tensor_b`.
  """
  
  # Extracts the tensors' shapes.
  a_shape, b_shape = tensor_a.shape, tensor_b.shape 

  # Return `False` if tensor_a can't be expanded to `tensor_b`.
  if not is_legal_expand(a_shape, b_shape):
    return False
  
  # expand `tensor_a` to the same size as `tensor_b`.
  # Stage 1: Prepend singleton dimensions to `tensor_a`.
  for _ in range(len(b_shape) - len(a_shape)):
    tensor_a = tensor_a.unsqueeze(0) 
  
  a_shape = tensor_a.shape  # update `a_shape` after prepending singleton dimensions.

  # Stage 2: expand `tensor_a` to `tensor_b` while assuming it can be expanded.
  for idx_dim, (a_dim, b_dim) in enumerate(zip(a_shape, b_shape)) :
    if a_dim != b_dim:  # This means that `a_dim` is 1, and `b_dim` is not 1.
      tensor_a = torch.cat([tensor_a] * b_dim, dim=idx_dim)  # duplicate `tensor_a` values.

  return tensor_a
  

## Part B <a id='b'></a>

This part requires us to create a function that checks if two tensors can be broadcasted togther, and if so, return the size of the broadcast. We will create two separate function like we did in [Part A](#a) to achieve this goal. The code will be similar to the code from the last part, but we cannot use the functions from there since the broadcasting operation is different from the expansion operation because it affects both participating tensors.

In [4]:
def is_legal_broadcast(a_shape: torch.Size, b_shape: torch.Size) -> bool:
  """Check if `a` and `b` can be broadcasted together based on their shapes.
  
  The laws of tensor broadcast are as follows:
  
  a.  Start from the last dimension (the rightmost) of both tensors and check if:
      They are equal to each other, or the current dimension of any tensor is equal to 1.
      If none of these conditions are met, return `False`.

  b.  Move one dimension to the left in each of the tensors and repeat the above check.

  c.  Once we have iterated over all dimensions from right to left in at least one of the tensors
      without encountering any error, return `True`.

  Args:
    a_shape: shape of tensor a.
    b_shape: shape of tensor b.

  Returns:
    True if `a` can be expanded to `b`, False otherwise.
  """

  # reverse the shapes of tensors.
  a_shape = a_shape[::-1]
  b_shape = b_shape[::-1]

  # iterate over the shapes of tensors and return `False` if the condition is not met.
  for a_dim, b_dim in zip(a_shape, b_shape):
    if a_dim != b_dim and a_dim != 1 and b_dim != 1:
      return False

  return True

In [5]:
def get_broadcast_info(tensor_a: torch.Tensor, tensor_b: torch.Tensor) -> \
    tuple[torch.Size | None, bool]:
  """Get the broadcasting information between `tensor_a` and `tensor_b`.
  
  Args:
    tensor_a (torch.Tensor): first tensor of the broadcasting operation. 
    tensor_b (torch.Tensor): second tensor of the broadcasting operation.

  Returns:
    tuple[torch.Size | None, bool]: A tuple containing the following information:
    - The shape of the broadcasted tensor if the first element is True, and None otherwise.
    - A boolean value indicating whether `tensor_a` can be broadcasted with `tensor_b`.
  """

  # Extracts the tensors' shapes.
  a_shape, b_shape = tensor_a.shape, tensor_b.shape

  # Return `False` if `tensor_a` can't be broadcasted with `tensor_b`.
  if not is_legal_broadcast(a_shape, b_shape):
    return None, False

  # Add singleton dimensions to smaller shape.
  for _ in range(abs(len(b_shape) - len(a_shape))):
    if len(a_shape) < len(b_shape):
      a_shape = (1,) + a_shape
    else:
      b_shape = (1,) + b_shape
  
  # Initialize the broadcast shape.
  broadcast_shape = []

  # Iterate over the shapes of tensors.
  for a_dim, b_dim in zip(a_shape[::-1], b_shape[::-1]):
    broadcast_shape.append(max(a_dim, b_dim))  # get the maximum value between `a_dim` and `b_dim`.
  
  return torch.Size(broadcast_shape[::-1]), True  # return the broadcast shape and `True`.

## Part C <a id='c'></a>

In this sub-section, we will combine the functions from the previous part to make the general broadcasting function. This function will take two tensors, check if the tensors can be broadcasted, and if so broadcast them together.

In [6]:
def broadcast_tensors(tensor_a: torch.Tensor, tensor_b: torch.Tensor) -> \
    tuple[torch.Tensor, torch.Tensor] | None:
  """Broadcast `tensor_a` and `tensor_b` together.

  This is my implementation of the broadcast_tensors function.
  If `tensor_a` can't be broadcasted with `tensor_b`, the function will not do anything
  and return `None`.

  Args:
    tensor_a (torch.Tensor): first tensor of the broadcasting operation. 
    tensor_b (torch.Tensor): second tensor of the broadcasting operation.
  
  Returns:
    tuple[torch.Tensor, torch.Tensor] | None: A tuple containing the following information:
    - The broadcasted `tensor_a` and `tensor_b` if the tensors can be broadcasted, and None otherwise.
  """

  # Get the info about the broadcast.
  broadcast_shape, is_legal_broadcast = get_broadcast_info(tensor_a, tensor_b)

  # Check if the two tensors can be broadcased together.
  if not is_legal_broadcast:
    return None
  
  # Broadcast `tensor_a` and `tensor_b` together.
  # Extract the current shapes of the tensors.
  a_shape, b_shape = tensor_a.shape, tensor_b.shape

  # Copy the tensors to avoid modifying the original tensors.
  tensor_a, tensor_b = tensor_a.clone(), tensor_b.clone()

  # Stage 1: Prepend singleton dimensions to the smaller tensor.
  for _ in range(abs(len(b_shape) - len(a_shape))):
    if len(a_shape) < len(b_shape):
      tensor_a = tensor_a.unsqueeze(0)
    else:  # len(a_shape) > len(b_shape)
      tensor_b = tensor_b.unsqueeze(0)
  
  # update shapes after prepending singleton dimensions.
  a_shape, b_shape = tensor_a.shape, tensor_b.shape

  # Stage 2: broadcast `tensor_a` with `tensor_b` while assuming they can be broadcasted.
  for idx_dim, (a_dim, b_dim, bc_dim) in enumerate(zip(a_shape, b_shape, broadcast_shape)):
    if a_dim != bc_dim:  # Broadcast dimention of `tensor_a` if needed.
      tensor_a = torch.cat([tensor_a] * bc_dim, dim=idx_dim)  # duplicate `tensor_a` values.

    if b_dim != bc_dim:  # Broadcast dimention of `tensor_b` if needed.
      tensor_b = torch.cat([tensor_b] * bc_dim, dim=idx_dim)  # duplicate `tensor_b` values.
  
  return tensor_a, tensor_b

## Part D <a id='d'></a>

In this sub-section, we will implement tests for the various self-created functions.
We will time the tests and check if the outputs are correct. 

The next block contains tests for [Part A](#a).

In [7]:
# `expand_as` correctness tests.
a = torch.zeros(1, 2, 82, 1, 10)
b = torch.zeros(100, 30, 2, 82, 7, 10)
c = torch.zeros(1, 82, 1, 10)

# Inputs for `False` expected output.
print(f"{expand_as(a, c) = }", f"{expand_as(b, a) = }", sep='\n')  # expected output: False.
print('Test Passed!' if expand_as(a, c) is False and expand_as(b, a) is False else 'Test Failed!')
print()  # Next test.

print(f"{expand_as(a, b).shape = }", f"{expand_as(c, a).shape = }", sep='\n')  # expected output: Shapes.
print('Test Passed!' if expand_as(a, b).shape == a.expand_as(b).shape and
          expand_as(c, a).shape == c.expand_as(a).shape else 'Test Failed!')

expand_as(a, c) = False
expand_as(b, a) = False
Test Passed!

expand_as(a, b).shape = torch.Size([100, 30, 2, 82, 7, 10])
expand_as(c, a).shape = torch.Size([1, 2, 82, 1, 10])
Test Passed!


The next block contains tests for [Part B](#b)

In [8]:
# `get_broadcast_info` correctness tests.
a = torch.zeros(1, 2, 82, 1, 10)
b = torch.zeros(100, 30, 1, 82, 7, 1)
c = torch.zeros(2, 82, 1, 5)
d = torch.zeros(2, 8, 1)

# Inputs for `None`, `False` expected output.
print(f"{get_broadcast_info(d, c) = }", f"{get_broadcast_info(c, a) = }",
       sep='\n')  # expected output: `None`, `False`.
print('Test Passed!' if get_broadcast_info(d, c) == (None, False) and
          get_broadcast_info(c, a) == (None, False) else 'Test Failed!')
print()  # Next test.

print(f"{get_broadcast_info(a, b) = }", f"{get_broadcast_info(b, c) = }",
        sep='\n')  # expected output: Shapes.
print('Test Passed!' if get_broadcast_info(a, b)[0] == torch.broadcast_tensors(a, b)[0].shape and
          get_broadcast_info(b, c)[0] == torch.broadcast_tensors(b, c)[0].shape else 'Test Failed!')

get_broadcast_info(d, c) = (None, False)
get_broadcast_info(c, a) = (None, False)
Test Passed!

get_broadcast_info(a, b) = (torch.Size([100, 30, 2, 82, 7, 10]), True)
get_broadcast_info(b, c) = (torch.Size([100, 30, 2, 82, 7, 5]), True)
Test Passed!


The next block contains tests for [Part C](#c)

In [9]:
# `broadcast_tensors` correctness tests.
a = torch.zeros(1, 2, 82, 1, 10)
b = torch.zeros(100, 30, 1, 82, 7, 1)
c = torch.zeros(2, 82, 1, 5)
d = torch.zeros(2, 8, 1)

# Inputs for `None` expected output.
print(f"{broadcast_tensors(d, c) = }", f"{broadcast_tensors(c, a) = }",
       sep='\n')  # expected output: `None`.
print('Test Passed!' if broadcast_tensors(d, c) is None and
          broadcast_tensors(c, a) is None else 'Test Failed!')
print()  # Next test.

print(f"{broadcast_tensors(a, b)[0].shape = }", f"{broadcast_tensors(b, c)[0].shape = }",
        sep='\n')  # expected output: Shapes.
print('Test Passed!' if (broadcast_tensors(a, b)[0] == torch.broadcast_tensors(a, b)[0]).all() and
          (broadcast_tensors(b, c)[0] == torch.broadcast_tensors(b, c)[0]).all() else 'Test Failed!')

broadcast_tensors(d, c) = None
broadcast_tensors(c, a) = None
Test Passed!

broadcast_tensors(a, b)[0].shape = torch.Size([100, 30, 2, 82, 7, 10])
broadcast_tensors(b, c)[0].shape = torch.Size([100, 30, 2, 82, 7, 5])
Test Passed!
