In [7]:
import re
import numpy as np
from typing import Dict, List, Tuple, Union, Optional

def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths) -> np.ndarray:
    """

    Args:
        tensor: Input numpy array (tensor)
        pattern: String pattern describing the transformation
        **axes_lengths: Named sizes for dimensions

    Returns:
        Rearranged numpy array

    Example usage as per the doc:
        # Transpose
        x = np.random.rand(3, 4)
        result = rearrange(x, 'h w -> w h')

        # Split an axis
        x = np.random.rand(12, 10)
        result = rearrange(x, '(h w) c -> h w c', h=3)

        # Merge axes
        x = np.random.rand(3, 4, 5)
        result = rearrange(x, 'a b c -> (a b) c')
    """
    # Parse the pattern
    parser = EinopsPatternParser(pattern)
    input_spec, output_spec = parser.parse()

    # Validate and normalize the pattern
    input_dims, output_dims = validate_and_normalize_pattern(
        tensor.shape, input_spec, output_spec, axes_lengths
    )

    # Process tensor based on the parsed pattern
    result = transform_tensor(tensor, input_dims, output_dims)

    return result



In [8]:
class EinopsPatternParser:
    """Parser for einops pattern strings."""

    def __init__(self, pattern: str):
        """
        Initialize the pattern parser.

        Args:
            pattern: String pattern in einops format
        """
        self.pattern = pattern.strip()
        self.composite_axis_re = re.compile(r'\([^()]+\)')  # Matches (a b), (h w), etc.
        self.ellipsis = '...'

    def parse(self) -> Tuple[List[str], List[str]]:
        """
        Parse the pattern string into input and output specifications.

        Returns:
            Tuple of (input_spec, output_spec)
        """
        # Split into input and output parts
        if '->' not in self.pattern:
            raise ValueError("Pattern must contain '->' to separate input and output parts")

        input_pattern, output_pattern = self.pattern.split('->')

        # Clean and parse each part
        input_spec = self._parse_part(input_pattern.strip())
        output_spec = self._parse_part(output_pattern.strip())

        return input_spec, output_spec

    def _parse_part(self, part: str) -> List[str]:
        """
        Parse one part of the pattern (input/output).

        Args:
            part: Input or output pattern string

        Returns:
            List of dimension specifications
        """
        if not part:
            return []

        # handle ellipsis as a separate token
        part = part.replace('...', ' ... ')

        # Handle composite axes - need to preserve parentheses
        # Find all composite axis patterns like (a b)
        composite_matches = list(self.composite_axis_re.finditer(part))
        if composite_matches:
            # replace with placeholders
            placeholders = {}
            for i, match in enumerate(composite_matches):
                placeholder = f"__COMPOSITE_{i}__"
                placeholders[placeholder] = match.group(0)
                part = part.replace(match.group(0), placeholder)

            # Split the string with placeholders
            tokens = part.split()

            # Replace placeholders with original composite axes
            result = []
            for token in tokens:
                if token in placeholders:
                    result.append(placeholders[token])
                else:
                    result.append(token)

            return result
        else:
            # No composite axes, simple split
            return part.split()



In [9]:
class AxisSpec:
    """Specification for a tensor axis."""

    def __init__(self, name: str, composite: bool = False, components: List[str] = None):
        """
        Initialize an axis specification.

        Args:
            name: Name of the axis
            composite: Whether this is a composite axis (e.g., (a b))
            components: List of component axis names if composite
        """
        self.name = name
        self.composite = composite
        self.components = components or []
        self.size = None

    def __str__(self):
        if self.composite:
            return f"({' '.join(self.components)})"
        return self.name

def parse_composite_axes(dimensions: List[str]) -> List[AxisSpec]:
    """
    Parse and expand composite axes in the dimension list.

    Args:
        dimensions: List of dimension specifications

    Returns:
        List of normalized AxisSpec objects
    """
    result = []

    for dim in dimensions:
        if dim.startswith('(') and dim.endswith(')'):
            # This is a composite axis like (h w)
            components = dim[1:-1].split()
            result.append(AxisSpec(dim, True, components))
        else:
            result.append(AxisSpec(dim))

    return result

def validate_and_normalize_pattern(
    tensor_shape: Tuple[int, ...],
    input_spec: List[str],
    output_spec: List[str],
    axes_lengths: Dict[str, int]
) -> Tuple[List[AxisSpec], List[AxisSpec]]:
    """
    Validate the pattern against the tensor shape and normalize it.

    Args:
        tensor_shape: Shape of the input tensor
        input_spec: List of input dimension specifications
        output_spec: List of output dimension specifications
        axes_lengths: Dictionary of axis lengths provided by the user

    Returns:
        Tuple of (input_dims, output_dims) with resolved sizes
    """
    # Parse and expand composite axes
    input_dims = parse_composite_axes(input_spec)
    output_dims = parse_composite_axes(output_spec)

    # Handle ellipsis if present
    input_has_ellipsis = any(dim.name == '...' for dim in input_dims)
    output_has_ellipsis = any(dim.name == '...' for dim in output_dims)

    if input_has_ellipsis != output_has_ellipsis:
        raise ValueError("Ellipsis must be present in both input and output or in neither")

    # Count explicit dimensions (non-ellipsis)
    explicit_input_dims = sum(1 for dim in input_dims if dim.name != '...' and not dim.composite)
    composite_input_count = sum(1 for dim in input_dims if dim.composite)

    # For composite dimensions, need to count how many actual tensor dimensions they represent
    composite_expanded_count = 0
    for dim in input_dims:
        if dim.composite:
            # A composite dimension like (h w) represents one tensor dimension
            composite_expanded_count += 1

    # Calculate ellipsis dimensions
    ellipsis_dims = []
    if input_has_ellipsis:
        total_explicit_dims = explicit_input_count + composite_input_count
        ellipsis_dim_count = len(tensor_shape) - total_explicit_dims
        if ellipsis_dim_count < 0:
            raise ValueError(f"Input pattern specifies more dimensions ({total_explicit_dims}) "
                            f"than tensor has ({len(tensor_shape)})")

        # Create named dimensions for ellipsis
        ellipsis_dims = [AxisSpec(f"_ellipsis_{i}") for i in range(ellipsis_dim_count)]

    # Resolve explicit dimensions and substitute ellipsis
    resolved_input_dims = []
    tensor_dim_idx = 0

    # Create a mapping from axis name to size
    axis_sizes = {}

    # First pass: Handle non-composite axes and record sizes
    for dim in input_dims:
        if dim.name == '...':
            resolved_input_dims.extend(ellipsis_dims)

            # Record sizes of ellipsis dimensions
            for i, ellipsis_dim in enumerate(ellipsis_dims):
                ellipsis_dim.size = tensor_shape[tensor_dim_idx]
                axis_sizes[ellipsis_dim.name] = ellipsis_dim.size
                tensor_dim_idx += 1
        elif dim.composite:
            # For composite axes, calculate the product of component sizes if available
            component_sizes = {}
            missing_components = []

            for component in dim.components:
                if component in axes_lengths:
                    component_sizes[component] = axes_lengths[component]
                else:
                    missing_components.append(component)

            # If all components have sizes, we can validate
            if not missing_components:
                # Calculate the expected size
                expected_size = 1
                for size in component_sizes.values():
                    expected_size *= size

                if tensor_dim_idx >= len(tensor_shape):
                    raise ValueError("Input pattern has more dimensions than tensor")

                if expected_size != tensor_shape[tensor_dim_idx]:
                    raise ValueError(
                        f"Size mismatch for composite dimension {dim}: "
                        f"expected {tensor_shape[tensor_dim_idx]}, got {expected_size}"
                    )
            else:
                # If some components are missing, derive from the tensor dimension
                if len(missing_components) == 1:
                    # If only one component is missing, calculate its size
                    missing_component = missing_components[0]
                    remaining_size = tensor_shape[tensor_dim_idx]

                    for component, size in component_sizes.items():
                        remaining_size //= size

                    component_sizes[missing_component] = remaining_size
                    axes_lengths[missing_component] = remaining_size
                else:
                    # If multiple components are missing, we need explicit sizes
                    components_str = ', '.join(missing_components)
                    raise ValueError(f"Cannot determine sizes for components: {components_str}")

            # Record the size of the composite dimension
            dim.size = tensor_shape[tensor_dim_idx]

            # Record individual component sizes
            for component, size in component_sizes.items():
                axis_sizes[component] = size

            resolved_input_dims.append(dim)
            tensor_dim_idx += 1
        else:
            if tensor_dim_idx >= len(tensor_shape):
                raise ValueError("Input pattern has more dimensions than tensor")

            dim.size = tensor_shape[tensor_dim_idx]
            axis_sizes[dim.name] = dim.size
            resolved_input_dims.append(dim)
            tensor_dim_idx += 1


    if tensor_dim_idx != len(tensor_shape):
        raise ValueError(f"Input pattern doesn't account for all tensor dimensions: "
                        f"pattern: {input_spec}, tensor shape: {tensor_shape}")

    # Resolve output dimensions
    resolved_output_dims = []
    for dim in output_dims:
        if dim.name == '...':
            resolved_output_dims.extend(ellipsis_dims)
        elif dim.composite:
            # For composite output axes, calculate the size based on components
            component_sizes = []

            for component in dim.components:
                if component in axis_sizes:
                    component_sizes.append(axis_sizes[component])
                elif component in axes_lengths:
                    component_sizes.append(axes_lengths[component])
                else:
                    raise ValueError(f"Cannot determine size for component '{component}' in output dimension '{dim}'")

            dim.size = 1
            for size in component_sizes:
                dim.size *= size

            resolved_output_dims.append(dim)
        else:
            # For simple dimensions, find them in input or axes_lengths
            if dim.name in axis_sizes:
                dim.size = axis_sizes[dim.name]
            elif dim.name in axes_lengths:
                dim.size = axes_lengths[dim.name]
            else:
                raise ValueError(f"Cannot determine size for output dimension '{dim.name}'")

            resolved_output_dims.append(dim)

    return resolved_input_dims, resolved_output_dims

def transform_tensor(
    tensor: np.ndarray,
    input_dims: List[AxisSpec],
    output_dims: List[AxisSpec]
) -> np.ndarray:
    """
    Transform the tensor according to the input and output specifications.

    Args:
        tensor: Input tensor to transform
        input_dims: List of input dimension specifications
        output_dims: List of output dimension specifications

    Returns:
        Transformed tensor
    """
    # Expand composite input dimensions
    expanded_input_dims = []
    expanded_shapes = []

    for dim in input_dims:
        if dim.composite:
            # For composite input dimensions, need to split them
            components = dim.components
            component_sizes = []

            for component in components:
                # Find the size for this component
                for other_dim in input_dims + output_dims:
                    if other_dim.name == component and other_dim.size is not None:
                        component_sizes.append(other_dim.size)
                        break
                else:
                    # Try to find in another composite dimension
                    for other_dim in input_dims + output_dims:
                        if other_dim.composite and component in other_dim.components:
                            for comp in other_dim.components:
                                if comp == component and hasattr(comp, 'size'):
                                    component_sizes.append(comp.size)
                                    break

            # Make sure the product matches
            if np.prod(component_sizes) != dim.size:
                remaining = dim.size
                for size in component_sizes:
                    remaining //= size
                if len(component_sizes) < len(components):
                    component_sizes.append(remaining)

            expanded_input_dims.extend(components)
            expanded_shapes.extend(component_sizes)
        else:
            expanded_input_dims.append(dim.name)
            expanded_shapes.append(dim.size)

    # Reshape the tensor to split composite dimensions
    tensor = tensor.reshape(expanded_shapes)

    # Create a mapping from dimension names to their axis position
    dim_to_position = {name: i for i, name in enumerate(expanded_input_dims)}

    # Determine the order of axes for the output
    output_order = []
    output_shape = []

    for dim in output_dims:
        if dim.composite:
            # For composite output dimensions (merging axes)
            merged_positions = []
            merged_size = 1

            for component in dim.components:
                if component in dim_to_position:
                    merged_positions.append(dim_to_position[component])

            # Check if all components exist in input
            if len(merged_positions) == len(dim.components):
                # Check if axes are adjacent for efficient merging
                if all(a + 1 == b for a, b in zip(merged_positions[:-1], merged_positions[1:])):
                    # Components are adjacent, can merge directly
                    output_order.extend(merged_positions)
                    output_shape.append(dim.size)
                else:
                    # Need to transpose first to make adjacent
                    pass  # For now, we'll require adjacent axes
            else:
                raise ValueError(f"Not all components of {dim} found in input dimensions")
        else:
            if dim.name in dim_to_position:
                output_order.append(dim_to_position[dim.name])
                output_shape.append(dim.size)

    # Check if we need to transpose
    if sorted(output_order) != output_order:
        tensor = np.transpose(tensor, output_order)

    # Reshape to final output shape
    final_shape = []
    i = 0
    for dim in output_dims:
        if dim.composite:
            # For merged dimensions, use the precomputed size
            final_shape.append(dim.size)
            i += len(dim.components)
        elif dim.name == '...':
            # For ellipsis, add all corresponding dimensions
            ellipsis_count = sum(1 for d in input_dims if d.name.startswith('_ellipsis_'))
            final_shape.extend(tensor.shape[i:i+ellipsis_count])
            i += ellipsis_count
        else:
            # For simple dimensions, use the dimension size
            final_shape.append(dim.size)
            i += 1

    tensor = tensor.reshape(final_shape)

    return tensor

In [10]:
import unittest
import numpy as np
from typing import Tuple, List

class TestEinopsPatternParser(unittest.TestCase):
    """Tests for the EinopsPatternParser class."""

    def test_basic_pattern_parsing(self):
        parser = EinopsPatternParser('a b -> b a')
        input_spec, output_spec = parser.parse()
        self.assertEqual(input_spec, ['a', 'b'])
        self.assertEqual(output_spec, ['b', 'a'])

    def test_composite_pattern_parsing(self):
        parser = EinopsPatternParser('(a b) c -> a b c')
        input_spec, output_spec = parser.parse()
        self.assertEqual(input_spec, ['(a b)', 'c'])
        self.assertEqual(output_spec, ['a', 'b', 'c'])

    def test_ellipsis_pattern_parsing(self):
        parser = EinopsPatternParser('... h w -> ... (h w)')
        input_spec, output_spec = parser.parse()
        self.assertEqual(input_spec, ['...', 'h', 'w'])
        self.assertEqual(output_spec, ['...', '(h w)'])

    def test_invalid_pattern(self):
        with self.assertRaises(ValueError):
            parser = EinopsPatternParser('a b c')  # Missing ->
            parser.parse()



In [11]:
class TestAxisSpec(unittest.TestCase):
    """Tests for the AxisSpec class."""

    def test_simple_axis(self):
        axis = AxisSpec('h')
        self.assertEqual(axis.name, 'h')
        self.assertFalse(axis.composite)
        self.assertEqual(axis.components, [])

    def test_composite_axis(self):
        axis = AxisSpec('(h w)', True, ['h', 'w'])
        self.assertEqual(axis.name, '(h w)')
        self.assertTrue(axis.composite)
        self.assertEqual(axis.components, ['h', 'w'])






In [12]:
class TestParseCompositeAxes(unittest.TestCase):
    """Tests for parse_composite_axes function."""

    def test_simple_axes(self):
        dims = ['a', 'b', 'c']
        axes = parse_composite_axes(dims)
        self.assertEqual(len(axes), 3)
        self.assertEqual([axis.name for axis in axes], dims)
        self.assertFalse(any(axis.composite for axis in axes))

    def test_mixed_axes(self):
        dims = ['a', '(b c)', 'd']
        axes = parse_composite_axes(dims)
        self.assertEqual(len(axes), 3)
        self.assertEqual([axis.name for axis in axes], dims)
        self.assertEqual([axis.composite for axis in axes], [False, True, False])
        self.assertEqual(axes[1].components, ['b', 'c'])


In [13]:
class TestRearrangeFunctionality(unittest.TestCase):
    """Tests for the rearrange function."""

    def test_transpose(self):
        x = np.random.rand(3, 4)
        result = rearrange(x, 'h w -> w h')
        self.assertEqual(result.shape, (4, 3))
        np.testing.assert_array_equal(result, x.T)

    def test_split_axis(self):
        x = np.random.rand(12, 10)
        result = rearrange(x, '(h w) c -> h w c', h=3)
        self.assertEqual(result.shape, (3, 4, 10))
        reshaped = x.reshape(3, 4, 10)
        np.testing.assert_array_equal(result, reshaped)

    def test_merge_axes(self):
        x = np.random.rand(3, 4, 5)
        result = rearrange(x, 'a b c -> (a b) c')
        self.assertEqual(result.shape, (12, 5))
        reshaped = x.reshape(12, 5)
        np.testing.assert_array_equal(result, reshaped)


    def test_complex_pattern(self):
        x = np.random.rand(8, 32, 32, 3)  # Batch, Height, Width, Channels
        result = rearrange(x, 'b (h ph) (w pw) c -> b h w (ph pw c)', ph=4, pw=4)
        self.assertEqual(result.shape, (8, 8, 8, 48))

    def test_error_missing_size(self):
        x = np.random.rand(12, 10)
        with self.assertRaises(ValueError):
            rearrange(x, '(h w) c -> h w c')  # Missing h or w

    def test_error_size_mismatch(self):
        x = np.random.rand(12, 10)
        with self.assertRaises(ValueError):
            rearrange(x, '(h w) c -> h w c', h=5)  # 5 * w != 12

    def test_repeat_axis(self):
        x = np.random.rand(3, 1, 5)
        with self.assertRaises(ValueError):
            # Our current implementation doesn't support repeating
            # This is more complex and would require additional functionality
            rearrange(x, 'a 1 c -> a b c', b=4)

def run_tests():
    """Run all tests."""
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()

    suite.addTests(loader.loadTestsFromTestCase(TestEinopsPatternParser))
    suite.addTests(loader.loadTestsFromTestCase(TestAxisSpec))
    suite.addTests(loader.loadTestsFromTestCase(TestParseCompositeAxes))
    suite.addTests(loader.loadTestsFromTestCase(TestRearrangeFunctionality))

    runner = unittest.TextTestRunner(verbosity=2)
    return runner.run(suite)

In [14]:
if __name__ == "__main__":
    run_tests()

test_basic_pattern_parsing (__main__.TestEinopsPatternParser.test_basic_pattern_parsing) ... ok
test_composite_pattern_parsing (__main__.TestEinopsPatternParser.test_composite_pattern_parsing) ... ok
test_ellipsis_pattern_parsing (__main__.TestEinopsPatternParser.test_ellipsis_pattern_parsing) ... ok
test_invalid_pattern (__main__.TestEinopsPatternParser.test_invalid_pattern) ... ok
test_composite_axis (__main__.TestAxisSpec.test_composite_axis) ... ok
test_simple_axis (__main__.TestAxisSpec.test_simple_axis) ... ok
test_mixed_axes (__main__.TestParseCompositeAxes.test_mixed_axes) ... ok
test_simple_axes (__main__.TestParseCompositeAxes.test_simple_axes) ... ok
test_complex_pattern (__main__.TestRearrangeFunctionality.test_complex_pattern) ... ok
test_error_missing_size (__main__.TestRearrangeFunctionality.test_error_missing_size) ... ok
test_error_size_mismatch (__main__.TestRearrangeFunctionality.test_error_size_mismatch) ... ok
test_merge_axes (__main__.TestRearrangeFunctionality.te