<a href="https://colab.research.google.com/github/mananuppadhyay/einops_implementation/blob/main/Einops.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [99]:
import einops
import torch
from functools import reduce
import operator
import re
import math
import numpy as np

In [114]:
import re
import math
import numpy as np
from typing import List, Dict, Tuple, Any, Union, Optional, Set


class EinopsManan:
    def parse_lr(self, pattern: str) -> Tuple[List[str], List[str]]:
        """
        Parse the left and right sides of an einops pattern string.

        Args:
            pattern: A string in the format 'left -> right' where left and right
                     represent the input and output tensor specifications.

        Returns:
            A tuple containing two lists:
            - left_tokens: The tokenized left side of the pattern.
            - right_tokens: The tokenized right side of the pattern.

        Example:
            For pattern "b c (h w) -> b h w c", returns
            (['b', 'c', '(h w)'], ['b', 'h', 'w', 'c'])
        """
        left, right = pattern.split('->')
        token_pattern = r'\.\.\.|\w+|\([\w\s]+\)'
        left_tokens = re.findall(token_pattern, left.strip())
        right_tokens = re.findall(token_pattern, right.strip())
        return left_tokens, right_tokens

    def create_idx_mapping(self, l_expanded: List[str], r_expanded: List[str]) -> Tuple[Dict[str, int], Dict[str, int]]:
        """
        Create index mappings for the expanded left and right token lists.

        Args:
            l_expanded: The expanded list of tokens from the left side of the pattern.
            r_expanded: The expanded list of tokens from the right side of the pattern.

        Returns:
            A tuple containing two dictionaries:
            - mp_l: Maps each variable in l_expanded to its index.
            - mp_r: Maps each variable in r_expanded to its index.
        """
        mp_l = {var: i for i, var in enumerate(l_expanded)}
        mp_r = {var: i for i, var in enumerate(r_expanded)}
        return mp_l, mp_r

    def expand_left_and_right_sides(self, tensor: np.ndarray, l_tokens: List[str], r_tokens: List[str]) -> Tuple[List[str], List[str]]:
        """
        Expand the left and right token lists, handling ellipsis and grouped dimensions.

        Args:
            tensor: The input numpy array.
            l_tokens: The tokenized left side of the pattern.
            r_tokens: The tokenized right side of the pattern.

        Returns:
            A tuple containing two lists:
            - l_expanded: The expanded list of tokens from the left side.
            - r_expanded: The expanded list of tokens from the right side.

        Note:
            This function handles ellipsis (...) by replacing it with named dimensions
            and expands grouped dimensions like (h w) into separate dimensions.
        """
        n_ellipsis = len(tensor.shape) - (len(l_tokens) - sum(1 for t in l_tokens if t == '...'))
        ellipsis_vars = [f'__ellipsis{i}' for i in range(n_ellipsis)]

        def expand_tokens(tokens: List[str]) -> List[str]:
            """
            Expand a list of tokens, handling ellipsis and grouped dimensions.

            Args:
                tokens: A list of token strings from the pattern.

            Returns:
                A list of expanded token strings.
            """
            expanded = []
            for token in tokens:
                if token == '...':
                    expanded.extend(ellipsis_vars)
                elif token.startswith('('):
                    expanded.extend(token[1:-1].split())
                else:
                    expanded.append(token)
            return expanded

        l_expanded = expand_tokens(l_tokens)
        r_expanded = expand_tokens(r_tokens)
        return l_expanded, r_expanded

    def get_brackets(self, r_tokens: List[str], mp_r: Dict[str, int]) -> List[List[int]]:
        """
        Extract the bracketed (grouped) dimensions from the right side tokens.

        Args:
            r_tokens: The tokenized right side of the pattern.
            mp_r: The mapping from dimension names to indices in the expanded right tokens.

        Returns:
            A list of lists, where each inner list contains the indices of dimensions that
            are grouped together in the output tensor.
        """
        brackets = []
        for token in r_tokens:
            if token.startswith('('):
                grouped = token[1:-1].split()
                indices = [mp_r[var] for var in grouped]
                brackets.append(indices)
        return brackets

    def get_resulting_array(self, tensor: np.ndarray, l_tokens: List[str], r_tokens: List[str],
                           l_expanded: List[str], r_expanded: List[str], brackets: List[List[int]],
                           **axes_lengths: Dict[str, int]) -> np.ndarray:
        """
        Perform the actual tensor transformation based on the parsed pattern.

        Args:
            tensor: The input numpy array.
            l_tokens: The tokenized left side of the pattern.
            r_tokens: The tokenized right side of the pattern.
            l_expanded: The expanded list of tokens from the left side.
            r_expanded: The expanded list of tokens from the right side.
            brackets: The list of indices representing grouped dimensions in output.
            **axes_lengths: Keyword arguments specifying the sizes of dimensions.

        Returns:
            The transformed numpy array.

        Raises:
            ValueError: If there are issues with dimension sizes, incompatible shapes,
                       or insufficient information to determine dimension sizes.
        """
        reshape_dims = []
        array_dims = list(tensor.shape)
        n_ellipsis = len(tensor.shape) - (len(l_tokens) - sum(1 for t in l_tokens if t == '...'))
        ellipsis_vars = [f'ellipsis_{i}' for i in range(n_ellipsis)]

        # Process left side tokens for reshape operation
        for token in l_tokens:
            if token == '...':
                reshape_dims.extend(array_dims[:n_ellipsis])
                array_dims = array_dims[n_ellipsis:]
            elif token.startswith('('):
                grouped = token[1:-1].split()
                total = array_dims.pop(0)
                provided = {}
                missing = []
                for var in grouped:
                    if var in axes_lengths:
                        provided[var] = axes_lengths[var]
                    else:
                        missing.append(var)
                if len(missing) > 1:
                    raise ValueError(f"Missing sizes for {missing} in group {grouped}")
                known_prod = math.prod(provided.values(), start=1)
                if missing:
                    miss = missing[0]
                    if total % known_prod != 0:
                        raise ValueError(f"Cannot split {total} into {grouped} with provided sizes {provided}")
                    provided[miss] = total // known_prod
                if math.prod(provided.values()) != total:
                    raise ValueError(f"Product of {grouped} sizes does not match {total}")
                reshape_dims.extend(provided[var] for var in grouped)
            else:
                if token in axes_lengths:
                    expected_size = axes_lengths[token]
                    actual_size = array_dims.pop(0)
                    if expected_size != actual_size:
                        raise ValueError(f"Expected size {expected_size} for axis '{token}', got {actual_size}")
                    reshape_dims.append(expected_size)
                else:
                    reshape_dims.append(array_dims.pop(0))

        # Identify variables that exist in both left and right sides
        left_vars_set = set(l_expanded)
        existing_vars_in_right = [var for var in r_expanded if var in left_vars_set]
        new_vars_in_right = [var for var in r_expanded if var not in left_vars_set]

        # Check that all new dimensions in right side have specified lengths
        for var in new_vars_in_right:
            if var not in axes_lengths:
                raise ValueError(f"New axis '{var}' requires a size in axes_lengths")

        # Create mapping and determine new order for transposition
        mp_l, _ = self.create_idx_mapping(l_expanded, [])
        new_order = [mp_l[var] for var in existing_vars_in_right]
        left_only_vars = [var for var in l_expanded if var not in existing_vars_in_right]
        new_order += [mp_l[var] for var in left_only_vars]

        # Reshape and transpose
        try:
            b = tensor.reshape(reshape_dims)
            b = np.transpose(b, new_order)
        except ValueError as e:
            raise ValueError(f"Failed to reshape or transpose: {e}")

        # Handle new dimensions that need to be inserted
        inserted_positions = []
        for i, var in enumerate(r_expanded):
            if var in new_vars_in_right:
                size = axes_lengths[var]
                b = np.expand_dims(b, axis=i)
                b = np.repeat(b, size, axis=i)
                inserted_positions.append(i)

        # Handle dimensions that are in left but not right (to be squeezed out)
        left_only_positions = list(range(len(existing_vars_in_right), len(existing_vars_in_right) + len(left_only_vars)))
        for pos in inserted_positions:
            for j in range(len(left_only_positions)):
                if left_only_positions[j] >= pos:
                    left_only_positions[j] += 1

        for pos in reversed(sorted(left_only_positions)):
            if b.shape[pos] != 1:
                raise ValueError(f"Dimension {pos} must be 1 to squeeze, but is {b.shape[pos]}")
            b = np.squeeze(b, axis=pos)

        # Handle bracketed (grouped) dimensions in the right side
        offset = 0
        for group in brackets:
            group = [idx - offset for idx in group]
            start = group[0]
            end = group[-1]
            merged_size = np.prod(b.shape[start:end+1])
            new_shape = list(b.shape[:start]) + [merged_size] + list(b.shape[end+1:])
            b = b.reshape(new_shape)
            offset += (end - start)

        return b

    def rearrange(self, tensor: np.ndarray, pattern: str, **axes_lengths: int) -> np.ndarray:
        """
        Rearrange a tensor according to the provided einops-like pattern.

        Args:
            tensor: The input numpy array to rearrange.
            pattern: A string in the format 'left -> right' that specifies how to transform the tensor.
                    Where 'left' describes the input tensor dimensions and 'right' describes the output.
            **axes_lengths: Keyword arguments specifying the sizes for any new dimensions or
                           dimensions that need explicit size specification.

        Returns:
            The rearranged numpy array according to the provided pattern.

        Examples:
            >>> x = np.random.randn(3, 4, 5)  # shape (3, 4, 5)
            >>> einops = Einops()
            >>> y = einops.rearrange(x, 'a b c -> b a c')  # Transpose axes 0 and 1
            >>> y.shape
            (4, 3, 5)

            >>> # Split dimension 'b' into 'h' and 'w'
            >>> z = einops.rearrange(x, 'a (h w) c -> a h w c', h=2)
            >>> z.shape
            (3, 2, 2, 5)
        """
        l_tokens, r_tokens = self.parse_lr(pattern)
        l_expanded, r_expanded = self.expand_left_and_right_sides(tensor, l_tokens, r_tokens)
        mp_r = {var: i for i, var in enumerate(r_expanded)}
        brackets = self.get_brackets(r_tokens, mp_r)
        result = self.get_resulting_array(tensor, l_tokens, r_tokens, l_expanded, r_expanded, brackets, **axes_lengths)
        return result

In [115]:
einops = EinopsManan()

## Assignment test cases

In [116]:
x = np.random.rand(3, 4)
result = einops.rearrange(x, 'h w -> w h')
print(result.shape)

x = np.random.rand(12, 10)
result = einops.rearrange(x, '(h w) c -> h w c', h=3)
print(result.shape)

x = np.random.rand(3, 4, 5)
result = einops.rearrange(x, 'a b c -> (a b) c')
print(result.shape)

x = np.random.rand(3, 1, 5)
result = einops.rearrange(x, 'a 1 c -> a b c', b=4)
print(result.shape)

x = np.random.rand(2, 3, 4, 5)
result = einops.rearrange(x, '... h w -> ... (h w)')
print(result.shape)


(4, 3)
(3, 4, 10)
(12, 5)
(3, 4, 5)
(2, 3, 20)


## Test cases:

In [117]:
x = np.random.rand(2, 3, 4)
result = einops.rearrange(x, '... -> ...')
print(result.shape)  # (2,3,4)

(2, 3, 4)


In [118]:
x = np.random.rand(2, 2, 5)
result = einops.rearrange(x, 'a b c -> (a b) c')
print(result.shape)  # (4, 5)

(4, 5)


In [119]:
x = np.random.randn(4,5)
x = einops.rearrange(result, '(a b) c -> a b c', a=2)
print(x.shape)  # (2, 2, 5)

(2, 2, 5)


In [120]:
x = np.random.rand(2, 3, 4)
result = einops.rearrange(x, 'a b c -> c b a')
print(result.shape)  # (4, 3, 2)

(4, 3, 2)


In [121]:
x = np.random.rand(2, 3, 4)
result = einops.rearrange(x, '... c -> c ...')
print(result.shape)  # (4, 2, 3)

(4, 2, 3)


In [122]:
x = np.random.rand(2, 6, 5)
result = einops.rearrange(x, 'b (h w) c -> (c h) b w', h=2)
print(result.shape)  # (10, 2, 3)

(10, 2, 3)


In [123]:
x = np.random.rand(2, 3, 4, 5)
result = einops.rearrange(x, 'a b c d -> (a b) (c d)')
print(result.shape)  # (6, 20)

(6, 20)


In [124]:
x = np.random.rand(5)
result = einops.rearrange(x, 'c -> t c l', t=1, l=1 )
print(result.shape)  # (1, 5, 1)


(1, 5, 1)


In [125]:
x = np.random.rand(2, 3, 4, 5)
result = einops.rearrange(x, '... (h w) -> ... h w', h=2)
print(result.shape)

ValueError: Cannot split 5 into ['h', 'w'] with provided sizes {'h': 2}

In [126]:
x = np.random.rand(2, 6, 8)
result = einops.rearrange(x, 'b (h1 h2) (w1 w2) -> b h1 w1 h2 w2', h1=2, w1=2)
print(result.shape)  # (2, 2, 2, 3, 4)

(2, 2, 2, 3, 4)


In [127]:
x = np.random.rand(24, 8)
result = einops.rearrange(x, '(a b c) d -> c b a d', a=2, b=3)
print(result.shape)   # (4, 3, 2, 8)


(4, 3, 2, 8)
