In [7]:
import torch

# make a tensor out of a range
a = torch.tensor(range(6)).reshape(2, 3)
b = torch.tensor(range(1, -5, -1)).reshape(2, 3)
c = torch.tensor([0.5, 1.5]).unsqueeze(1)
print(a)
print(b)
print(c)

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[ 1,  0, -1],
        [-2, -3, -4]])
tensor([[0.5000],
        [1.5000]])


In [8]:
# unary operations
print("identity")
print(a)
print(torch.einsum('ij->ij', a))
print(torch.einsum('ij', a))
print("transpose")
print(a.T)
print(torch.einsum('ij->ji', a))
print("row sum")
print(a.sum(dim=1))
print(torch.einsum('ij->i', a))
print("column sum")
print(a.sum(dim=0))
print(torch.einsum('ij->j', a))
print("sum")
print(a.sum())
print(torch.einsum('ij->', a))

identity
tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[0, 1, 2],
        [3, 4, 5]])
transpose
tensor([[0, 3],
        [1, 4],
        [2, 5]])
tensor([[0, 3],
        [1, 4],
        [2, 5]])
row sum
tensor([ 3, 12])
tensor([ 3, 12])
column sum
tensor([3, 5, 7])
tensor([3, 5, 7])
sum
tensor(15)
tensor(15)


In [11]:
# you can also do the following, apparently

print(torch.einsum('ii', a))

tensor([[0, 1, 2],
        [3, 4, 5]])


RuntimeError: einsum(): subscript i is repeated for operand 0 but the sizes don't match, 3 != 2

In [63]:
# binary operations: Hadamard products
print(a)
print(b)
print("hadamard product")
print(a*b)
print(torch.einsum('ij,ij->ij', a, b), a*b)
print("hadamard product transpose")
print((a*b).T)
print(torch.einsum('ij,ij->ji', a, b)) 
print("hadamard product row sum")
print((a*b).sum(dim=1))
print(torch.einsum('ij,ij->i', a, b))
print("hadamard product column sum")
print((a*b).sum(dim=0))
print(torch.einsum('ij,ij->j', a, b))
print("hadamard product sum")
print((a*b).sum())
print(torch.einsum('ij,ij->', a, b))

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[ 1,  0, -1],
        [-2, -3, -4]])
hadamard product
tensor([[  0,   0,  -2],
        [ -6, -12, -20]])
tensor([[  0,   0,  -2],
        [ -6, -12, -20]]) tensor([[  0,   0,  -2],
        [ -6, -12, -20]])
hadamard product row sum
tensor([ -2, -38])
tensor([ -2, -38])
hadamard product column sum
tensor([ -6, -12, -22])
tensor([ -6, -12, -22])
hadamard product sum
tensor(-40)
tensor(-40)
hadamard product transpose
tensor([[  0,  -6],
        [  0, -12],
        [ -2, -20]])
tensor([[  0,  -6],
        [  0, -12],
        [ -2, -20]])


In [66]:
# binary operations: transposed matrix multiplication
print(a)
print(b)
print("a x b^T")
print(a @ b.T)
print(torch.einsum('ij,kj->ik', a, b))
print("a^T x b")
print(a.T @ b)
print(torch.einsum('ij,ik->jk', a, b))
print("a x b.T row sum")
print((a @ b.T).sum(dim=1))
print(torch.einsum('ij,kj->i', a, b))
print("a^T x b column sum")
print((a.T @ b).sum(dim=0))
print(torch.einsum('ij,ik->k', a, b))
print("a x b.T sum")
print((a @ b.T).sum())
print(torch.einsum('ij,kj->', a, b))
print("a^T x b row sum")
print((a.T @ b).sum())
print(torch.einsum('ij,ik->', a, b))
print("a x b.T column sum")
print((a @ b.T).sum())
print(torch.einsum('ij,kj->', a, b))
print("a^T x b sum")
print((a.T @ b).sum())
print(torch.einsum('ij,ik->', a, b))

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[ 1,  0, -1],
        [-2, -3, -4]])
a x b^T
tensor([[ -2, -11],
        [ -2, -38]])
tensor([[ -2, -11],
        [ -2, -38]])
a^T x b
tensor([[ -6,  -9, -12],
        [ -7, -12, -17],
        [ -8, -15, -22]])
tensor([[ -6,  -9, -12],
        [ -7, -12, -17],
        [ -8, -15, -22]])
a x b.T row sum
tensor([-13, -40])
tensor([-13, -40])
a^T x b column sum
tensor([-21, -36, -51])
tensor([-21, -36, -51])
a x b.T sum
tensor(-53)
tensor(-53)
a^T x b row sum
tensor(-108)
tensor(-108)
a x b.T column sum
tensor(-53)
tensor(-53)
a^T x b sum
tensor(-108)
tensor(-108)


In [43]:
# I am unclear on this one
print(a)
print(b)
torch.einsum('ij,ik->ijk', a, b) # the way Bing describes this is the element at i,j,k is a[i,j] * b[i,k].  Okay so actually...

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[ 1,  0, -1],
        [-2, -3, -4]])


tensor([[[  0,   0,   0],
         [  1,   0,  -1],
         [  2,   0,  -2]],

        [[ -6,  -9, -12],
         [ -8, -12, -16],
         [-10, -15, -20]]])

In [78]:
# binary operations: outer product
print(a)
print(b)
print("Kronecker product, sort of")
print(torch.kron(a, b), torch.kron(a, b).shape)
print(torch.einsum('ij,kl->ijkl', a, b), torch.einsum('ij,kl->ijkl', a, b).shape)
# So it occurs to me that in addition to einsum, my skills with reshape are also weak


tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[ 1,  0, -1],
        [-2, -3, -4]])
Kronecker product, sort of
tensor([[  0,   0,   0,   1,   0,  -1,   2,   0,  -2],
        [  0,   0,   0,  -2,  -3,  -4,  -4,  -6,  -8],
        [  3,   0,  -3,   4,   0,  -4,   5,   0,  -5],
        [ -6,  -9, -12,  -8, -12, -16, -10, -15, -20]]) torch.Size([4, 9])
tensor([[[[  0,   0,   0],
          [  0,   0,   0]],

         [[  1,   0,  -1],
          [ -2,  -3,  -4]],

         [[  2,   0,  -2],
          [ -4,  -6,  -8]]],


        [[[  3,   0,  -3],
          [ -6,  -9, -12]],

         [[  4,   0,  -4],
          [ -8, -12, -16]],

         [[  5,   0,  -5],
          [-10, -15, -20]]]]) torch.Size([2, 3, 2, 3])
tensor([[  0,   0,   0,   0,   0,   0,   1,   0,  -1],
        [ -2,  -3,  -4,   2,   0,  -2,  -4,  -6,  -8],
        [  3,   0,  -3,  -6,  -9, -12,   4,   0,  -4],
        [ -8, -12, -16,   5,   0,  -5, -10, -15, -20]])


In [80]:
# A manual loop that performs einsum, minus several options such as diagonals.
import itertools
def loop_einsum(s, a, b):
    in_string, out_string = s.split('->')
    in_string1, in_string2 = in_string.split(',')
    in_string1, in_string2, out_string = list(in_string1), list(in_string2), list(out_string)
    inter_string = in_string1 + [letter for letter in in_string2 if letter not in in_string1]
    inter_dims = list(a.shape)
    for i, letter in enumerate(in_string2):
        if letter not in in_string1:
            inter_dims.append(b.shape[i])
    inter_vals = torch.zeros(inter_dims)
    coords = list(itertools.product(*[range(dim) for dim in inter_dims]))
    # element-wise multiplication
    for coord in coords:
        coords_a = [coord[i] for i, letter in enumerate(inter_string) if letter in in_string1]
        coords_b = [coord[i] for i, letter in enumerate(inter_string) if letter in in_string2]
        inter_vals[coord] = a[*coords_a] * b[*coords_b]
    # summation
    sum_dims = tuple([i for i, letter in enumerate(inter_string) if letter not in out_string])
    if len(sum_dims) == 0:
        out_vals = inter_vals
    else:
        out_vals = inter_vals.sum(dim=sum_dims)
    # transpose
    ordered_out_string = [letter for letter in inter_string if letter in out_string]
    if tuple(ordered_out_string) != tuple(out_string):
        out_vals = out_vals.permute(*[ordered_out_string.index(letter) for letter in out_string])
    return out_vals

import torch
a = torch.tensor(range(6)).reshape(2, 3)
b = torch.tensor(range(1, -5, -1)).reshape(2, 3)
test_cases = [
    "ij,ij->ij",
    "ij,ij->ji",
    "ij,ij->i",
    "ij,ij->j",
    "ij,ij->",
    "ij,ik->ijk",
    "ij,kl->ijkl",
]
for test_case in test_cases:
    canonical = torch.einsum(test_case, a, b)
    mine = loop_einsum(test_case, a, b)
    print(test_case, (canonical==mine).all())

# All my tests pass, and I think I understand it now.  So the one thing I didn't implement were diagonals...how does that work again?

ij,ij->ij tensor(True)
ij,ij->ji tensor(True)
ij,ij->i tensor(True)
ij,ij->j tensor(True)
ij,ij-> tensor(True)
ij,ik->ijk tensor(True)
ij,kl->ijkl tensor(True)


In [None]:
# Alright so what does this mean in terms of attention and linear attention?
#KV = torch.einsum("nshd,nshm->nhmd", K, values)
# exactly one unshared dimension means it's a matmul, I think.  Quadratic, but scaling with model dimension rather than sequence length.

#Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)
# dropping one dimension, with all the others shared, I think means a Hadamard product followed by a sum, so that's in linear time.
# V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z) 
# uh oh, triple threat.  This one might be another matmul scaling with model dim.