# Tensor Shape Assert

> This is a simple notebook to illustrate how the recursive function asserts the shape of a tensor.

## The (very verbose) recursive function

This function recursively checks the integrity of an n-dimensional tensor. Here's how it works:

1. If the input is not a list, we've reached a scalar value, so we return True and an empty shape list.
1. If the input is an empty list, we consider it valid and return True with a shape of [0].
1. We check the integrity of the first element and use its shape as the expected shape for all other elements at this level.
1. We iterate through the remaining elements, checking their integrity and shape against the first element.
1. If all elements are valid and have the same shape, we return True and the current level's shape (current length + subshape).
1. If any inconsistency is found, we return False and an empty shape list.

The function returns a tuple: (is_valid, shape). If the tensor is valid, is_valid will be True, and shape will be a list representing the dimensions of the tensor. If it's invalid, is_valid will be False, and shape will be an empty list.

In [None]:
def check_tensor_integrity(tensor, depth=0):
    indent = "| "+ "\t|" * depth + " "
    print(indent + "▶")
    print(f"{indent}Depth {depth}: Checking {tensor}")
    
    if not isinstance(tensor, list):
        print(f"{indent}Reached scalar value: {tensor}")
        return True, []  # Base case: we've reached a scalar value
    
    if not tensor:
        print(f"{indent}Empty list at depth {depth}")
        return True, [0]  # Empty list is considered valid
    
    print(f"{indent}Checking first element: {tensor[0]}")
    first_elem_valid, first_elem_shape = check_tensor_integrity(tensor[0], depth + 1)
    if not first_elem_valid:
        print(f"{indent}First element is invalid")
        return False, []
    
    expected_length = len(tensor)
    print(f"{indent}Expected length at this level: {expected_length}")
    current_level_shape = [expected_length] + first_elem_shape
    print(f"{indent}Expected shape at this level: {current_level_shape}")
    
    for i, elem in enumerate(tensor[1:], 1):
        print(f"{indent}Checking element {i}: {elem}")
        elem_valid, elem_shape = check_tensor_integrity(elem, depth + 1)
        if not elem_valid or elem_shape != first_elem_shape:
            print(f"{indent}Element {i} is invalid or has inconsistent shape")
            return False, []
    
    print(f"{indent}All elements at depth {depth} are valid")
    print(f"{indent}Final shape at depth {depth}: {current_level_shape}")
    return True, current_level_shape

In [None]:
# A helper function to get the shape of a nested list
def shape(iterable: list[float], /) -> tuple[int]:
    if not iterable:
        return ()
    shape = (len(iterable),)
    while isinstance(iterable[0], list):
        iterable = iterable[0]
        shape = shape + (len(iterable),)
    return shape

### Examples

#### Good Tensors

In [None]:
tensor_1 = [[[0, 1, 2], [3, 4, 5]]]

print("Checking tensor of shape:", shape(tensor_1))
result = check_tensor_integrity(tensor_1)
print(f"Final result: {result}")

In [None]:
tensor_2 = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]

print("Checking tensor of shape:", shape(tensor_2))
result = check_tensor_integrity(tensor_2)
print(f"Final result: {result}")

A long one is ahead!

In [None]:
tensor_3 = [
 [[ 0,  1,  2,  3,  4],
  [ 5,  6,  7,  8,  9],
  [10, 11, 12, 13, 14],
  [15, 16, 17, 18, 19]],
 [[20, 21, 22, 23, 24],
  [25, 26, 27, 28, 29],
  [30, 31, 32, 33, 34],
  [35, 36, 37, 38, 39]],
 [[40, 41, 42, 43, 44],
  [45, 46, 47, 48, 49],
  [50, 51, 52, 53, 54],
  [55, 56, 57, 58, 59]]
]

print("Checking tensor of shape:", shape(tensor_3))
result = check_tensor_integrity(tensor_3)
print(f"Final result: {result}")

#### Bad Tensors

In [None]:
tensor_4 = [
  [[1, 2], 
   [3, 4]], 
  [[5, 6, 7], 
   [8, 9, 0]]
]

print("Checking tensor of shape:", shape(tensor_4))
result = check_tensor_integrity(tensor_4)
print(f"Final result: {result}")

## Final function

Now we can make the final function cleaner and with exceptions.

In [None]:
def check_tensor_integrity(tensor):
    # Base case: we've reached a scalar value
    if not isinstance(tensor, list):
        return []
    
    # Empty list is considered valid
    if not tensor:
        return [0]
    
    # Check the first element to get the expected shape
    expected_shape = check_tensor_integrity(tensor[0])
    current_shape = [len(tensor)] + expected_shape

    # Validate the rest of the elements
    for elem in tensor[1:]:
        elem_shape = check_tensor_integrity(elem)
        if elem_shape != expected_shape:
            raise ValueError((
                "The tensor has an inhomogeneous shape after {dims} dimensions. "
                "The detected shape was {shape} + inhomogeneous part."
            ).format(
                dims=len(current_shape), 
                shape=tuple(current_shape[:len(current_shape)])
            ))
    
    return current_shape

One example with a valid 3D tensor

In [None]:
# Example usage
tensor_1 = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]  # Valid 3D tensor

print(f"Checking Tensor: {tensor_1}")
result_1 = check_tensor_integrity(tensor_1)
print(f"Tensor shape: {result_1}")

And one example with an invalid 3D tensor (this will raise an exception)

In [None]:
tensor_2 = [[[1, 2], [3, 4]], [[5, 6], [7, 8, 9]]]
#                                       ↑ Inhomogeneous part

try:
    print(f"Checking Tensor: {tensor_2}")
    result_2 = check_tensor_integrity(tensor_2)
    print(f"Tensor shape: {result_2}")
except ValueError as e:
    print(f"Error: {e}")

This function is can be found at `smoltorch.util.operations.assert_tensor_shape`. Although it's only meant to be used internally.