In [1]:
# import warnings
# warnings.filterwarnings('ignore')

Einsum Tutorial

Source: https://rockt.github.io/2018/04/30/einsum

Matrix Transpose:
$$B_{ji} = A_{ij}$$

In [2]:
import torch
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->ji', [a])

tensor([[0, 3],
        [1, 4],
        [2, 5]])

Sum:

$$b = \sum_{i}sum_{j}A_{ij} = A_{ij}$$

In [3]:
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])

tensor(15)

Column Sum:

$$b_{j} = \sum_{i}A_{ij} = A_{ij}$$

In [4]:
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])

tensor([3, 5, 7])

Row Sum:

$$b_{i} = \sum_{j}A_{ij} = A_{ij}$$

In [5]:
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->i', [a])

tensor([ 3, 12])

Matrix-Vector Multiplication:

$$C_{ij} = \sum_{k}A_{ik}b_{k} = A_{ik}B_{kj}$$

In [6]:
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])

tensor([ 5, 14])

Matrix-Matrix Multiplication:

$$C_{ij} = \sum_{k}A_{ik}b_{kj} = A_{ik}b_{kj}$$

In [7]:
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
torch.einsum('ik,kj->ij', [a, b])

tensor([[ 25,  28,  31,  34,  37],
        [ 70,  82,  94, 106, 118]])

Dot Product Vector:

$$c_{i} = \sum_{i}a_{i}b_{i} = a_{i}b_{i}$$

In [8]:
a = torch.arange(3)
b = torch.arange(3,6)  # -- a vector of length 3 containing [3, 4, 5]
torch.einsum('i,i->', [a, b])

tensor(14)

Dot Product Matrix:

$$c_{i} = \sum_{i}\sum_{j}A_{ij}B_{ij} = A_{ij}B_{ij}$$

In [9]:
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])

tensor(145)

Hadamard Product:

$$C_{ij} = A_{ij}B_{ij}$$

In [10]:
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])

tensor([[ 0,  7, 16],
        [27, 40, 55]])

Bilinear Transformation:

$$D_{ij} = \sum_{k}\sum_{l}A_{ik}B_{jkl}C_{il} = A_{ik}B_{jkl}C_{il}$$

In [11]:
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
torch.einsum('ik,jkl,il->ij', [a, b, c])

tensor([[ 2.3644, -1.4066, -2.8853, -3.3441,  2.4853],
        [-0.7584,  2.4459,  1.3680,  5.7902, -3.5133]])

Attention:

Rocktäschel, Grefenstette, Hermann, Kocisky and Blunsom. Reasoning about Entailment with Neural Attention. in: International Conference on Learning Representations (ICLR). 2016

$$M_{t} = tanh(W^{y}Y + (W^{h}h_{t} + W^{r}r_{t-1}) \bigotimes e_{L}) \enspace\enspace M_{t} ϵ \mathbb{R}^{kXL} $$
$$\alpha_{t} = softmax(w^{T}M_{t}) \enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace \alpha_{t} ϵ \mathbb{R}^{L} $$
$$r_{t} = Y\alpha^{T}_{t} + tanh(W^{t}r_{t-1}) \enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace\enspace r_{t} ϵ \mathbb{R}^{k} $$

In [12]:
import torch.nn.functional as F

def random_tensors(shape, num=1, requires_grad=False):
  tensors = [torch.randn(shape, requires_grad=requires_grad) for i in range(0, num)]
  return tensors[0] if num == 1 else tensors

# Parameters
# -- [hidden_dimension]
bM, br, w = random_tensors([7], num=3, requires_grad=True)
# -- [hidden_dimension x hidden_dimension]
WY, Wh, Wr, Wt = random_tensors([7, 7], num=4, requires_grad=True)

# Single application of attention mechanism 
def attention(Y, ht, rt1):
  # -- [batch_size x hidden_dimension] 
  tmp = torch.einsum("ik,kl->il", [ht, Wh]) + torch.einsum("ik,kl->il", [rt1, Wr])
  Mt = F.tanh(torch.einsum("ijk,kl->ijl", [Y, WY]) + tmp.unsqueeze(1).expand_as(Y) + bM)
  # -- [batch_size x sequence_length]
  at = F.softmax(torch.einsum("ijk,k->ij", [Mt, w])) 
  # -- [batch_size x hidden_dimension]
  rt = torch.einsum("ijk,ij->ik", [Y, at]) + F.tanh(torch.einsum("ij,jk->ik", [rt1, Wt]) + br)
  # -- [batch_size x hidden_dimension], [batch_size x sequence_dimension]
  return rt, at

# Sampled dummy inputs
# -- [batch_size x sequence_length x hidden_dimension]
Y = random_tensors([3, 5, 7])
# -- [batch_size x hidden_dimension]
ht, rt1 = random_tensors([3, 7], num=2)

rt, at = attention(Y, ht, rt1)
at  # -- print attention weights

tensor([[0.2531, 0.0308, 0.4751, 0.2212, 0.0198],
        [0.1354, 0.2035, 0.4196, 0.1273, 0.1142],
        [0.0536, 0.5476, 0.0097, 0.0278, 0.3613]], grad_fn=<SoftmaxBackward0>)