In [1]:
import torch
from torch import nn
import torch.nn.functional as F

##  einsum review


> 'ik,kj->ij'

- free indices: specified in the output (-> 之后/右边的)
- summation indices: all other (也就是没有出现在 -> 右边的，但出现在 -> 左边的 index 都会被求和 reduce 掉)


$$
(A\cdot B)_{ij} = \sum_k A_{ik}\cdot B_{kj}\\
AB_{ij}=A_{ik}B_{kj}
$$

## bilinear

$$
y=x_1^TAx_2+b
$$

- $x_1\in\mathbb R^{d_1}$: 列向量（数学中默认）
- $x_2\in\mathbb R^{d_2}$：列向量（数学中默认）
- $A\in \mathbb R^{d_1\times d_2}$
- $y$ 是一个 scalar

## `torch.nn.functional.bilinear`/`F.bilinear`

In [3]:
x_1 = torch.randn((3, 1))
x_2 = torch.randn((2, 1))
A = torch.randn((3, 2))

In [4]:
torch.mm(x_1.T, torch.mm(A, x_2))

tensor([[0.7900]])

In [14]:
F.bilinear(x_1.T.unsqueeze(0), x_2.T.unsqueeze(0), A.unsqueeze(0))

tensor([[[0.7900]]])

## `torch.nn.Bilinear`

In [66]:
m = nn.Bilinear(20, 30, 40)
# out * in1 * in2
print(m.weight.shape)
print(m.bias.shape)

torch.Size([40, 20, 30])
torch.Size([40])


In [67]:
# 128 pairs (x1, x2)
# x1: (20, 1)^T
# x2: (30, 1)^T
input1 = torch.randn(128, 20)
input2 = torch.randn(128, 30)

In [68]:
output = m(input1, input2)

In [69]:
output.shape

torch.Size([128, 40])

In [70]:
output

tensor([[-0.5529, -1.8147, -1.4468,  ...,  2.4288,  0.1620, -1.5111],
        [-2.1974, -4.0457, -0.4051,  ...,  1.7315,  4.8727,  5.2919],
        [-1.4074,  1.8243,  0.7437,  ..., -0.7086, -0.5513,  0.8196],
        ...,
        [-2.5626,  5.0811, -3.8342,  ...,  4.6002,  0.7150,  1.7936],
        [-0.9222,  1.4171,  1.4422,  ...,  3.3310,  1.9861, -3.7736],
        [-2.5675,  6.2278,  0.9713,  ...,  2.5801, -2.5891,  0.3585]],
       grad_fn=<AddBackward0>)

In [71]:
Y = torch.zeros(128, 40)
# (x1, x2)
# b (free index)
for i in range(128):
    # A
    # a (free index)
    for j in range(40):
        Y[i, j] = input1[i, :].view(1, -1) @ (m.weight[j, :, :] @ input2[i, :].view(-1, 1)) + m.bias[j]

In [72]:
Y

tensor([[-0.5529, -1.8147, -1.4468,  ...,  2.4288,  0.1620, -1.5111],
        [-2.1974, -4.0457, -0.4051,  ...,  1.7315,  4.8727,  5.2919],
        [-1.4074,  1.8243,  0.7437,  ..., -0.7086, -0.5513,  0.8196],
        ...,
        [-2.5626,  5.0811, -3.8342,  ...,  4.6002,  0.7150,  1.7936],
        [-0.9222,  1.4171,  1.4422,  ...,  3.3310,  1.9861, -3.7736],
        [-2.5675,  6.2278,  0.9713,  ...,  2.5801, -2.5891,  0.3585]],
       grad_fn=<CopySlices>)

In [73]:
# input1: b, n (128, 20)
# input2: b, m (128, 30)
# A: a, n, m (40, 20, 30)
torch.einsum('bn,abn->ba', input1, torch.einsum('anm,bm->abn', m.weight, input2)) + m.bias

tensor([[-0.5529, -1.8147, -1.4468,  ...,  2.4288,  0.1620, -1.5111],
        [-2.1974, -4.0457, -0.4051,  ...,  1.7315,  4.8727,  5.2919],
        [-1.4074,  1.8243,  0.7437,  ..., -0.7086, -0.5513,  0.8196],
        ...,
        [-2.5626,  5.0811, -3.8342,  ...,  4.6002,  0.7150,  1.7936],
        [-0.9222,  1.4171,  1.4422,  ...,  3.3310,  1.9861, -3.7736],
        [-2.5675,  6.2278,  0.9713,  ...,  2.5801, -2.5891,  0.3585]],
       grad_fn=<AddBackward0>)

In [74]:
torch.einsum('bn,anm,bm->ba', input1, m.weight, input2) + m.bias

tensor([[-0.5529, -1.8147, -1.4468,  ...,  2.4288,  0.1620, -1.5111],
        [-2.1974, -4.0457, -0.4051,  ...,  1.7315,  4.8727,  5.2919],
        [-1.4074,  1.8243,  0.7437,  ..., -0.7086, -0.5513,  0.8196],
        ...,
        [-2.5626,  5.0811, -3.8342,  ...,  4.6002,  0.7150,  1.7936],
        [-0.9222,  1.4171,  1.4422,  ...,  3.3310,  1.9861, -3.7736],
        [-2.5675,  6.2278,  0.9713,  ...,  2.5801, -2.5891,  0.3585]],
       grad_fn=<AddBackward0>)