# Einsum

https://rockt.github.io/2018/04/30/einsum

In [1]:
import torch

In [2]:
a = torch.arange(6).reshape(2, 3)
a

tensor([[0, 1, 2],
        [3, 4, 5]])

In [3]:
# transpose
torch.einsum('ij->ji', [a])

tensor([[0, 3],
        [1, 4],
        [2, 5]])

In [4]:
# sum
torch.einsum('ij->', [a])

tensor(15)

In [5]:
# column sum
torch.einsum('ij->j', [a])

tensor([3, 5, 7])

In [6]:
# row sum
torch.einsum('ij->i', [a])

tensor([ 3, 12])

In [7]:
# matrix-vector multiplication
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])

tensor([ 5, 14])

In [8]:
# matrix-matrix multiplication
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
torch.einsum('ik,kj->ij', [a, b])

tensor([[ 25,  28,  31,  34,  37],
        [ 70,  82,  94, 106, 118]])

In [10]:
# dot product
a = torch.arange(3)
b = torch.arange(3,6)
print(torch.einsum('i,i->', [a, b]))

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
print(torch.einsum('ij,ij->', [a, b]))

tensor(14)
tensor(145)


In [17]:
def random_tensors(shape, num=1, requires_grad=False):
  tensors = [torch.randn(shape, requires_grad=requires_grad) for i in range(0, num)]
  return tensors[0] if num == 1 else tensors

def transition(zl):
  # -- [batch_size x num_actions x hidden_dimension]
  return zl.unsqueeze(1) + torch.tanh(torch.einsum("bk,aki->bai", [zl, W]) + b)

# -- [num_actions x hidden_dimension]
b = random_tensors([5, 3], requires_grad=True)
# -- [num_actions x hidden_dimension x hidden_dimension]
W = random_tensors([5, 3, 3], requires_grad=True)

# Sampled dummy inputs
# -- [batch_size x hidden_dimension]
zl = random_tensors([2, 3])

transition(zl)

tensor([[[-1.9040,  1.1765,  0.5330],
         [-1.7898,  0.9375, -0.8331],
         [-2.0107,  2.7527, -0.1804],
         [-1.7663,  2.7636,  1.0751],
         [-1.6905,  2.7144,  0.6291]],

        [[ 1.6053,  0.3546,  0.1547],
         [-0.2243, -0.1944,  0.0846],
         [ 0.5732, -0.0291,  0.3685],
         [ 0.9933,  1.7517,  0.9864],
         [-0.2649, -0.0050, -0.4425]]], grad_fn=<AddBackward0>)

In [22]:
import torch.nn.functional as F

# Parameters
# -- [hidden_dimension]
bM, br, w = random_tensors([7], num=3, requires_grad=True)
# -- [hidden_dimension x hidden_dimension]
WY, Wh, Wr, Wt = random_tensors([7, 7], num=4, requires_grad=True)

# Single application of attention mechanism
def attention(Y, ht, rt1):
    # -- [batch_size x hidden_dimension]
    tmp = torch.einsum("ik,kl->il", [ht, Wh]) + torch.einsum("ik,kl->il", [rt1, Wr])
    Mt = torch.tanh(torch.einsum("ijk,kl->ijl", [Y, WY]) + tmp.unsqueeze(1).expand_as(Y) + bM)
    # -- [batch_size x sequence_length]
    at = F.softmax(torch.einsum("ijk,k->ij", [Mt, w]), dim=1)
    # -- [batch_size x hidden_dimension]
    rt = torch.einsum("ijk,ij->ik", [Y, at]) + torch.tanh(torch.einsum("ij,jk->ik", [rt1, Wt]) + br)
    # -- [batch_size x hidden_dimension], [batch_size x sequence_dimension]
    return rt, at

# Sampled dummy inputs
# -- [batch_size x sequence_length x hidden_dimension]
Y = random_tensors([3, 5, 7])
# -- [batch_size x hidden_dimension]
ht, rt1 = random_tensors([3, 7], num=2)

rt, at = attention(Y, ht, rt1)
at  # -- print attention weights

tensor([[0.0696, 0.1255, 0.0665, 0.2041, 0.5343],
        [0.0711, 0.3393, 0.1199, 0.4316, 0.0381],
        [0.1206, 0.0227, 0.0022, 0.7835, 0.0710]], grad_fn=<SoftmaxBackward0>)

In [20]:
bM, br, w

(tensor([-0.0569,  0.9318, -0.0543, -1.3124, -0.5706, -0.0218,  0.4713],
        requires_grad=True),
 tensor([0.3007, 0.9175, 1.5571, 2.1650, 0.8585, 0.5402, 0.4908],
        requires_grad=True),
 tensor([ 1.0446,  1.0443,  0.6526, -0.6547, -0.0314, -0.2609,  0.0547],
        requires_grad=True))