# Basic operations w/ Einsum

In [1]:
import torch


In [5]:
# initalizing a tensor
x = torch.rand((2,3))
x

tensor([[0.9506, 0.0962, 0.4298],
        [0.5465, 0.4349, 0.7661]])

In [7]:
# premutation of Tensor
torch.einsum("ij->ji",x) # this here is a special case of Transposing (T)

tensor([[0.9506, 0.5465],
        [0.0962, 0.4349],
        [0.4298, 0.7661]])

In [8]:
# summation
torch.einsum("ij->",x)

tensor(3.2241)

In [9]:
# column sum
torch.einsum("ij->j",x)

tensor([1.4971, 0.5310, 1.1960])

In [10]:
# row sum
torch.einsum("ij->i",x)

tensor([1.4766, 1.7475])

In [11]:
# Matrix Vector Multiplication
V = torch.rand((1,3))
torch.einsum("ij, kj->ik",x,V)

tensor([[0.4631],
        [0.9800]])

In [16]:
# Matrix - Matrix multiplication
y = torch.rand((3,4))
torch.einsum("ij, jk-> ik",x,y)

tensor([[0.8822, 1.0980, 1.0431, 0.8831],
        [1.2143, 1.4004, 0.9319, 0.5482]])

In [17]:
# Dot product first row with first row of matrix
torch.einsum("i, i->",x[0],x[0])

tensor(1.0977)

In [18]:
# Dot product with matrix
torch.einsum("ij,ij->",x,x)

tensor(2.1724)

In [19]:
# Hadamard Product (element wise multiplication)
torch.einsum("ij,ij->ij",x,x)

tensor([[0.9037, 0.0092, 0.1848],
        [0.2987, 0.1891, 0.5869]])

In [20]:
# outer product
a = torch.rand((3))
b = torch.rand((5))
torch.einsum("i,j->ij",a,b)

tensor([[1.2311e-01, 3.6428e-01, 7.5051e-04, 5.3622e-01, 5.7032e-01],
        [8.9691e-02, 2.6540e-01, 5.4680e-04, 3.9067e-01, 4.1552e-01],
        [8.3478e-02, 2.4702e-01, 5.0892e-04, 3.6361e-01, 3.8674e-01]])

In [21]:
# Batch Matrix Multiplication
a = torch.rand((3,2,5))
b = torch.rand((3,5,3))
torch.einsum("ijk,ikl->ijl",a,b)

tensor([[[1.4419, 1.3241, 0.9689],
         [1.5320, 1.4319, 1.0265]],

        [[0.5533, 0.9317, 0.2423],
         [1.1477, 1.4001, 0.6014]],

        [[1.1152, 0.7390, 1.4030],
         [0.2587, 0.2661, 0.4532]]])

In [22]:
# Matrix diagonal
x = torch.rand((3,3))
torch.einsum("ii->i",x)

tensor([0.6895, 0.3006, 0.6306])

In [23]:
# Matrix trace
torch.einsum("ii->",x)

tensor(1.6207)

# Case Studies

## TREEQN

Given a low-dimensional state representation zl at layer l and a transition function W^a per action a, we want to calculate all next-state representations zal+1 using a residual connection.


In [24]:
import torch.nn.functional as F

def random_tensors(shape, num=1, requires_grad=False):
  tensors = [torch.randn(shape, requires_grad=requires_grad) for i in range(0, num)]
  return tensors[0] if num == 1 else tensors

# Parameters
# -- [num_actions x hidden_dimension]
b = random_tensors([5, 3], requires_grad=True)
# -- [num_actions x hidden_dimension x hidden_dimension]
W = random_tensors([5, 3, 3], requires_grad=True)

def transition(zl):
  # -- [batch_size x num_actions x hidden_dimension]
  return zl.unsqueeze(1) + F.tanh(torch.einsum("bk,aki->bai", [zl, W]) + b)

# Sampled dummy inputs
# -- [batch_size x hidden_dimension]
zl = random_tensors([2, 3])

transition(zl)



tensor([[[ 0.2614,  0.0652,  1.0455],
         [-0.4112,  1.8931,  1.5323],
         [ 0.4651,  1.9163,  1.5342],
         [ 0.4624,  1.7946,  0.8379],
         [-0.7446,  1.9776,  1.6927]],

        [[ 0.4550, -0.2439,  0.7862],
         [-0.6796,  1.2814, -0.5299],
         [ 0.6686,  1.0737,  0.6388],
         [ 0.4404,  1.1259,  0.5036],
         [ 0.0452,  1.2452,  0.2531]]], grad_fn=<AddBackward0>)

## ATTENTION
word-by-word attention mechanism.

In [25]:
# Parameters
# -- [hidden_dimension]
bM, br, w = random_tensors([7], num=3, requires_grad=True)
# -- [hidden_dimension x hidden_dimension]
WY, Wh, Wr, Wt = random_tensors([7, 7], num=4, requires_grad=True)

# Single application of attention mechanism 
def attention(Y, ht, rt1):
  # -- [batch_size x hidden_dimension] 
  tmp = torch.einsum("ik,kl->il", [ht, Wh]) + torch.einsum("ik,kl->il", [rt1, Wr])
  Mt = F.tanh(torch.einsum("ijk,kl->ijl", [Y, WY]) + tmp.unsqueeze(1).expand_as(Y) + bM)
  # -- [batch_size x sequence_length]
  at = F.softmax(torch.einsum("ijk,k->ij", [Mt, w])) 
  # -- [batch_size x hidden_dimension]
  rt = torch.einsum("ijk,ij->ik", [Y, at]) + F.tanh(torch.einsum("ij,jk->ik", [rt1, Wt]) + br)
  # -- [batch_size x hidden_dimension], [batch_size x sequence_dimension]
  return rt, at

# Sampled dummy inputs
# -- [batch_size x sequence_length x hidden_dimension]
Y = random_tensors([3, 5, 7])
# -- [batch_size x hidden_dimension]
ht, rt1 = random_tensors([3, 7], num=2)

rt, at = attention(Y, ht, rt1)
at  # -- print attention weights

  at = F.softmax(torch.einsum("ijk,k->ij", [Mt, w]))


tensor([[0.0656, 0.2419, 0.2439, 0.3282, 0.1205],
        [0.3106, 0.5143, 0.0332, 0.0881, 0.0538],
        [0.1002, 0.1683, 0.0513, 0.3149, 0.3654]], grad_fn=<SoftmaxBackward>)