In [88]:
import numpy as np

x = np.arange(16.0)
two_dim_x = np.reshape(x, (4, 4))
three_dim_x = np.reshape(x, (4, 2, 2))
four_dim_x = np.reshape(x, (2, 2, 2, 2))

In [89]:
two_dim_x

array([[ 0.,  1.,  2.,  3.],
       [ 4.,  5.,  6.,  7.],
       [ 8.,  9., 10., 11.],
       [12., 13., 14., 15.]])

In [90]:
three_dim_x

array([[[ 0.,  1.],
        [ 2.,  3.]],

       [[ 4.,  5.],
        [ 6.,  7.]],

       [[ 8.,  9.],
        [10., 11.]],

       [[12., 13.],
        [14., 15.]]])

In [91]:
four_dim_x

array([[[[ 0.,  1.],
         [ 2.,  3.]],

        [[ 4.,  5.],
         [ 6.,  7.]]],


       [[[ 8.,  9.],
         [10., 11.]],

        [[12., 13.],
         [14., 15.]]]])

Einops provides rearrange, reduce and repeat methods.

**Rearrange example**

In [92]:
from einops import rearrange

# swap two columns. two_dim_x 
rearrange(two_dim_x, 'x y -> y x')

array([[ 0.,  4.,  8., 12.],
       [ 1.,  5.,  9., 13.],
       [ 2.,  6., 10., 14.],
       [ 3.,  7., 11., 15.]])

In [93]:
# composition
rearrange(two_dim_x, 'x y -> (x y)')

array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15.])

In [94]:
# another composition
rearrange(three_dim_x, 'b x y -> b (x y)')

array([[ 0.,  1.,  2.,  3.],
       [ 4.,  5.,  6.,  7.],
       [ 8.,  9., 10., 11.],
       [12., 13., 14., 15.]])

In [113]:
rearrange(four_dim_x, 'b x y c -> (b x y c)', b=2)

array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15.])

In [96]:
# decomposition

rearrange(three_dim_x, '(b1 b2) x y  -> b1 b2 x y', b1=2)

array([[[[ 0.,  1.],
         [ 2.,  3.]],

        [[ 4.,  5.],
         [ 6.,  7.]]],


       [[[ 8.,  9.],
         [10., 11.]],

        [[12., 13.],
         [14., 15.]]]])

In [115]:
#expand dims
rearrange(two_dim_x, 'x y -> x y 1')

array([[[ 0.],
        [ 1.],
        [ 2.],
        [ 3.]],

       [[ 4.],
        [ 5.],
        [ 6.],
        [ 7.]],

       [[ 8.],
        [ 9.],
        [10.],
        [11.]],

       [[12.],
        [13.],
        [14.],
        [15.]]])

**Reduce example**

In [97]:
from einops import reduce

reduce(two_dim_x, 'x y -> x', 'sum') # there are also min, max, mean, prod

array([ 6., 22., 38., 54.])

In [98]:
# max-pooling with 2x2 kernel
reduce(four_dim_x, 'b (h h2) (w w2) c -> h (b w) c', 'max', h2=2, w2=2)

array([[[ 6.,  7.],
        [14., 15.]]])

**Repeat example**

In [99]:
from einops import repeat

repeat(two_dim_x, 'x y -> x (repeat y)', repeat=2)

array([[ 0.,  1.,  2.,  3.,  0.,  1.,  2.,  3.],
       [ 4.,  5.,  6.,  7.,  4.,  5.,  6.,  7.],
       [ 8.,  9., 10., 11.,  8.,  9., 10., 11.],
       [12., 13., 14., 15., 12., 13., 14., 15.]])

**Different frameworks example**

In [102]:
import tensorflow as tf
import torch

tf_three_dim_x = x = tf.Variable(three_dim_x)
torch_three_dim_x = torch.from_numpy(three_dim_x)

In [103]:
def basic_check(tensor):
    # Result is strictly determined and not framework specific.
    res = rearrange(tensor, 'b x y -> b y x')
    print(type(res), res.shape)
    
basic_check(tf_three_dim_x)
basic_check(torch_three_dim_x)

<class 'tensorflow.python.framework.ops.EagerTensor'> (4, 2, 2)
<class 'torch.Tensor'> torch.Size([4, 2, 2])


**Layers example**

In [104]:
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU, Flatten
from einops.layers.torch import Reduce

model_torch = Sequential(
    Conv2d(3, 6, kernel_size=5),
    MaxPool2d(kernel_size=2),
    Conv2d(6, 16, kernel_size=5),
    MaxPool2d(kernel_size=2),
    Flatten(),
    Linear(7744, 120), 
    ReLU(),
    Linear(120, 10), 
)

model_einops = Sequential(
    Conv2d(3, 6, kernel_size=5),
    MaxPool2d(kernel_size=2),
    Conv2d(6, 16, kernel_size=5),
    Reduce('b c (h 2) (w 2) -> b (c h w)', 'max'), # max pooling and flatten
    Linear(7744, 120), 
    ReLU(),
    Linear(120, 10), 
)

**Performance**

In [105]:
val = torch.randn(20, 3, 100, 100)

In [106]:
%%timeit
for i in range(100):
    model_torch(val)

1.84 s ± 180 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [107]:
%%timeit
for i in range(100):
    model_einops(val)

1.92 s ± 144 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [108]:
# internal cache for pattern matters

def get_pattern(i):
    return f"x_{i} y -> y x_{i}"

In [109]:
%%timeit
for i in range(10000):
    rearrange(two_dim_x, get_pattern(i))

445 ms ± 40.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [110]:
%%timeit
for i in range(10000):
    rearrange(two_dim_x, get_pattern(0))

45.4 ms ± 3.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
