In [13]:
import torch
te = torch.tensor

In [14]:
from my_einsum import my_einsum

In [15]:
einsum_imp = torch.einsum
einsum_imp = my_einsum

In [16]:
def print_congrats(case_name):
    print(f'✅ {case_name} PASSED 🎊🥳💃🎉 !')

In [17]:
CASE_NAME = "(1d matrix)"
x = torch.arange(6)

# ---------------------------------------------------
actual = einsum_imp('i->i', x)
expected = torch.tensor([0,1,2,3,4,5])
assert torch.allclose ( actual, expected )

# ---------------------------------------------------
actual = einsum_imp('i->', x)
expected = torch.tensor([15])
assert torch.allclose ( actual, expected )

print_congrats(CASE_NAME)

✅ (1d matrix) PASSED 🎊🥳💃🎉 !


In [18]:
CASE_NAME = "(1d matrix),(1d matrix) of different shapes"

x = torch.arange(6)
y = torch.arange(3)

print(f'{x=} {y=}')


# ---------------------------------------------------
# AKA OUTER PRODUCT

actual = einsum_imp('i,j->ij', x, y)

expected = torch.stack([
    0*y,
    1*y,
    2*y,
    3*y,
    4*y,
    5*y,
])

assert torch.allclose ( actual, expected )
# ---------------------------------------------------
# AKA OUTER PRODUCT
actual = einsum_imp('i,j->ji', x, y)

expected = torch.stack([
    0*x,
    1*x,
    2*x,
])

assert torch.allclose ( actual, expected )
# ---------------------------------------------------
actual = einsum_imp('i,j->i', x, y)
expected = torch.tensor([0, 3, 6, 9, 12, 15])
assert torch.allclose ( actual, expected )


# ---------------------------------------------------
actual = einsum_imp('i,j->j', x, y)
expected = torch.tensor([0, 15, 30])
assert torch.allclose ( actual, expected )

# ---------------------------------------------------
# AKA INNER PRODUCT
actual = einsum_imp('i,j->', x, y)
expected = torch.tensor([0 + 15 + 30])
assert torch.allclose ( actual, expected )

print_congrats(CASE_NAME)

x=tensor([0, 1, 2, 3, 4, 5]) y=tensor([0, 1, 2])
✅ (1d matrix),(1d matrix) of different shapes PASSED 🎊🥳💃🎉 !


In [19]:
CASE_NAME = "(1d matrix),(1d matrix) of SAME shapes"

x = torch.arange(6)
y = torch.arange(6)

print(f'{x=} {y=}')

# ---------------------------------------------------
actual = einsum_imp('i,i->i', x, y)
expected = torch.tensor([0, 1, 4, 9, 16, 25])
assert torch.allclose ( actual, expected )
# ---------------------------------------------------
actual = einsum_imp('i,i->', x, y)
expected = torch.tensor([0+1+4+9+16+25])
assert torch.allclose ( actual, expected )
# ---------------------------------------------------


print_congrats(CASE_NAME)

x=tensor([0, 1, 2, 3, 4, 5]) y=tensor([0, 1, 2, 3, 4, 5])
✅ (1d matrix),(1d matrix) of SAME shapes PASSED 🎊🥳💃🎉 !


In [20]:
CASE_NAME =  "(2d matrix)"
x = torch.arange(6).reshape( (2,3) )

# ---------------------------------------------------
actual = einsum_imp('ij->ji', x)
expected = x.T
assert torch.allclose ( actual, expected )

# ---------------------------------------------------
actual = einsum_imp('ij->i', x)
expected = torch.tensor([3, 12])
assert torch.allclose ( actual, expected )

# ---------------------------------------------------
actual = einsum_imp('ij->j', x)
expected = torch.tensor([3, 5, 7])
assert torch.allclose ( actual, expected )

# ---------------------------------------------------
actual = einsum_imp('ij->', x)
expected = torch.tensor([15])
assert torch.allclose ( actual, expected )
# ---------------------------------------------------
x = torch.arange(9).reshape( (3,3) )
actual = einsum_imp('ii->i', x)
expected = torch.tensor([0,4,8])
assert torch.allclose ( actual, expected )
# ---------------------------------------------------

actual = einsum_imp('ii->', x)
expected = torch.tensor([0+4+8])
assert torch.allclose ( actual, expected )
# ---------------------------------------------------

print_congrats(CASE_NAME)

✅ (2d matrix) PASSED 🎊🥳💃🎉 !


In [21]:
CASE_NAME =  "(2d matrix), (1d matrix)"
x = torch.arange(6).reshape( (2,3) )
y = torch.arange(3)

# ---------------------------------------------------
actual = einsum_imp('ij,k->kji', x, y)
expected = torch.stack([
    0*x.T,
    1*x.T,
    2*x.T,
])
assert torch.allclose ( actual, expected )

# ---------------------------------------------------
actual = einsum_imp('ij,j->ij', x, y)
expected = torch.tensor([
    [0*0,1*1,2*2],
    [3*0,4*1,5*2]
])
assert torch.allclose ( actual, expected )

# ---------------------------------------------------
actual = einsum_imp('ij,k->ij', x, y)
expected = torch.tensor([
    [0*0 + 0*1 + 0*2, 1*0 + 1*1 + 1*2, 2*0 + 2*1 + 2*2],
    [3*0 + 3*1 + 3*2, 4*0 + 4*1 + 4*2, 5*0 + 5*1 + 5*2]
])
assert torch.allclose ( actual, expected )

# ---------------------------------------------------
x = torch.arange(9).reshape( (3,3) )
y = torch.arange(3)

actual = einsum_imp('ii,i->i', x, y)
expected = torch.tensor( [ 0*0, 1*4, 2*8 ] )
assert torch.allclose ( actual, expected )

print_congrats(CASE_NAME)

✅ (2d matrix), (1d matrix) PASSED 🎊🥳💃🎉 !


In [22]:
CASE_NAME =  "(2d matrix), (2d matrix)"
x = torch.arange(6).reshape( (2,3) )
y = torch.arange(6).reshape( (3,2) )

print(f'{x=}')
print()
print(f'{y=}')
print()

# ---------------------------------------------------
# AKA matrix multiplication
actual = einsum_imp('ij,jk->ik', x, y)
expected = x@y
assert torch.allclose ( actual, expected )

# ---------------------------------------------------
# [ ] TODO: interesting visualization case. At least review, ideally go visually through whole process
actual = einsum_imp('ij,jk->i', x, y)

def dot(a,b): return (te(a)*te(b)).sum()

expected = torch.tensor([
    dot([0,1,2], [0,2,4])+dot([0,1,2], [1,3,5]),
    dot([3,4,5], [0,2,4])+dot([3,4,5], [1,3,5]),
])

assert torch.allclose ( actual, expected )

print_congrats(CASE_NAME)

x=tensor([[0, 1, 2],
        [3, 4, 5]])

y=tensor([[0, 1],
        [2, 3],
        [4, 5]])

✅ (2d matrix), (2d matrix) PASSED 🎊🥳💃🎉 !


In [23]:
CASE_NAME =  "(2d), (2d), (2d)"
x = torch.arange(6).reshape( (2,3) )
y = torch.arange(6).reshape( (3,2) )
z = torch.arange(8).reshape( (2,4) )

print(f'{x=}')
print()
print(f'{y=}')
print()
print(f'{z=}')
print()

# ---------------------------------------------------
# AKA matrix multiplication
actual = einsum_imp('ab,bc,cd->', x, y, z)
expected = ((x@y)@z).sum()
assert torch.allclose ( actual, expected )


print_congrats(CASE_NAME)

x=tensor([[0, 1, 2],
        [3, 4, 5]])

y=tensor([[0, 1],
        [2, 3],
        [4, 5]])

z=tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

✅ (2d), (2d), (2d) PASSED 🎊🥳💃🎉 !
