In [1]:
import numpy as np
from typing import Dict, List, Tuple, Any, Set
import dataclasses
import re # Using regex for some pattern checks

In [2]:
# --- 1. Custom Exception ---
class EinopsError(ValueError):
    """Custom exception for errors during einops operations."""
    pass

In [3]:
# --- 2. Data Structure for Parsed Pattern ---
@dataclasses.dataclass
class ParsedPattern:
    """Holds the structured information parsed from the pattern string."""
    # Input specification (LHS)
    lhs_axes: List[str]
    input_shape_spec: Tuple[Any, ...]
    input_axis_names: List[str]
    has_ellipsis_lhs: bool

    # Output specification (RHS)
    rhs_axes: List[str]
    output_shape_spec: Tuple[Any, ...]
    output_axis_names: List[str]
    has_ellipsis_rhs: bool

    # Resolved information
    resolved_axes: Dict[str, int]
    needs_reshaping_input: bool
    needs_reshaping_output: bool
    needs_repeating: bool
    needs_transposing: bool

    # Execution plan hints
    intermediate_shape_after_lhs_reshape: Tuple[int, ...]
    transpose_indices: Tuple[int, ...]
    final_shape: Tuple[int, ...]
    repeat_instructions: Dict # Store details needed for repeating axes

In [4]:
# --- 3. Main Public Function ---
def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths: int) -> np.ndarray:
    """
    Rearranges a NumPy ndarray based on the provided einops-style pattern.
    (Docstring remains the same as previous version)
    """
    try:
        # Stage 1: Initial Validation (Checker)
        _validate_input(tensor, pattern, axes_lengths)

        # Stage 2a: Parsing the Pattern & Semantic Validation
        parsed_pattern = _parse_pattern(pattern, tensor.shape, axes_lengths)

        # Stage 2b: Executing the Rearrangement
        result = _execute_rearrangement(tensor, parsed_pattern)

        return result
    except EinopsError as e:
        # Add context to the error message before re-raising
        message = f' Error while processing pattern "{pattern}".'
        message += f"\n Input tensor shape: {tensor.shape}. " # Get shape here as tensor should be valid ndarray if error is from parse/execute
        message += f"Additional axes lengths: {axes_lengths}."
        # Combine original error message with context
        raise EinopsError(message + f"\n Original error: {e}") from e
    except Exception as e:
        # Catch unexpected errors during development/execution
        # Avoid accessing tensor.shape here as tensor might be the cause (like list)
        message = f' Unexpected error while processing pattern "{pattern}".'
        message += f"\n Input tensor type: {type(tensor)}. "
        message += f"Additional axes lengths: {axes_lengths}."
        raise RuntimeError(message + f"\n Original error: {e.__class__.__name__}: {e}") from e

In [None]:
# --- 4. Internal Helper Functions ---

# --- 4a. Validation (Checker Logic) ---
def _validate_input(tensor: np.ndarray, pattern: str, axes_lengths: Dict[str, int]) -> None:
    """
    Performs initial syntax and type checks on the inputs.
    Incorporates checks from 'einops_error_conditions'.
    """
    # Check tensor type
    if not isinstance(tensor, np.ndarray):
        raise EinopsError("Input tensor must be a NumPy ndarray.")

    # Check pattern type and basic structure
    if not isinstance(pattern, str):
        raise EinopsError("Pattern must be a string.")
    if '->' not in pattern:
        raise EinopsError("Pattern must contain '->' separator.")
    if pattern.count('->') > 1:
        raise EinopsError("Pattern must contain exactly one '->' separator.")
    # Basic check for balanced parentheses
    if pattern.count('(') != pattern.count(')'):
        raise EinopsError(f"Pattern has unbalanced parentheses: '{pattern}'")

    # Check for invalid characters using regex
    allowed_pattern = r"(\w+|\(|\)|\s+|\.\.\.|->)"
    remaining_chars = re.sub(allowed_pattern, '', pattern)
    if remaining_chars:
        individual_allowed = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_() .-") # Include dot and hyphen here
        invalid_chars_found = set(remaining_chars) - individual_allowed
        if not invalid_chars_found:
             invalid_chars_found = set(remaining_chars)

        raise EinopsError(f"Pattern contains invalid characters or structure: {invalid_chars_found} in '{pattern}'")

    # Check axes_lengths keys and values
    for name, length in axes_lengths.items():
        if not isinstance(name, str) or not name.isidentifier():
            raise EinopsError(f"Axis name '{name}' in axes_lengths is not a valid identifier.")
        if name.startswith('_') or name.endswith('_'):
             raise EinopsError(f"Axis name '{name}' in axes_lengths should not start or end with underscore.")
        if not isinstance(length, int) or length <= 0:
            raise EinopsError(f"Length for axis '{name}' must be a positive integer, got {length}.")

    # Only print if validation passes the initial checks
    # print("Input validation passed (basic checks).") # Placeholder for debugging

In [6]:
# --- 4b. Parsing (Parser Logic) ---
def _parse_pattern(pattern: str, tensor_shape: Tuple[int, ...], axes_lengths: Dict[str, int]) -> ParsedPattern:
    """
    Parses the pattern string, validates semantics against tensor shape,
    and builds a plan for execution. Incorporates checks from 'einops_error_conditions'.
    """
    # print(f"Parsing pattern: '{pattern}' for shape: {tensor_shape} with lengths: {axes_lengths}") # Placeholder

    # --- Stage 1: Split and Basic Structure Validation ---
    lhs_str, rhs_str = pattern.split('->')
    lhs_str = lhs_str.strip()
    rhs_str = rhs_str.strip()

    # Check for ellipsis misuse (dots outside '...')
    if re.search(r'(?<!\.)\.(?!\.)|(?<!\.\.)\.(?!\.)', lhs_str + rhs_str):
         raise EinopsError("Pattern may contain dots only inside ellipsis (...)")
    # Check for multiple ellipses
    if lhs_str.count('...') > 1 or rhs_str.count('...') > 1:
         raise EinopsError("Pattern may contain dots only inside ellipsis (...); only one ellipsis per side.")

    has_ellipsis_lhs = '...' in lhs_str
    has_ellipsis_rhs = '...' in rhs_str

    # Check for ellipsis on RHS only
    if not has_ellipsis_lhs and has_ellipsis_rhs:
        raise EinopsError(f"Ellipsis found in right side, but not left side of pattern '{pattern}'")

    # --- Stage 2: Detailed Parsing (LHS & RHS) ---
    # TODO: Implement robust parsing logic here. Use regex or manual parsing.
    lhs_parts = re.findall(r'\(.*?\)|[\w.]+|\.\.\.', lhs_str) # Basic split - needs improvement
    rhs_parts = re.findall(r'\(.*?\)|[\w.]+|\.\.\.', rhs_str) # Basic split - needs improvement

    # Placeholder parsed data
    parsed_lhs_axes: List[str] = []
    parsed_rhs_axes: List[str] = []
    decomposed_lhs_names: List[str] = []
    decomposed_rhs_names: List[str] = []
    resolved_axes_lengths: Dict[str, int] = axes_lengths.copy()
    identifiers_lhs: Set[str] = set()
    identifiers_rhs: Set[str] = set()
    has_anonymous_lhs = False
    has_anonymous_rhs = False
    repeat_instructions_dict = {}

    # --- Stage 2a: Parse LHS ---
    # TODO: Implement detailed LHS parsing logic here, including:
    #   - Composition, name, ellipsis handling
    #   - Error checks for nesting, bad names, duplicates, anonymous axes > 1
    #   - Dimension number matching & ellipsis expansion
    #   - Shape inference & divisibility checks

    # --- Stage 2b: Parse RHS ---
    # TODO: Implement detailed RHS parsing logic here, including:
    #   - Similar checks as LHS
    #   - Handling numeric literals for repetition
    #   - Checking axis consistency with LHS

    # --- Stage 3: Final Semantic Validation & Plan Generation ---
    # TODO: Implement final checks and calculate execution plan details

    # --- Stage 4: Create and return ParsedPattern object ---
    # This is a dummy return value - replace with actual parsed data
    parsed_info = ParsedPattern(
        lhs_axes=parsed_lhs_axes, input_shape_spec=(), input_axis_names=decomposed_lhs_names, has_ellipsis_lhs=has_ellipsis_lhs,
        rhs_axes=parsed_rhs_axes, output_shape_spec=(), output_axis_names=decomposed_rhs_names, has_ellipsis_rhs=has_ellipsis_rhs,
        resolved_axes=resolved_axes_lengths,
        needs_reshaping_input=False, # Set based on parsing
        needs_reshaping_output=False, # Set based on parsing
        needs_repeating=False, # Set based on parsing (numeric literals)
        needs_transposing=False, # Set based on parsing (order change)
        intermediate_shape_after_lhs_reshape=(), # Calculate based on LHS
        transpose_indices=(), # Calculate based on axis mapping
        final_shape=(), # Calculate based on RHS
        repeat_instructions=repeat_instructions_dict
    )
    # print("Parsing complete (dummy implementation).") # Placeholder
    # raise EinopsError("Parsing not fully implemented.") # Uncomment during development
    if not parsed_info: # Replace with actual check if parsing failed internally
         raise EinopsError("Parsing failed to produce a valid plan.")
    return parsed_info

In [7]:
# --- 4c. Execution (Executor Logic) ---
def _execute_rearrangement(tensor: np.ndarray, plan: ParsedPattern) -> np.ndarray:
    """
    Executes the rearrangement using NumPy operations based on the parsed plan.
    """
    # print("Executing rearrangement based on parsed plan.") # Placeholder
    current_tensor = tensor
    temp_shape = tensor.shape # Keep track for error messages if needed

    try:
        # 1. Initial Reshape (Decomposition based on LHS)
        if plan.needs_reshaping_input:
            # print(f"Reshaping input to intermediate shape: {plan.intermediate_shape_after_lhs_reshape}")
            current_tensor = current_tensor.reshape(plan.intermediate_shape_after_lhs_reshape)
            temp_shape = current_tensor.shape
            # TODO: Implement actual reshape based on LHS composition

        # 2. Repeat/Tile (Based on numeric literals in RHS)
        if plan.needs_repeating:
            # print(f"Performing repeat/tile operation based on: {plan.repeat_instructions}")
            temp_shape = current_tensor.shape
            # TODO: Implement robust repeat/tile logic based on plan.repeat_instructions
            pass # Placeholder

        # 3. Transpose (Reordering axes)
        if plan.needs_transposing:
            # print(f"Transposing with order: {plan.transpose_indices}")
            current_tensor = np.transpose(current_tensor, axes=plan.transpose_indices)
            temp_shape = current_tensor.shape
            # TODO: Implement actual transpose

        # 4. Final Reshape (Composition based on RHS)
        if plan.needs_reshaping_output:
            # print(f"Reshaping output to final shape: {plan.final_shape}")
            current_tensor = current_tensor.reshape(plan.final_shape)
            temp_shape = current_tensor.shape
            # TODO: Implement actual reshape based on RHS composition

        # Sanity check final shape
        if hasattr(plan, 'final_shape') and plan.final_shape and current_tensor.shape != plan.final_shape:
           # This indicates an internal logic error in parsing/execution planning
           raise EinopsError(f"Internal error: Final shape mismatch. Expected {plan.final_shape}, got {current_tensor.shape}")

    except ValueError as ve:
        # Catch potential errors from numpy operations (e.g., reshape size mismatch)
        raise EinopsError(f"NumPy error during execution: {ve}. Current shape during error: {temp_shape}") from ve

    # print("Execution complete (dummy implementation).") # Placeholder
    # raise EinopsError("Execution not fully implemented.") # Uncomment during development
    return current_tensor # Return the final result

In [None]:
# my_einops.py
# --- Optional: Example Usage within the module ---
if __name__ == '__main__':
    print("Running example usage:")

    # --- Test Cases ---
    def run_test(test_name, tensor, pattern, lengths, expected_shape=None, expect_error=None):
        print(f"\n--- {test_name} ---")
        print(f"Pattern: '{pattern}', Lengths: {lengths}")
        # print("Input Data:\n", tensor) # Optional: print data for small tensors
        try:
            # --- Simulating the flow ---
            # Initial validation happens first
            _validate_input(tensor, pattern, lengths)

            # Only print shape if input validation passed (tensor must be ndarray)
            print("Input shape:", tensor.shape)

            # Parsing and Execution
            parsed = _parse_pattern(pattern, tensor.shape, lengths)
            result = _execute_rearrangement(tensor, parsed) # Dummy execution for now
            # --- End Simulation ---

            if expect_error:
                print(f"!!! ERROR: Expected EinopsError ({expect_error}), but got result.")
                # print(f"    Output shape (unexpected): {result.shape}") # Debug output
            else:
                # In the dummy execution, result shape might just be input shape
                # Once implemented, this check becomes meaningful
                print(f"Output shape: {result.shape} (Expected: {expected_shape})")
                # Add shape assertion here once implemented:
                # if expected_shape:
                #    assert result.shape == expected_shape, f"Shape mismatch: Got {result.shape}, expected {expected_shape}"
        except EinopsError as e:
            if expect_error:
                print(f"OK: Caught expected EinopsError.") # Simplified OK message
                # Check if error message contains expected substring for more specific tests
                if expect_error not in str(e):
                     print(f"    WARN: Error message mismatch. Expected substring '{expect_error}', Got '{e}'")
                else:
                     print(f"    Message contains expected text: '{expect_error}'")
            else:
                print(f"!!! ERROR: Caught unexpected EinopsError: {e}")
        except Exception as e:
             print(f"!!! ERROR: Caught unexpected {type(e).__name__}: {e}")


    # --- Basic Validation Tests ---
    run_test("Invalid Tensor", [1, 2, 3], 'a -> a', {}, expect_error="Input tensor must be a NumPy ndarray")
    run_test("No Separator", np.zeros(1), 'a b c', {}, expect_error="Pattern must contain '->' separator")
    run_test("Multiple Separators", np.zeros(1), 'a -> b -> c', {}, expect_error="Pattern must contain exactly one '->' separator")
    run_test("Invalid Char", np.zeros(1), 'a $ -> b', {}, expect_error="Pattern contains invalid characters or structure")
    run_test("Invalid Structure", np.zeros(1), 'a . . -> b', {}, expect_error="Pattern contains invalid characters or structure")
    run_test("Unbalanced Parens 1", np.zeros(1), '(a -> b', {}, expect_error="Pattern has unbalanced parentheses")
    run_test("Unbalanced Parens 2", np.zeros(1), 'a) -> b', {}, expect_error="Pattern has unbalanced parentheses")
    run_test("Invalid axes_lengths Name", np.zeros(1), 'a -> b', {'1b': 2}, expect_error="not a valid identifier")
    run_test("Invalid axes_lengths Value", np.zeros(1), 'a -> b', {'b': 0}, expect_error="must be a positive integer")
    run_test("Invalid axes_lengths Underscore", np.zeros(1), 'a -> b', {'_b': 2}, expect_error="should not start or end with underscore")

    # --- Parsing/Semantic Tests (will likely fail until _parse_pattern is implemented) ---
    run_test("Dots outside Ellipsis", np.zeros((2,3)), 'a . b -> a b', {}, expect_error="dots only inside ellipsis") # Expect error in parse
    run_test("Multiple Ellipsis LHS", np.zeros((2,3)), '... a ... -> a', {}, expect_error="only one ellipsis per side") # Expect error in parse
    run_test("Ellipsis RHS only", np.zeros((2,3)), 'a b -> ... a b', {}, expect_error="Ellipsis found in right side, but not left") # Expect error in parse

    run_test("Transpose", np.arange(12).reshape(3, 4), 'h w -> w h', {}, expected_shape=(4, 3)) # Dummy shape
    run_test("Split Axis", np.arange(12).reshape(6, 2), '(h w) c -> h w c', {'h': 3}, expected_shape=(3, 2, 2)) # Dummy shape
    run_test("Split Axis Fail (Div)", np.arange(10).reshape(5, 2), '(h w) c -> h w c', {'h': 3}, expect_error="Shape mismatch, can't divide") # Expect error in parse
    run_test("Merge Axes", np.arange(12).reshape(2, 3, 2), 'a b c -> (a b) c', {}, expected_shape=(6, 2)) # Dummy shape
    # run_test("Repeat Axis", np.arange(6).reshape(2,1,3), 'a 1 c -> a b c', {'b':4}, expected_shape=(2,4,3)) # Dummy shape - needs repeat logic
    run_test("Ellipsis", np.arange(24).reshape(2,3,4), '... w -> ... (w 2)', {}, expected_shape=(2,3,8)) # Dummy shape - needs repeat/reshape logic
    run_test("Ellipsis Mismatch", np.arange(24).reshape(2,3,4), 'a ... w -> a ... (w 2)', {}, expected_shape=(2,3,8)) # Dummy shape
    run_test("Axis Mismatch", np.arange(12).reshape(3, 4), 'h w -> w c', {}, expect_error="Identifiers only on one side") # Expect error in parse



Running example usage:

--- Invalid Tensor ---
Pattern: 'a -> a', Lengths: {}
OK: Caught expected EinopsError.
    Message contains expected text: 'Input tensor must be a NumPy ndarray'

--- No Separator ---
Pattern: 'a b c', Lengths: {}
OK: Caught expected EinopsError.
    Message contains expected text: 'Pattern must contain '->' separator'

--- Multiple Separators ---
Pattern: 'a -> b -> c', Lengths: {}
OK: Caught expected EinopsError.
    Message contains expected text: 'Pattern must contain exactly one '->' separator'

--- Invalid Char ---
Pattern: 'a $ -> b', Lengths: {}
OK: Caught expected EinopsError.
    Message contains expected text: 'Pattern contains invalid characters or structure'

--- Invalid Structure ---
Pattern: 'a . . -> b', Lengths: {}
OK: Caught expected EinopsError.
    Message contains expected text: 'Pattern contains invalid characters or structure'

--- Unbalanced Parens 1 ---
Pattern: '(a -> b', Lengths: {}
Input shape: (1,)
!!! ERROR: Expected EinopsError (Pa