In [55]:
# install numpy if not installed
! pip install numpy




[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [56]:
# numpy for operations 
# re for parsing

import numpy as np
import re

In [57]:
def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths) -> np.ndarray:
    """
    rearranges tensor according to an einops-like pattern.
    supports splitting axes, merging axes, and repeating axes.
    note: this implementation now uses a clearer repeat syntax.
    """
    # check that input is a numpy array
    if not isinstance(tensor, np.ndarray):
        raise TypeError("input tensor must be a numpy ndarray.")
    if "->" not in pattern:
        raise ValueError("pattern must contain '->' separator.")

    # parse the pattern into source and target patterns
    src, target = map(str.strip, pattern.split("->"))
    if not src:
        raise ValueError("src pattern cannot be empty.")
    if not target:
        raise ValueError("target pattern cannot be empty.")

    # tokenize the src and target patterns.
    # tokens can be grouped (in parentheses) or simple axis names.
    src_tokens = [t.strip() for t in re.findall(r'\(.*?\)|[\w\.]+', src)]
    tgt_tokens = [t.strip() for t in re.findall(r'\(.*?\)|[\w\.]+', target)]

    if tensor.ndim != len(src_tokens):
        raise ValueError(f"input tensor has {tensor.ndim} dimensions, but src pattern '{src}' has {len(src_tokens)} tokens.")

    # build a mapping from variable names to their dimension sizes.
    var2dim = {}
    temp_shape = list(tensor.shape)
    processed_src_vars = set()
    src_vars_in_order = []  # flat list of individual axis names in order

    current_dim_index = 0
    for token in src_tokens:
        if token.startswith('(') and token.endswith(')'):
            # handle grouped token from src; expected to be mergeable dimensions
            group_vars = re.findall(r'\w+', token[1:-1])
            if not group_vars:
                raise ValueError(f"empty group found in src pattern: {token}")
            product_known = 1
            unknown_vars = []
            for var in group_vars:
                if var in axes_lengths:
                    size = axes_lengths[var]
                    if var in var2dim and var2dim[var] != size:
                        raise ValueError(f"size mismatch for axis '{var}': {var2dim[var]} vs {size}")
                    var2dim[var] = size
                    product_known *= size
                elif var in var2dim:
                    product_known *= var2dim[var]
                else:
                    unknown_vars.append(var)
            # if all dimensions known, check product against tensor dimension
            shape_dim = temp_shape[current_dim_index]
            if not unknown_vars:
                if product_known != shape_dim:
                    raise ValueError(f"product of dimensions in group {token} ({product_known}) does not match input tensor dimension size ({shape_dim})")
            elif len(unknown_vars) == 1:
                # infer the missing dimension
                missing_var = unknown_vars[0]
                if shape_dim % product_known != 0:
                    raise ValueError(f"cannot infer size for '{missing_var}' in group {token}: dimension size {shape_dim} is not divisible by known product {product_known}")
                inferred = shape_dim // product_known
                var2dim[missing_var] = inferred
            else:
                raise ValueError(f"ambiguous dimensions for group {token}. provide sizes for all but one variable.")
            # add individual vars to flat list
            for var in group_vars:
                src_vars_in_order.append(var)
                processed_src_vars.add(var)
            current_dim_index += 1
        else:
            # simple token: must be a valid identifier.
            if not re.fullmatch(r'\w+', token):
                raise ValueError(f"invalid token in src pattern: '{token}'")
            shape_dim = temp_shape[current_dim_index]
            if token in var2dim and var2dim[token] != shape_dim:
                raise ValueError(f"size mismatch for axis '{token}': {var2dim[token]} vs {shape_dim}")
            if token in axes_lengths and axes_lengths[token] != shape_dim:
                raise ValueError(f"provided size for axis '{token}' ({axes_lengths[token]}) does not match input tensor dimension size ({shape_dim})")
            var2dim[token] = axes_lengths.get(token, shape_dim)
            src_vars_in_order.append(token)
            processed_src_vars.add(token)
            current_dim_index += 1

    # split tensor: reshape tensor into individual axes dimensions
    split_shape = [var2dim[var] for var in src_vars_in_order]
    try:
        tensor_split = tensor.reshape(split_shape)
    except ValueError as e:
        raise ValueError(f"cannot reshape input tensor (shape {tensor.shape}) into intermediate split shape {split_shape}. original error: {e}")

    # process repeat instructions in target pattern.
    # repeat groups are now explicitly defined as (repeat axis1 axis2 ...)
    repeat_instructions = []
    for token in tgt_tokens:
        if token.startswith('(repeat') and token.endswith(')'):
            # tokenize and check that token starts with 'repeat'
            inner_tokens = re.findall(r'\w+', token[1:-1])
            if len(inner_tokens) < 2 or inner_tokens[0] != 'repeat':
                raise ValueError(f"invalid repeat syntax in target: '{token}'. expected format like '(repeat axis1 axis2 ...)'")
            axes_to_repeat = inner_tokens[1:]
            # require a repeat factor to be provided in axes_lengths with key 'repeat'
            if 'repeat' not in axes_lengths:
                raise ValueError(f"missing repeat factor 'repeat' in axes_lengths for repeat group {token}")
            factor = axes_lengths['repeat']
            if not isinstance(factor, int) or factor <= 0:
                raise ValueError(f"repeat factor must be a positive integer, got {factor}")
            repeat_instructions.append({'axes': axes_to_repeat, 'factor': factor, 'token': token})

            # update var2dim for the repeated axes before permutation/reshape.
            for axis in axes_to_repeat:
                if axis not in var2dim:
                    raise ValueError(f"axis '{axis}' specified in repeat group '{token}' not found in src pattern or axes_lengths.")
                var2dim[axis] *= factor

    # apply repeats on the tensor split before permutation.
    if repeat_instructions:
        for instruction in repeat_instructions:
            factor = instruction['factor']
            axes = instruction['axes']
            for axis in axes:
                try:
                    axis_index = src_vars_in_order.index(axis)
                except ValueError:
                    raise ValueError(f"axis '{axis}' from repeat group '{instruction['token']}' not found in src pattern.")
                # use np.repeat along the specified axis.
                tensor_split = np.repeat(tensor_split, factor, axis=axis_index)
                # note: since np.repeat modifies the size in-place, our var2dim update above is consistent.
    
    # build the final target shape from target tokens.
    target_vars_in_order = []
    final_target_shape = []
    for token in tgt_tokens:
        if token.startswith('(') and token.endswith(')'):
            # group token in target pattern
            group_vars = re.findall(r'\w+', token[1:-1])
            # check for repeat group - but note that repeats have already been processed.
            if group_vars and group_vars[0] == 'repeat':
                axes_in_group = group_vars[1:]
                if not axes_in_group:
                    raise ValueError(f"repeat group '{token}' must specify at least one axis.")
                prod = 1
                for var in axes_in_group:
                    if var not in var2dim:
                        raise ValueError(f"axis '{var}' in target group '{token}' not found in src or axes_lengths.")
                    prod *= var2dim[var]
                    target_vars_in_order.append(var)
                final_target_shape.append(prod)
            else:
                prod = 1
                for var in group_vars:
                    if var not in var2dim:
                        raise ValueError(f"axis '{var}' in target group '{token}' not found in src or axes_lengths.")
                    prod *= var2dim[var]
                    target_vars_in_order.append(var)
                final_target_shape.append(prod)
        else:
            # simple token in target
            if not re.fullmatch(r'\w+', token):
                raise ValueError(f"invalid token in target pattern: '{token}'")
            if token not in var2dim:
                raise ValueError(f"axis '{token}' in target pattern not found in src or axes_lengths.")
            target_vars_in_order.append(token)
            final_target_shape.append(var2dim[token])

    # ensure that all src variables are used in target.
    if set(target_vars_in_order) != processed_src_vars:
        missing = processed_src_vars - set(target_vars_in_order)
        extra = set(target_vars_in_order) - processed_src_vars
        extra -= {'repeat'}
        if missing:
            raise ValueError(f"src axes {missing} are not included in the target pattern.")
        if extra:
            raise ValueError(f"target axes {extra} were not found in the src pattern.")

    # compute permutation order based on target vars order.
    try:
        permutation = [src_vars_in_order.index(var) for var in target_vars_in_order]
    except ValueError as e:
        raise ValueError(f"could not find target axis in the flattened src axes list: {e}")
    
    # perform permutation
    tensor_permuted = np.transpose(tensor_split, permutation)
    
    # final reshape to merge axes as specified by target pattern.
    try:
        final_tensor = tensor_permuted.reshape(final_target_shape)
    except ValueError as e:
        actual = np.prod(tensor_permuted.shape)
        expected = np.prod(final_target_shape)
        raise ValueError(f"cannot reshape permuted tensor (shape {tensor_permuted.shape}, {actual} elements) into final target shape {final_target_shape} ({expected} elements). original error: {e}")
    
    return final_tensor


In [58]:
# call einops for evaluation purposes
from einops import rearrange as einops_rearrange
from einops import repeat as einops_repeat

### transposition

In [59]:
def test_case(x, pattern, description=""):
    # run both the custom rearrange and the einops version.
    custom_output = rearrange(x, pattern)
    einops_output = einops_rearrange(x, pattern)
    
    print(f"Test: {description}")
    print("Input shape:", x.shape)
    print("Transformation pattern:", pattern)
    print("Custom output shape:", custom_output.shape)
    print("Einops output shape:", einops_output.shape)
    
    if np.array_equal(custom_output, einops_output):
        print("[PASSED]")
    else:
        print("[FAILED]")
        print("Custom output:\n", custom_output)
        print("Einops output:\n", einops_output, "\n")


def run_tests():
    # transposing a 2D array.
    x = np.random.randint(0, 10, size=(3, 4))
    print("Original array:\n", x)
    print("Original shape:", x.shape)
    print()
    test_case(x, "h w -> w h", "2D array transpose")
    
    # swap the last two axes.
    x = np.random.randint(0, 10, size=(2, 3, 4))
    test_case(x, "b h w -> b w h", "3D array: swap h and w")
    
    # flatten the array.
    x = np.random.randint(0, 10, size=(3, 4))
    test_case(x, "h w -> (h w)", "2D array flatten")
    
    # reshape to 2D by merging first two axes.
    x = np.random.randint(0, 10, size=(2, 3, 4))
    test_case(x, "b h w -> (b h) w", "3D array: merge first two dimensions")
    
    # empty array
    x = np.empty((0, 4))
    test_case(x, "h w -> w h", "Empty array transpose")
    
    # single element array
    x = np.array([[42]])
    test_case(x, "h w -> w h", "Single-element array transpose")

if __name__ == "__main__":
    run_tests()

Original array:
 [[8 8 9 3]
 [8 6 9 6]
 [4 2 3 8]]
Original shape: (3, 4)

Test: 2D array transpose
Input shape: (3, 4)
Transformation pattern: h w -> w h
Custom output shape: (4, 3)
Einops output shape: (4, 3)
[PASSED]
Test: 3D array: swap h and w
Input shape: (2, 3, 4)
Transformation pattern: b h w -> b w h
Custom output shape: (2, 4, 3)
Einops output shape: (2, 4, 3)
[PASSED]
Test: 2D array flatten
Input shape: (3, 4)
Transformation pattern: h w -> (h w)
Custom output shape: (12,)
Einops output shape: (12,)
[PASSED]
Test: 3D array: merge first two dimensions
Input shape: (2, 3, 4)
Transformation pattern: b h w -> (b h) w
Custom output shape: (6, 4)
Einops output shape: (6, 4)
[PASSED]
Test: Empty array transpose
Input shape: (0, 4)
Transformation pattern: h w -> w h
Custom output shape: (4, 0)
Einops output shape: (4, 0)
[PASSED]
Test: Single-element array transpose
Input shape: (1, 1)
Transformation pattern: h w -> w h
Custom output shape: (1, 1)
Einops output shape: (1, 1)
[PASSED

### splitting of axes

In [60]:
def test_case(x, pattern, description="", **kwargs):
    # run both the custom rearrange and the einops version
    custom_output = rearrange(x, pattern, **kwargs)
    einops_output = einops_rearrange(x, pattern, **kwargs)
    
    print(f"test: {description}")
    print("input shape:", x.shape)
    print("transformation pattern:", pattern)
    print("custom output shape:", custom_output.shape)
    print("einops output shape:", einops_output.shape)
    
    if np.array_equal(custom_output, einops_output):
        print("[PASSED]")
    else:
        print("[FAILED]")
        print("custom output:\n", custom_output)
        print("einops output:\n", einops_output, "\n")


def run_tests():

    x = np.random.rand(12, 10,5)
    print("original shape:", x.shape)
    
    pattern = "b (h w) c -> b h w c"
    y = rearrange(x, pattern, h=2)
    print("shape after splitting axes :", y.shape)

    
    # validate the custom rearrange against einops
    test_case(x, pattern, description="split axes with h=2", h=2)

if __name__ == "__main__":
    run_tests()

original shape: (12, 10, 5)
shape after splitting axes : (12, 2, 5, 5)
test: split axes with h=2
input shape: (12, 10, 5)
transformation pattern: b (h w) c -> b h w c
custom output shape: (12, 2, 5, 5)
einops output shape: (12, 2, 5, 5)
[PASSED]


### merging axes

In [61]:
def test_case(x, pattern, description=""):
    # run both the custom rearrange and the einops version
    custom_output = rearrange(x, pattern)
    einops_output = einops_rearrange(x, pattern)
    
    print(f"test: {description}")
    print("input shape:", x.shape)
    print("transformation pattern:", pattern)
    print("custom output shape:", custom_output.shape)
    print("einops output shape:", einops_output.shape)
    
    if np.array_equal(custom_output, einops_output):
        print("[PASSED]\n")
    else:
        print("[FAILED]")
        print("custom output:\n", custom_output)
        print("einops output:\n", einops_output, "\n")


def run_tests():
    x = np.random.rand(3, 4, 5)
    print("original shape:", x.shape)
    
    pattern = "b h w  -> b (h w) "
    y = rearrange(x, pattern)
    print("shape after merged axes :", y.shape)
    
    # validate the custom rearrange against einops
    test_case(x, pattern, description="merged axes")

if __name__ == "__main__":
    run_tests()

original shape: (3, 4, 5)
shape after merged axes : (3, 20)
test: merged axes
input shape: (3, 4, 5)
transformation pattern: b h w  -> b (h w) 
custom output shape: (3, 20)
einops output shape: (3, 20)
[PASSED]



### repeat axis

In [63]:
# test 1: original case - repeating a single axis
x = np.random.rand(2, 3, 4)
print("test 1 - original shape:", x.shape)
# repeat keyword is needed for the axis to be repeated
y = rearrange(x, 'b c h -> b c (repeat h)', repeat=4)
print("[PASSED] test 1 - shape after repeated axis:", y.shape)

# test 2: complex pattern - reordering and repeating multiple axes
x3 = np.random.rand(2, 3, 4, 5)
print("\ntest 2 - original shape:", x3.shape)
# here we repeat the b axis and then permute axes; note that repeat is applied on axis 'b'
try:
    y3 = rearrange(x3, 'b c h w -> c (repeat b) h w', repeat=2)
    print("[PASSED] test 2 - shape after repeating first axis:", y3.shape)
except Exception as e:
    print("[FAILED] test 2 - error in complex repeat test:", e)

test 1 - original shape: (2, 3, 4)
[PASSED] test 1 - shape after repeated axis: (2, 3, 16)

test 2 - original shape: (2, 3, 4, 5)
[PASSED] test 2 - shape after repeating first axis: (3, 4, 4, 5)
