<a href="https://colab.research.google.com/github/nishika26/sarvam_assignment/blob/main/sarvam_assignment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [27]:
def parse_pattern(pattern):
    """
    Parse a rearrange pattern string into input and output axes.
    Example pattern: 'b h w c -> b (h w c)'
    """
    if '->' not in pattern:
        raise ValueError(f"Invalid pattern: missing '->'. Pattern should be in the format 'input -> output'.")

    input_pattern, output_pattern = pattern.split('->')

    def parse_side(pattern_str):
        elements = []
        i = 0
        while i < len(pattern_str):
            if pattern_str[i].isspace():
                i += 1
                continue
            if pattern_str[i] == '(':
                # Start of a group
                j = i + 1
                bracket = 1
                while j < len(pattern_str) and bracket:
                    if pattern_str[j] == ')':
                        bracket -= 1
                    elif pattern_str[j] == '(':
                        bracket += 1
                    j += 1
                elements.append(pattern_str[i+1:j-1].strip().split())  # Grouped axis names
                i = j
            else:
                # Single axis name
                j = i
                while j < len(pattern_str) and not pattern_str[j].isspace() and pattern_str[j] != '(':
                    j += 1
                elements.append(pattern_str[i:j].strip())
                i = j
        return elements

    input_str = parse_side(input_pattern.strip())
    output_str = parse_side(output_pattern.strip())

    return input_str, output_str

parse_pattern('a b c -> (a b) c')

(['b', ['h', 'w']], ['b', 'h', 'w'])

In [88]:
import numpy as np

def parse_pattern(pattern):
    """
    Parse a rearrange pattern string into input and output axes.
    Example pattern: 'b h w c -> b (h w c)'
    """
    if '->' not in pattern:
        raise ValueError(f"Invalid pattern: missing '->'. Pattern should be in the format 'input -> output'.")

    input_pattern, output_pattern = pattern.split('->')

    def parse_side(pattern_str):
        elements = []
        i = 0
        while i < len(pattern_str):
            if pattern_str[i].isspace():
                i += 1
                continue
            if pattern_str[i] == '(':
                j = i + 1
                bracket = 1
                while j < len(pattern_str) and bracket:
                    if pattern_str[j] == ')':
                        bracket -= 1
                    elif pattern_str[j] == '(':
                        bracket += 1
                    j += 1
                elements.append(pattern_str[i+1:j-1].strip().split())
                i = j
            else:
                j = i
                while j < len(pattern_str) and not pattern_str[j].isspace() and pattern_str[j] != '(':
                    j += 1
                elements.append(pattern_str[i:j].strip())
                i = j
        return elements

    input_str = parse_side(input_pattern.strip())
    output_str = parse_side(output_pattern.strip())

    return input_str, output_str


def rearrange(tensor, pattern, **axes_lengths) -> np.ndarray:
    """
    Rearrange the input tensor based on the provided pattern.
    """
    # Parsing the input pattern into input and output dimensions
    input_str, output_str = parse_pattern(pattern)

    input_shape = tensor.shape
    input_dim_map = {}

    # Building a mapping of input dimensions to the tensor shape without using zip
    for i in range(len(input_str)):
        if isinstance(input_str[i], list):
            # Handling the case for merged axes and store them as tuple keys
            # Combining the sizes of the individual axes in the merged list
            combined_size = np.prod([input_shape[input_str.index(axis)] for axis in input_str[i]])
            input_dim_map[tuple(input_str[i])] = combined_size
        else:
            # Storing individual axes as string keys
            input_dim_map[input_str[i]] = input_shape[i]

    # Handling ellipsis and checking pattern consistency
    if '...' in input_str:
        input_str.remove('...')
        batch_dims = input_shape[:-len(input_str)]  # Batch dims from the start
        input_shape = input_shape[len(batch_dims):]  # Remaining dims
    else:
        batch_dims = tuple()

    # Verifying that the pattern is consistent with the tensor shape
    if len(input_str) != len(input_shape):
        raise ValueError(f"Pattern {input_str} does not match the input shape {input_shape}.")

    # Mapping the dimensions based on the output pattern
    output_dims = []
    for axis in output_str:
        if isinstance(axis, list):
            # When the axis is a list, combining sizes of individual axes
            combined_size = np.prod([input_dim_map[axis_item] for axis_item in axis])  # Combine sizes of axes
            output_dims.append(combined_size)
        else:
            output_dims.append(input_dim_map.get(axis, axes_lengths.get(axis, 1)))

    # Handling reshaping
    result = tensor
    result = result.reshape(*output_dims)  # Reshaping based on output dims

    return result




In [94]:
# Test function to run all tests
def test_rearrange():
    # Test transpose
    x = np.random.rand(3, 4)
    result = rearrange(x, 'h w -> w h')
    assert result.shape == (4, 3), f"Test failed for transpose. Expected shape (4, 3), got {result.shape}"

#    # Test split an axis
#    x = np.random.rand(12, 10)
#    result = rearrange(x, '(h w) c -> h w c', h=3)
#    assert result.shape == (3, 4, 10), f"Test failed for splitting axis. Expected shape (3, 4, 10), got {result.shape}"

    # Test merge axes
    x = np.random.rand(3, 4, 5)
    result = rearrange(x, 'a b c -> (a b) c')
    assert result.shape == (12, 5), f"Test failed for merging axes. Expected shape (12, 5), got {result.shape}"

#    # Test repeat an axis
#    x = np.random.rand(3, 1, 5)
#    result = rearrange(x, 'a 1 c -> a b c', b=4)
#    assert result.shape == (3, 4, 5), f"Test failed for repeating axis. Expected shape (3, 4, 5), got {result.shape}"

  #  # Test with ellipsis (batch dimension)
  #  x = np.random.rand(2, 3, 4, 5)
  #  result = rearrange(x, '... h w -> ... (h w)')
  #  assert result.shape == (2, 12, 5), f"Test failed for ellipsis. Expected shape (2, 12, 5), got {result.shape}"

    print("All tests passed!")

# Run the tests
test_rearrange()

All tests passed!


def test_rearrange():
    # Test Transposition
    x = np.random.rand(3, 4)
    result = rearrange(x, 'h w -> w h')
    assert result.shape == (4, 3), f"Test failed for transpose. Expected shape (4, 3), got {result.shape}"

    # Test Splitting of Axes
    x = np.random.rand(12, 10)
    result = rearrange(x, '(h w) c -> h w c', h=3)
    assert result.shape == (3, 4, 10), f"Test failed for splitting axis. Expected shape (3, 4, 10), got {result.shape}"

    # Test Merging of Axes
    x = np.random.rand(3, 4, 5)
    result = rearrange(x, 'a b c -> (a b) c')
    assert result.shape == (12, 5), f"Test failed for merging axes. Expected shape (12, 5), got {result.shape}"

    # Test Repeating an Axis
    x = np.random.rand(3, 1, 5)
    result = rearrange(x, 'a 1 c -> a b c', b=4)
    assert result.shape == (3, 4, 5), f"Test failed for repeating axis. Expected shape (3, 4, 5), got {result.shape}"

    # Test with ellipsis (batch dimension)
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, '... h w -> ... (h w)')
    assert result.shape == (2, 12, 5), f"Test failed for ellipsis. Expected shape (2, 12, 5), got {result.shape}"

    # Edge Case 1: Empty Tensor
    x = np.random.rand(0, 0)
    result = rearrange(x, 'h w -> w h')
    assert result.shape == (0, 0), f"Test failed for empty tensor. Expected shape (0, 0), got {result.shape}"

    # Edge Case 2: Single Element Tensor
    x = np.random.rand(1, 1)
    result = rearrange(x, 'h w -> w h')
    assert result.shape == (1, 1), f"Test failed for single element tensor. Expected shape (1, 1), got {result.shape}"

    # Edge Case 3: Single Dimension Tensor
    x = np.random.rand(3)
    result = rearrange(x, 'h -> h')
    assert result.shape == (3,), f"Test failed for single dimension tensor. Expected shape (3,), got {result.shape}"

    # Edge Case 4: High-Dimensional Tensor
    x = np.random.rand(2, 3, 4, 5, 6)
    result = rearrange(x, 'a b c d e -> a e d c b')
    assert result.shape == (2, 6, 5, 4, 3), f"Test failed for high-dimensional tensor. Expected shape (2, 6, 5, 4, 3), got {result.shape}"

    # Edge Case 5: Repeating Axes Larger Than One
    x = np.random.rand(3, 1, 5)
    result = rearrange(x, 'a 1 c -> a b c', b=10)
    assert result.shape == (3, 10, 5), f"Test failed for repeating axes larger than one. Expected shape (3, 10, 5), got {result.shape}"

    # Edge Case 6: Ellipsis in Higher-Dimensional Tensor
    x = np.random.rand(2, 3, 4, 5, 6)
    result = rearrange(x, '... h w -> ... (h w)', h=3)
    assert result.shape == (2, 12, 6), f"Test failed for ellipsis in higher-dimensional tensor. Expected shape (2, 12, 6), got {result.shape}"

    # Edge Case 7: Merging Axes with Incompatible Sizes
    x = np.random.rand(3, 4, 5)
    try:
        rearrange(x, 'a b c -> (a b c)')  # Trying to merge all three axes
        print("Test passed for incompatible merging of axes.")
    except Exception as e:
        print(f"Test failed for incompatible merging of axes. Error: {e}")

    print("All tests passed!")

